Polyhedral QP layer

We use DiffOpt to define a custom network layer which, given an input matrix y, computes its projection onto a polytope defined by a fixed number of inequalities: a_i^T x ≥ b_i. A neural network is created using Flux.jl and trained on the MNIST dataset, integrating this quadratic optimization layer.

The QP is solved in the forward pass, and its DiffOpt derivative is used in the backward pass expressed with ChainRulesCore.rrule.

This example is similar to the custom ReLU layer, except that the layer is parameterized by the hyperplanes (w,b) and not a simple stateless function. This also means that ChainRulesCore.rrule must return the derivatives of the output with respect to the layer parameters to allow for backpropagation.

using JuMP
import DiffOpt
import Ipopt
import ChainRulesCore
import Flux
import MLDatasets
import Statistics
using Base.Iterators: repeated
using LinearAlgebra
using Random

Random.seed!(42)
Random.TaskLocalRNG()

The Polytope representation and its derivative

struct Polytope{N}
    w::NTuple{N, Matrix{Float64}}
    b::Vector{Float64}
end

Polytope(w::NTuple{N}) where {N} = Polytope{N}(w, randn(N))
Main.Polytope

We define a "call" operation on the polytope, making it a so-called functor. Calling the polytope with a matrix y operates an Euclidean projection of this matrix onto the polytope.

function (polytope::Polytope)(y::AbstractMatrix; model = direct_model(DiffOpt.diff_optimizer(Ipopt.Optimizer)))
    N, M = size(y)
    empty!(model)
    set_silent(model)
    @variable(model, x[1:N, 1:M])
    @constraint(model, greater_than_cons[idx in 1:length(polytope.w)], dot(polytope.w[idx], x) ≥ polytope.b[idx])
    @objective(model, Min, dot(x - y, x - y))
    optimize!(model)
    return JuMP.value.(x)
end

The @functor macro from Flux implements auxiliary functions for collecting the parameters of our custom layer and operating backpropagation.

Flux.@functor Polytope

Define the reverse differentiation rule, for the function we defined above. Flux uses ChainRules primitives to implement reverse-mode differentiation of the whole network. To learn the current layer (the polytope the layer contains), the gradient is computed with respect to the Polytope fields in a ChainRulesCore.Tangent type which is used to represent derivatives with respect to structs. For more details about backpropagation, visit Introduction, ChainRulesCore.jl.

function ChainRulesCore.rrule(polytope::Polytope, y::AbstractMatrix)
    model = direct_model(DiffOpt.diff_optimizer(Ipopt.Optimizer))
    xv = polytope(y; model = model)
    function pullback_matrix_projection(dl_dx)
        dl_dx = ChainRulesCore.unthunk(dl_dx)
        #  `dl_dy` is the derivative of `l` wrt `y`
        x = model[:x]
        # grad wrt input parameters
        dl_dy = zeros(size(dl_dx))
        # grad wrt layer parameters
        dl_dw = zero.(polytope.w)
        dl_db = zero(polytope.b)
        # set sensitivities
        MOI.set.(model, DiffOpt.BackwardInVariablePrimal(), x, dl_dx)
        # compute grad
        DiffOpt.backward(model)
        # compute gradient wrt objective function parameter y
        obj_expr = MOI.get(model, DiffOpt.BackwardOutObjective())
        dl_dy .= -2 * JuMP.coefficient.(obj_expr, x)
        greater_than_cons = model[:greater_than_cons]
        for idx in eachindex(dl_dw)
            cons_expr = MOI.get(model, DiffOpt.BackwardOutConstraint(), greater_than_cons[idx])
            dl_db[idx] = -JuMP.constant(cons_expr)
            dl_dw[idx] .= JuMP.coefficient.(cons_expr, x)
        end
        dself = ChainRulesCore.Tangent{typeof(polytope)}(; w = dl_dw, b = dl_db)
        return (dself, dl_dy)
    end
    return xv, pullback_matrix_projection
end

Prepare data

N = 500
imgs = MLDatasets.MNIST.traintensor(1:N)
labels = MLDatasets.MNIST.trainlabels(1:N);

Preprocessing

train_X = float.(reshape(imgs, size(imgs, 1) * size(imgs, 2), N)) ## stack all the images
train_Y = Flux.onehotbatch(labels, 0:9);

test_imgs = MLDatasets.MNIST.testtensor(1:N)
test_X = float.(reshape(test_imgs, size(test_imgs, 1) * size(test_imgs, 2), N))
test_Y = Flux.onehotbatch(MLDatasets.MNIST.testlabels(1:N), 0:9);

Define the Network

inner = 20

m = Flux.Chain(
    Flux.Dense(784, inner), ## 784 being image linear dimension (28 x 28)
    Polytope((randn(inner, N), randn(inner, N), randn(inner, N))),
    Flux.Dense(inner, 10), ## 10 being the number of outcomes (0 to 9)
    Flux.softmax,
)
Chain(
  Dense(784 => 20),                     # 15_700 parameters
  Main.Polytope{3}(([0.763010299225126 -0.050577659722164534 … -1.5614828121060058 -1.0523232233301327; 0.7035940274776313 1.8972368657899004 … -1.3192900970395856 0.36728994722441505; … ; 0.5442452868843884 0.7541078667628291 … 0.48217186403140977 0.4645704542013086; -0.4336598559509569 -0.4322222305100844 … -0.617200391117597 -0.4427862169261633], [-1.5011059075724487 0.30620952088484726 … 2.3622150518014866 -1.2920328899968427; -0.8079337289963097 0.4533553792350089 … 0.18578076815376657 -1.4903837019761388; … ; -0.5225729274127008 -1.038245444273301 … -0.8717363973503962 0.19469094062712794; -1.1152322409533746 -2.4749649738498327 … -0.9568724702338636 -0.3063801629184017], [0.20927421978950772 -1.3589300616187778 … -0.7752966290409944 1.39843195061478; 0.6926627348675881 1.3644700210731826 … -1.372403899975541 0.9189331042596978; … ; -0.10444953362793619 -0.7157846548069472 … 1.0951377409956915 -0.13342481569495498; -0.08821975129293722 1.3249256302211976 … 0.9941499246780054 -0.9419331496454915]), [-1.076776319092391, -0.10707476316332624, 0.1643198289502215]),  # 30_003 parameters
  Dense(20 => 10),                      # 210 parameters
  NNlib.softmax,
)                   # Total: 8 arrays, 45_913 parameters, 297.070 KiB.

Define input data The original data is repeated epochs times because Flux.train! only loops through the data set once

epochs = 50

dataset = repeated((train_X, train_Y), epochs);

Parameters for the network training

training loss function, Flux optimizer

custom_loss(x, y) = Flux.crossentropy(m(x), y)
opt = Flux.ADAM()
evalcb = () -> @show(custom_loss(train_X, train_Y))
#9 (generic function with 1 method)

Train to optimize network parameters

Flux.train!(custom_loss, Flux.params(m), dataset, opt, cb = Flux.throttle(evalcb, 5));
custom_loss(train_X, train_Y) = 2.364875262014027
custom_loss(train_X, train_Y) = 1.9400969781671051
custom_loss(train_X, train_Y) = 1.6126118853454194
custom_loss(train_X, train_Y) = 1.3517769406081526
custom_loss(train_X, train_Y) = 1.1452672731791906
custom_loss(train_X, train_Y) = 0.9798444097388876
custom_loss(train_X, train_Y) = 0.8476068375465406
custom_loss(train_X, train_Y) = 0.741311935644079
custom_loss(train_X, train_Y) = 0.6551726604656446
custom_loss(train_X, train_Y) = 0.5847416824798926

Although our custom implementation takes time, it is able to reach similar accuracy as the usual ReLU function implementation.

Average of correct guesses

accuracy(x, y) = Statistics.mean(Flux.onecold(m(x)) .== Flux.onecold(y));

Training accuracy

accuracy(train_X, train_Y)
0.894

Test accuracy

accuracy(test_X, test_Y)
0.738

Note that the accuracy is low due to simplified training. It is possible to increase the number of samples N, the number of epochs epoch and the connectivity inner.


This page was generated using Literate.jl.