Polyhedral QP layer
We use DiffOpt to define a custom network layer which, given an input matrix y
, computes its projection onto a polytope defined by a fixed number of inequalities: a_i^T x ≥ b_i
. A neural network is created using Flux.jl and trained on the MNIST dataset, integrating this quadratic optimization layer.
The QP is solved in the forward pass, and its DiffOpt derivative is used in the backward pass expressed with ChainRulesCore.rrule
.
This example is similar to the custom ReLU layer, except that the layer is parameterized by the hyperplanes (w,b)
and not a simple stateless function. This also means that ChainRulesCore.rrule
must return the derivatives of the output with respect to the layer parameters to allow for backpropagation.
using JuMP
import DiffOpt
import Ipopt
import ChainRulesCore
import Flux
import MLDatasets
import Statistics
using Base.Iterators: repeated
using LinearAlgebra
using Random
Random.seed!(42)
Random.TaskLocalRNG()
The Polytope representation and its derivative
struct Polytope{N}
w::NTuple{N,Vector{Float64}}
b::Vector{Float64}
end
Polytope(w::NTuple{N}) where {N} = Polytope{N}(w, randn(N))
Main.Polytope
We define a "call" operation on the polytope, making it a so-called functor. Calling the polytope with a matrix y
operates an Euclidean projection of this matrix onto the polytope.
function (polytope::Polytope{N})(
y::AbstractMatrix;
model = direct_model(DiffOpt.diff_optimizer(Ipopt.Optimizer)),
) where {N}
layer_size, batch_size = size(y)
empty!(model)
set_silent(model)
@variable(model, x[1:layer_size, 1:batch_size])
@constraint(
model,
greater_than_cons[idx in 1:N, sample in 1:batch_size],
dot(polytope.w[idx], x[:, sample]) ≥ polytope.b[idx]
)
@objective(model, Min, dot(x - y, x - y))
optimize!(model)
return JuMP.value.(x)
end
The @functor
macro from Flux implements auxiliary functions for collecting the parameters of our custom layer and operating backpropagation.
Flux.@functor Polytope
Define the reverse differentiation rule, for the function we defined above. Flux uses ChainRules primitives to implement reverse-mode differentiation of the whole network. To learn the current layer (the polytope the layer contains), the gradient is computed with respect to the Polytope
fields in a ChainRulesCore.Tangent type which is used to represent derivatives with respect to structs. For more details about backpropagation, visit Introduction, ChainRulesCore.jl.
function ChainRulesCore.rrule(
polytope::Polytope{N},
y::AbstractMatrix,
) where {N}
model = direct_model(DiffOpt.diff_optimizer(Ipopt.Optimizer))
xv = polytope(y; model = model)
function pullback_matrix_projection(dl_dx)
layer_size, batch_size = size(dl_dx)
dl_dx = ChainRulesCore.unthunk(dl_dx)
# `dl_dy` is the derivative of `l` wrt `y`
x = model[:x]
# grad wrt input parameters
dl_dy = zeros(size(dl_dx))
# grad wrt layer parameters
dl_dw = zero.(polytope.w)
dl_db = zero(polytope.b)
# set sensitivities
MOI.set.(model, DiffOpt.ReverseVariablePrimal(), x, dl_dx)
# compute grad
DiffOpt.reverse_differentiate!(model)
# compute gradient wrt objective function parameter y
obj_expr = MOI.get(model, DiffOpt.ReverseObjectiveFunction())
dl_dy .= -2 * JuMP.coefficient.(obj_expr, x)
greater_than_cons = model[:greater_than_cons]
for idx in 1:N, sample in 1:batch_size
cons_expr = MOI.get(
model,
DiffOpt.ReverseConstraintFunction(),
greater_than_cons[idx, sample],
)
dl_db[idx] -= JuMP.constant(cons_expr) / batch_size
dl_dw[idx] .+=
JuMP.coefficient.(cons_expr, x[:, sample]) / batch_size
end
dself = ChainRulesCore.Tangent{Polytope{N}}(; w = dl_dw, b = dl_db)
return (dself, dl_dy)
end
return xv, pullback_matrix_projection
end
Define the Network
layer_size = 20
m = Flux.Chain(
Flux.Dense(784, layer_size), # 784 being image linear dimension (28 x 28)
Polytope((randn(layer_size), randn(layer_size), randn(layer_size))),
Flux.Dense(layer_size, 10), # 10 being the number of outcomes (0 to 9)
Flux.softmax,
)
Chain(
Dense(784 => 20), # 15_700 parameters
Main.Polytope{3}(([0.763010299225126, 0.7035940274776313, -0.14599734645599982, -0.5536654908199037, 1.2869038273869011, 0.008341276474938202, -0.21620568547348124, 0.5804257435407556, -1.3392265917712003, 1.1478973141695823, -1.9733824220360674, 1.028013820942376, 0.3330700309002022, 0.3617600425138335, -1.453462820162369, -1.057330339314276, -1.3587775969043114, -1.1021467626767054, -0.5235197262169707, -0.11786930391333741], [0.351668972661715, 1.041263914306683, 0.46422300211229406, 0.04826873156033191, -0.25596725425831507, -0.22335670215229697, 1.0415674540673705, -0.3324595271156483, -0.2925751283411379, 0.22833547427474538, 0.8549714488784929, 0.013741809970875022, 1.6106151302442777, -0.1815465481057267, 1.4703466695005445, -1.4342494427631365, 2.2573510919898294, 0.5003066411786888, -0.6705635926644442, 0.7137770813874245], [-0.9962722884240374, 0.5735892597413242, 0.16779617569974092, 0.10928346396793628, -0.5314405685385368, 0.10145521150756781, 1.7523306866672552, -0.4759159998638638, -0.04913953026998466, 1.0468091878610841, 0.28067361040960415, -1.037581476255182, -0.4305762270027227, -0.559457160703659, -0.013166541964542387, 0.5942186617800541, -0.5526933729536513, -0.42696959994340017, 0.9373848400003064, 0.11406394788917756]), [0.4489732065383038, -0.6754953046551224, 1.4964212374772967]), # 63 parameters
Dense(20 => 10), # 210 parameters
NNlib.softmax,
) # Total: 8 arrays, 15_973 parameters, 63.031 KiB.
Prepare data
M = 500 # batch size
# Preprocessing train data
imgs = MLDatasets.MNIST.traintensor(1:M)
labels = MLDatasets.MNIST.trainlabels(1:M);
train_X = float.(reshape(imgs, size(imgs, 1) * size(imgs, 2), M)) # stack images
train_Y = Flux.onehotbatch(labels, 0:9);
# Preprocessing test data
test_imgs = MLDatasets.MNIST.testtensor(1:M)
test_labels = MLDatasets.MNIST.testlabels(1:M)
test_X = float.(reshape(test_imgs, size(test_imgs, 1) * size(test_imgs, 2), M))
test_Y = Flux.onehotbatch(test_labels, 0:9);
┌ Warning: MNIST.traintensor() is deprecated, use `MNIST(split=:train).features` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/0MkOE/src/datasets/vision/mnist.jl:157
┌ Warning: MNIST.trainlabels() is deprecated, use `MNIST(split=:train).targets` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/0MkOE/src/datasets/vision/mnist.jl:173
┌ Warning: MNIST.testtensor() is deprecated, use `MNIST(split=:test).features` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/0MkOE/src/datasets/vision/mnist.jl:165
┌ Warning: MNIST.testlabels() is deprecated, use `MNIST(split=:test).targets` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/0MkOE/src/datasets/vision/mnist.jl:180
Define input data The original data is repeated epochs
times because Flux.train!
only loops through the data set once
epochs = 50
dataset = repeated((train_X, train_Y), epochs);
Network training
training loss function, Flux optimizer
custom_loss(x, y) = Flux.crossentropy(m(x), y)
opt = Flux.ADAM()
evalcb = () -> @show(custom_loss(train_X, train_Y))
#8 (generic function with 1 method)
Train to optimize network parameters
@time Flux.train!(
custom_loss,
Flux.params(m),
dataset,
opt,
cb = Flux.throttle(evalcb, 5),
);
┌ Warning: Layer with Float32 parameters got Float64 input.
│ The input will be converted, but any earlier layers may be very slow.
│ layer = Dense(20 => 10) # 210 parameters
│ summary(x) = "20×500 Matrix{Float64}"
└ @ Flux ~/.julia/packages/Flux/hiqg1/src/layers/stateless.jl:60
custom_loss(train_X, train_Y) = 2.3216217f0
custom_loss(train_X, train_Y) = 1.7559088f0
custom_loss(train_X, train_Y) = 1.3558202f0
custom_loss(train_X, train_Y) = 1.0826f0
custom_loss(train_X, train_Y) = 0.88952696f0
custom_loss(train_X, train_Y) = 0.7472685f0
custom_loss(train_X, train_Y) = 0.6386277f0
custom_loss(train_X, train_Y) = 0.55348027f0
64.293757 seconds (35.89 M allocations: 4.149 GiB, 1.53% gc time, 1.42% compilation time)
Although our custom implementation takes time, it is able to reach similar accuracy as the usual ReLU function implementation.
Accuracy results
Average of correct guesses
accuracy(x, y) = Statistics.mean(Flux.onecold(m(x)) .== Flux.onecold(y));
Training accuracy
accuracy(train_X, train_Y)
0.9
Test accuracy
accuracy(test_X, test_Y)
0.744
Note that the accuracy is low due to simplified training. It is possible to increase the number of samples N
, the number of epochs epoch
and the connectivity inner
.
This page was generated using Literate.jl.