Polyhedral QP layer

We use DiffOpt to define a custom network layer which, given an input matrix y, computes its projection onto a polytope defined by a fixed number of inequalities: a_i^T x ≥ b_i. A neural network is created using Flux.jl and trained on the MNIST dataset, integrating this quadratic optimization layer.

The QP is solved in the forward pass, and its DiffOpt derivative is used in the backward pass expressed with ChainRulesCore.rrule.

This example is similar to the custom ReLU layer, except that the layer is parameterized by the hyperplanes (w,b) and not a simple stateless function. This also means that ChainRulesCore.rrule must return the derivatives of the output with respect to the layer parameters to allow for backpropagation.

using JuMP
import DiffOpt
import Ipopt
import ChainRulesCore
import Flux
import MLDatasets
import Statistics
using Base.Iterators: repeated
using LinearAlgebra
using Random

Random.seed!(42)
Random.TaskLocalRNG()

The Polytope representation and its derivative

struct Polytope{N}
    w::NTuple{N, Vector{Float64}}
    b::Vector{Float64}
end

Polytope(w::NTuple{N}) where {N} = Polytope{N}(w, randn(N))
Main.Polytope

We define a "call" operation on the polytope, making it a so-called functor. Calling the polytope with a matrix y operates an Euclidean projection of this matrix onto the polytope.

function (polytope::Polytope{N})(
    y::AbstractMatrix;
    model = direct_model(DiffOpt.diff_optimizer(Ipopt.Optimizer))
) where {N}
    layer_size, batch_size = size(y)
    empty!(model)
    set_silent(model)
    @variable(model, x[1:layer_size, 1:batch_size])
    @constraint(model,
        greater_than_cons[idx in 1:N, sample in 1:batch_size],
        dot(polytope.w[idx], x[:, sample]) ≥ polytope.b[idx]
    )
    @objective(model, Min, dot(x - y, x - y))
    optimize!(model)
    return JuMP.value.(x)
end

The @functor macro from Flux implements auxiliary functions for collecting the parameters of our custom layer and operating backpropagation.

Flux.@functor Polytope

Define the reverse differentiation rule, for the function we defined above. Flux uses ChainRules primitives to implement reverse-mode differentiation of the whole network. To learn the current layer (the polytope the layer contains), the gradient is computed with respect to the Polytope fields in a ChainRulesCore.Tangent type which is used to represent derivatives with respect to structs. For more details about backpropagation, visit Introduction, ChainRulesCore.jl.

function ChainRulesCore.rrule(polytope::Polytope{N}, y::AbstractMatrix) where {N}
    model = direct_model(DiffOpt.diff_optimizer(Ipopt.Optimizer))
    xv = polytope(y; model = model)
    function pullback_matrix_projection(dl_dx)
        layer_size, batch_size = size(dl_dx)
        dl_dx = ChainRulesCore.unthunk(dl_dx)
        #  `dl_dy` is the derivative of `l` wrt `y`
        x = model[:x]
        # grad wrt input parameters
        dl_dy = zeros(size(dl_dx))
        # grad wrt layer parameters
        dl_dw = zero.(polytope.w)
        dl_db = zero(polytope.b)
        # set sensitivities
        MOI.set.(model, DiffOpt.ReverseVariablePrimal(), x, dl_dx)
        # compute grad
        DiffOpt.reverse_differentiate!(model)
        # compute gradient wrt objective function parameter y
        obj_expr = MOI.get(model, DiffOpt.ReverseObjectiveFunction())
        dl_dy .= -2 * JuMP.coefficient.(obj_expr, x)
        greater_than_cons = model[:greater_than_cons]
        for idx in 1:N, sample in 1:batch_size
            cons_expr = MOI.get(model,
                DiffOpt.ReverseConstraintFunction(),
                greater_than_cons[idx, sample])
            dl_db[idx] -= JuMP.constant(cons_expr)/batch_size
            dl_dw[idx] .+= JuMP.coefficient.(cons_expr, x[:,sample])/batch_size
        end
        dself = ChainRulesCore.Tangent{Polytope{N}}(; w = dl_dw, b = dl_db)
        return (dself, dl_dy)
    end
    return xv, pullback_matrix_projection
end

Define the Network

layer_size = 20
m = Flux.Chain(
    Flux.Dense(784, layer_size), # 784 being image linear dimension (28 x 28)
    Polytope((randn(layer_size), randn(layer_size), randn(layer_size))),
    Flux.Dense(layer_size, 10), # 10 being the number of outcomes (0 to 9)
    Flux.softmax,
)
Chain(
  Dense(784 => 20),                     # 15_700 parameters
  Main.Polytope{3}(([0.763010299225126, 0.7035940274776313, -0.14599734645599982, -0.5536654908199037, 1.2869038273869011, 0.008341276474938202, -0.21620568547348124, 0.5804257435407556, -1.3392265917712003, 1.1478973141695823, -1.9733824220360674, 1.028013820942376, 0.3330700309002022, 0.3617600425138335, -1.453462820162369, -1.057330339314276, -1.3587775969043114, -1.1021467626767054, -0.5235197262169707, -0.11786930391333741], [0.351668972661715, 1.041263914306683, 0.46422300211229406, 0.04826873156033191, -0.25596725425831507, -0.22335670215229697, 1.0415674540673705, -0.3324595271156483, -0.2925751283411379, 0.22833547427474538, 0.8549714488784929, 0.013741809970875022, 1.6106151302442777, -0.1815465481057267, 1.4703466695005445, -1.4342494427631365, 2.2573510919898294, 0.5003066411786888, -0.6705635926644442, 0.7137770813874245], [-0.9962722884240374, 0.5735892597413242, 0.16779617569974092, 0.10928346396793628, -0.5314405685385368, 0.10145521150756781, 1.7523306866672552, -0.4759159998638638, -0.04913953026998466, 1.0468091878610841, 0.28067361040960415, -1.037581476255182, -0.4305762270027227, -0.559457160703659, -0.013166541964542387, 0.5942186617800541, -0.5526933729536513, -0.42696959994340017, 0.9373848400003064, 0.11406394788917756]), [0.4489732065383038, -0.6754953046551224, 1.4964212374772967]),  # 63 parameters
  Dense(20 => 10),                      # 210 parameters
  NNlib.softmax,
)                   # Total: 8 arrays, 15_973 parameters, 63.164 KiB.

Prepare data

M = 500 # batch size
# Preprocessing train data
imgs = MLDatasets.MNIST.traintensor(1:M)
labels = MLDatasets.MNIST.trainlabels(1:M);
train_X = float.(reshape(imgs, size(imgs, 1) * size(imgs, 2), M)) # stack images
train_Y = Flux.onehotbatch(labels, 0:9);
# Preprocessing test data
test_imgs = MLDatasets.MNIST.testtensor(1:M)
test_labels = MLDatasets.MNIST.testlabels(1:M)
test_X = float.(reshape(test_imgs, size(test_imgs, 1) * size(test_imgs, 2), M))
test_Y = Flux.onehotbatch(test_labels, 0:9);
┌ Warning: MNIST.traintensor() is deprecated, use `MNIST(split=:train).features` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/7BhSm/src/datasets/vision/mnist.jl:157
┌ Warning: MNIST.trainlabels() is deprecated, use `MNIST(split=:train).targets` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/7BhSm/src/datasets/vision/mnist.jl:173
┌ Warning: MNIST.testtensor() is deprecated, use `MNIST(split=:test).features` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/7BhSm/src/datasets/vision/mnist.jl:165
┌ Warning: MNIST.testlabels() is deprecated, use `MNIST(split=:test).targets` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/7BhSm/src/datasets/vision/mnist.jl:180

Define input data The original data is repeated epochs times because Flux.train! only loops through the data set once

epochs = 50
dataset = repeated((train_X, train_Y), epochs);

Network training

training loss function, Flux optimizer

custom_loss(x, y) = Flux.crossentropy(m(x), y)
opt = Flux.ADAM()
evalcb = () -> @show(custom_loss(train_X, train_Y))
#9 (generic function with 1 method)

Train to optimize network parameters

@time Flux.train!(custom_loss, Flux.params(m), dataset, opt, cb = Flux.throttle(evalcb, 5));
custom_loss(train_X, train_Y) = 2.3216230024496745
custom_loss(train_X, train_Y) = 1.8995358083695622
custom_loss(train_X, train_Y) = 1.5656887963878012
custom_loss(train_X, train_Y) = 1.3102242573150837
custom_loss(train_X, train_Y) = 1.1158038743206777
custom_loss(train_X, train_Y) = 0.964758416377316
custom_loss(train_X, train_Y) = 0.866590623825587
custom_loss(train_X, train_Y) = 0.7652301921815144
custom_loss(train_X, train_Y) = 0.6818464248641992
custom_loss(train_X, train_Y) = 0.6122045971774231
custom_loss(train_X, train_Y) = 0.5534808526578696
133.202730 seconds (126.72 M allocations: 7.221 GiB, 1.39% gc time, 1.07% compilation time)

Although our custom implementation takes time, it is able to reach similar accuracy as the usual ReLU function implementation.

Accuracy results

Average of correct guesses

accuracy(x, y) = Statistics.mean(Flux.onecold(m(x)) .== Flux.onecold(y));

Training accuracy

accuracy(train_X, train_Y)
0.9

Test accuracy

accuracy(test_X, test_Y)
0.744

Note that the accuracy is low due to simplified training. It is possible to increase the number of samples N, the number of epochs epoch and the connectivity inner.


This page was generated using Literate.jl.