Polyhedral QP layer

We use DiffOpt to define a custom network layer which, given an input matrix y, computes its projection onto a polytope defined by a fixed number of inequalities: a_i^T x ≥ b_i. A neural network is created using Flux.jl and trained on the MNIST dataset, integrating this quadratic optimization layer.

The QP is solved in the forward pass, and its DiffOpt derivative is used in the backward pass expressed with ChainRulesCore.rrule.

This example is similar to the custom ReLU layer, except that the layer is parameterized by the hyperplanes (w,b) and not a simple stateless function. This also means that ChainRulesCore.rrule must return the derivatives of the output with respect to the layer parameters to allow for backpropagation.

using JuMP
import DiffOpt
import Ipopt
import ChainRulesCore
import Flux
import MLDatasets
import Statistics
using Base.Iterators: repeated
using LinearAlgebra
using Random

Random.seed!(42)
Random.TaskLocalRNG()

The Polytope representation and its derivative

struct Polytope{N}
    w::NTuple{N,Vector{Float64}}
    b::Vector{Float64}
end

Polytope(w::NTuple{N}) where {N} = Polytope{N}(w, randn(N))
Main.Polytope

We define a "call" operation on the polytope, making it a so-called functor. Calling the polytope with a matrix y operates an Euclidean projection of this matrix onto the polytope.

function (polytope::Polytope{N})(
    y::AbstractMatrix;
    model = direct_model(DiffOpt.diff_optimizer(Ipopt.Optimizer)),
) where {N}
    layer_size, batch_size = size(y)
    empty!(model)
    set_silent(model)
    @variable(model, x[1:layer_size, 1:batch_size])
    @constraint(
        model,
        greater_than_cons[idx in 1:N, sample in 1:batch_size],
        dot(polytope.w[idx], x[:, sample]) ≥ polytope.b[idx]
    )
    @objective(model, Min, dot(x - y, x - y))
    optimize!(model)
    return JuMP.value.(x)
end

The @functor macro from Flux implements auxiliary functions for collecting the parameters of our custom layer and operating backpropagation.

Flux.@functor Polytope

Define the reverse differentiation rule, for the function we defined above. Flux uses ChainRules primitives to implement reverse-mode differentiation of the whole network. To learn the current layer (the polytope the layer contains), the gradient is computed with respect to the Polytope fields in a ChainRulesCore.Tangent type which is used to represent derivatives with respect to structs. For more details about backpropagation, visit Introduction, ChainRulesCore.jl.

function ChainRulesCore.rrule(
    polytope::Polytope{N},
    y::AbstractMatrix,
) where {N}
    model = direct_model(DiffOpt.diff_optimizer(Ipopt.Optimizer))
    xv = polytope(y; model = model)
    function pullback_matrix_projection(dl_dx)
        layer_size, batch_size = size(dl_dx)
        dl_dx = ChainRulesCore.unthunk(dl_dx)
        #  `dl_dy` is the derivative of `l` wrt `y`
        x = model[:x]
        # grad wrt input parameters
        dl_dy = zeros(size(dl_dx))
        # grad wrt layer parameters
        dl_dw = zero.(polytope.w)
        dl_db = zero(polytope.b)
        # set sensitivities
        MOI.set.(model, DiffOpt.ReverseVariablePrimal(), x, dl_dx)
        # compute grad
        DiffOpt.reverse_differentiate!(model)
        # compute gradient wrt objective function parameter y
        obj_expr = MOI.get(model, DiffOpt.ReverseObjectiveFunction())
        dl_dy .= -2 * JuMP.coefficient.(obj_expr, x)
        greater_than_cons = model[:greater_than_cons]
        for idx in 1:N, sample in 1:batch_size
            cons_expr = MOI.get(
                model,
                DiffOpt.ReverseConstraintFunction(),
                greater_than_cons[idx, sample],
            )
            dl_db[idx] -= JuMP.constant(cons_expr) / batch_size
            dl_dw[idx] .+=
                JuMP.coefficient.(cons_expr, x[:, sample]) / batch_size
        end
        dself = ChainRulesCore.Tangent{Polytope{N}}(; w = dl_dw, b = dl_db)
        return (dself, dl_dy)
    end
    return xv, pullback_matrix_projection
end

Define the Network

layer_size = 20
m = Flux.Chain(
    Flux.Dense(784, layer_size), # 784 being image linear dimension (28 x 28)
    Polytope((randn(layer_size), randn(layer_size), randn(layer_size))),
    Flux.Dense(layer_size, 10), # 10 being the number of outcomes (0 to 9)
    Flux.softmax,
)
Chain(
  Dense(784 => 20),                     # 15_700 parameters
  Main.Polytope{3}(([0.763010299225126, 0.7035940274776313, -0.14599734645599982, -0.5536654908199037, 1.2869038273869011, 0.008341276474938202, -0.21620568547348124, 0.5804257435407556, -1.3392265917712003, 1.1478973141695823, -1.9733824220360674, 1.028013820942376, 0.3330700309002022, 0.3617600425138335, -1.453462820162369, -1.057330339314276, -1.3587775969043114, -1.1021467626767054, -0.5235197262169707, -0.11786930391333741], [0.351668972661715, 1.041263914306683, 0.46422300211229406, 0.04826873156033191, -0.25596725425831507, -0.22335670215229697, 1.0415674540673705, -0.3324595271156483, -0.2925751283411379, 0.22833547427474538, 0.8549714488784929, 0.013741809970875022, 1.6106151302442777, -0.1815465481057267, 1.4703466695005445, -1.4342494427631365, 2.2573510919898294, 0.5003066411786888, -0.6705635926644442, 0.7137770813874245], [-0.9962722884240374, 0.5735892597413242, 0.16779617569974092, 0.10928346396793628, -0.5314405685385368, 0.10145521150756781, 1.7523306866672552, -0.4759159998638638, -0.04913953026998466, 1.0468091878610841, 0.28067361040960415, -1.037581476255182, -0.4305762270027227, -0.559457160703659, -0.013166541964542387, 0.5942186617800541, -0.5526933729536513, -0.42696959994340017, 0.9373848400003064, 0.11406394788917756]), [0.4489732065383038, -0.6754953046551224, 1.4964212374772967]),  # 63 parameters
  Dense(20 => 10),                      # 210 parameters
  NNlib.softmax,
)                   # Total: 8 arrays, 15_973 parameters, 63.164 KiB.

Prepare data

M = 500 # batch size
# Preprocessing train data
imgs = MLDatasets.MNIST.traintensor(1:M)
labels = MLDatasets.MNIST.trainlabels(1:M);
train_X = float.(reshape(imgs, size(imgs, 1) * size(imgs, 2), M)) # stack images
train_Y = Flux.onehotbatch(labels, 0:9);
# Preprocessing test data
test_imgs = MLDatasets.MNIST.testtensor(1:M)
test_labels = MLDatasets.MNIST.testlabels(1:M)
test_X = float.(reshape(test_imgs, size(test_imgs, 1) * size(test_imgs, 2), M))
test_Y = Flux.onehotbatch(test_labels, 0:9);
┌ Warning: MNIST.traintensor() is deprecated, use `MNIST(split=:train).features` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/bg0uc/src/datasets/vision/mnist.jl:157
┌ Warning: MNIST.trainlabels() is deprecated, use `MNIST(split=:train).targets` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/bg0uc/src/datasets/vision/mnist.jl:173
┌ Warning: MNIST.testtensor() is deprecated, use `MNIST(split=:test).features` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/bg0uc/src/datasets/vision/mnist.jl:165
┌ Warning: MNIST.testlabels() is deprecated, use `MNIST(split=:test).targets` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/bg0uc/src/datasets/vision/mnist.jl:180

Define input data The original data is repeated epochs times because Flux.train! only loops through the data set once

epochs = 50
dataset = repeated((train_X, train_Y), epochs);

Network training

training loss function, Flux optimizer

custom_loss(x, y) = Flux.crossentropy(m(x), y)
opt = Flux.ADAM()
evalcb = () -> @show(custom_loss(train_X, train_Y))
#9 (generic function with 1 method)

Train to optimize network parameters

@time Flux.train!(
    custom_loss,
    Flux.params(m),
    dataset,
    opt,
    cb = Flux.throttle(evalcb, 5),
);
------------------------------------------------------------------
	       SCS v3.2.3 - Splitting Conic Solver
	(c) Brendan O'Donoghue, Stanford University, 2012
------------------------------------------------------------------
problem:  variables n: 7, constraints m: 17
cones: 	  z: primal zero / dual free vars: 3
	  q: soc vars: 8, qsize: 1
	  s: psd vars: 6, ssize: 1
settings: eps_abs: 1.0e-04, eps_rel: 1.0e-04, eps_infeas: 1.0e-07
	  alpha: 1.50, scale: 1.00e-01, adaptive_scale: 1
	  max_iters: 100000, normalize: 1, rho_x: 1.00e-06
	  acceleration_lookback: 10, acceleration_interval: 10
lin-sys:  sparse-direct-amd-qdldl
	  nnz(A): 27, nnz(P): 0
------------------------------------------------------------------
 iter | pri res | dua res |   gap   |   obj   |  scale  | time (s)
------------------------------------------------------------------
     0| 5.96e+01  4.13e+00  2.63e+02 -1.28e+02  1.00e-01  1.11e-04
   225| 4.49e-06  1.76e-06  2.79e-06  2.79e-01  3.74e-01  9.36e-04
------------------------------------------------------------------
status:  solved
timings: total: 9.38e-04s = setup: 7.70e-05s + solve: 8.61e-04s
	 lin-sys: 9.40e-05s, cones: 5.79e-04s, accel: 9.78e-05s
------------------------------------------------------------------
objective = 0.278569
------------------------------------------------------------------
------------------------------------------------------------------
	       SCS v3.2.3 - Splitting Conic Solver
	(c) Brendan O'Donoghue, Stanford University, 2012
------------------------------------------------------------------
problem:  variables n: 11, constraints m: 26
cones: 	  z: primal zero / dual free vars: 4
	  q: soc vars: 12, qsize: 1
	  s: psd vars: 10, ssize: 1
settings: eps_abs: 1.0e-04, eps_rel: 1.0e-04, eps_infeas: 1.0e-07
	  alpha: 1.50, scale: 1.00e-01, adaptive_scale: 1
	  max_iters: 100000, normalize: 1, rho_x: 1.00e-06
	  acceleration_lookback: 10, acceleration_interval: 10
lin-sys:  sparse-direct-amd-qdldl
	  nnz(A): 40, nnz(P): 0
------------------------------------------------------------------
 iter | pri res | dua res |   gap   |   obj   |  scale  | time (s)
------------------------------------------------------------------
     0| 4.96e+01  4.70e+00  3.33e+02 -1.48e+02  1.00e-01  1.21e-04
   250| 1.32e-02  2.48e-04  1.29e-01  4.49e+00  3.31e-01  1.31e-03
   275| 3.00e-05  6.43e-06  4.99e-05  4.55e+00  3.31e-01  1.47e-03
------------------------------------------------------------------
status:  solved
timings: total: 1.47e-03s = setup: 8.52e-05s + solve: 1.39e-03s
	 lin-sys: 1.64e-04s, cones: 1.02e-03s, accel: 6.84e-05s
------------------------------------------------------------------
objective = 4.552818
------------------------------------------------------------------
┌ Warning: Layer with Float32 parameters got Float64 input.
│   The input will be converted, but any earlier layers may be very slow.
│   layer = Dense(20 => 10)     # 210 parameters
│   summary(x) = "20×500 Matrix{Float64}"
└ @ Flux ~/.julia/packages/Flux/FWgS0/src/layers/stateless.jl:50
custom_loss(train_X, train_Y) = 2.3216214f0
custom_loss(train_X, train_Y) = 1.8995347f0
custom_loss(train_X, train_Y) = 1.5656878f0
custom_loss(train_X, train_Y) = 1.3102238f0
custom_loss(train_X, train_Y) = 1.1158031f0
custom_loss(train_X, train_Y) = 0.9647578f0
custom_loss(train_X, train_Y) = 0.84461033f0
custom_loss(train_X, train_Y) = 0.74726844f0
custom_loss(train_X, train_Y) = 0.6669284f0
custom_loss(train_X, train_Y) = 0.5996534f0
103.313059 seconds (55.95 M allocations: 5.874 GiB, 0.98% gc time, 0.96% compilation time)

Although our custom implementation takes time, it is able to reach similar accuracy as the usual ReLU function implementation.

Accuracy results

Average of correct guesses

accuracy(x, y) = Statistics.mean(Flux.onecold(m(x)) .== Flux.onecold(y));

Training accuracy

accuracy(train_X, train_Y)
0.9

Test accuracy

accuracy(test_X, test_Y)
0.744

Note that the accuracy is low due to simplified training. It is possible to increase the number of samples N, the number of epochs epoch and the connectivity inner.


This page was generated using Literate.jl.