Custom ReLU layer

We demonstrate how DiffOpt can be used to generate a simple neural network unit - the ReLU layer. A neural network is created using Flux.jl and trained on the MNIST dataset.

This tutorial uses the following packages

using JuMP
import DiffOpt
import Ipopt
import ChainRulesCore
import Flux
import MLDatasets
import Statistics
import Base.Iterators: repeated
using LinearAlgebra

The ReLU and its derivative

Define a relu through an optimization problem solved by a quadratic solver. Return the solution of the problem.

function matrix_relu(
    y::Matrix;
    model = Model(() -> DiffOpt.diff_optimizer(Ipopt.Optimizer)),
)
    layer_size, batch_size = size(y)
    empty!(model)
    set_silent(model)
    @variable(model, x[1:layer_size, 1:batch_size] >= 0)
    @objective(model, Min, x[:]'x[:] - 2y[:]'x[:])
    optimize!(model)
    return value.(x)
end
matrix_relu (generic function with 1 method)

Define the reverse differentiation rule, for the function we defined above.

function ChainRulesCore.rrule(::typeof(matrix_relu), y::Matrix{T}) where {T}
    model = Model(() -> DiffOpt.diff_optimizer(Ipopt.Optimizer))
    pv = matrix_relu(y; model = model)
    function pullback_matrix_relu(dl_dx)
        # some value from the backpropagation (e.g., loss) is denoted by `l`
        # so `dl_dy` is the derivative of `l` wrt `y`
        x = model[:x] # load decision variable `x` into scope
        dl_dy = zeros(T, size(dl_dx))
        dl_dq = zeros(T, size(dl_dx))
        # set sensitivities
        MOI.set.(model, DiffOpt.ReverseVariablePrimal(), x[:], dl_dx[:])
        # compute grad
        DiffOpt.reverse_differentiate!(model)
        # return gradient wrt objective function parameters
        obj_exp = MOI.get(model, DiffOpt.ReverseObjectiveFunction())
        # coeff of `x` in q'x = -2y'x
        dl_dq[:] .= JuMP.coefficient.(obj_exp, x[:])
        dq_dy = -2 # dq/dy = -2
        dl_dy[:] .= dl_dq[:] * dq_dy
        return (ChainRulesCore.NoTangent(), dl_dy)
    end
    return pv, pullback_matrix_relu
end

For more details about backpropagation, visit Introduction, ChainRulesCore.jl.

Define the network

layer_size = 10
m = Flux.Chain(
    Flux.Dense(784, layer_size), # 784 being image linear dimension (28 x 28)
    matrix_relu,
    Flux.Dense(layer_size, 10), # 10 being the number of outcomes (0 to 9)
    Flux.softmax,
)
Chain(
  Dense(784 => 10),                     # 7_850 parameters
  Main.var"Main".matrix_relu,
  Dense(10 => 10),                      # 110 parameters
  NNlib.softmax,
)                   # Total: 4 arrays, 7_960 parameters, 31.297 KiB.

Prepare data

N = 1000 # batch size
# Preprocessing train data
imgs = MLDatasets.MNIST.traintensor(1:N)
labels = MLDatasets.MNIST.trainlabels(1:N)
train_X = float.(reshape(imgs, size(imgs, 1) * size(imgs, 2), N)) # stack images
train_Y = Flux.onehotbatch(labels, 0:9);
# Preprocessing test data
test_imgs = MLDatasets.MNIST.testtensor(1:N)
test_labels = MLDatasets.MNIST.testlabels(1:N)
test_X = float.(reshape(test_imgs, size(test_imgs, 1) * size(test_imgs, 2), N))
test_Y = Flux.onehotbatch(test_labels, 0:9);
┌ Warning: MNIST.traintensor() is deprecated, use `MNIST(split=:train).features` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/0MkOE/src/datasets/vision/mnist.jl:157
┌ Warning: MNIST.trainlabels() is deprecated, use `MNIST(split=:train).targets` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/0MkOE/src/datasets/vision/mnist.jl:173
┌ Warning: MNIST.testtensor() is deprecated, use `MNIST(split=:test).features` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/0MkOE/src/datasets/vision/mnist.jl:165
┌ Warning: MNIST.testlabels() is deprecated, use `MNIST(split=:test).targets` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/0MkOE/src/datasets/vision/mnist.jl:180

Define input data The original data is repeated epochs times because Flux.train! only loops through the data set once

epochs = 50 # ~1 minute (i7 8th gen with 16gb RAM)
# epochs = 100 # leads to 77.8% in about 2 minutes
dataset = repeated((train_X, train_Y), epochs);

Network training

training loss function, Flux optimizer

custom_loss(x, y) = Flux.crossentropy(m(x), y)
opt = Flux.Adam()
evalcb = () -> @show(custom_loss(train_X, train_Y))
#11 (generic function with 1 method)

Train to optimize network parameters

@time Flux.train!(
    custom_loss,
    Flux.params(m),
    dataset,
    opt,
    cb = Flux.throttle(evalcb, 5),
);
┌ Warning: Layer with Float32 parameters got Float64 input.
│   The input will be converted, but any earlier layers may be very slow.
│   layer = Dense(10 => 10)     # 110 parameters
│   summary(x) = "10×1000 Matrix{Float64}"
└ @ Flux ~/.julia/packages/Flux/hiqg1/src/layers/stateless.jl:60
custom_loss(train_X, train_Y) = 2.355365f0
custom_loss(train_X, train_Y) = 2.2240443f0
custom_loss(train_X, train_Y) = 2.1510334f0
custom_loss(train_X, train_Y) = 2.0600805f0
custom_loss(train_X, train_Y) = 1.9604436f0
custom_loss(train_X, train_Y) = 1.8702683f0
custom_loss(train_X, train_Y) = 1.7790897f0
custom_loss(train_X, train_Y) = 1.691865f0
custom_loss(train_X, train_Y) = 1.610134f0
custom_loss(train_X, train_Y) = 1.5316879f0
107.909324 seconds (76.76 M allocations: 4.763 GiB, 1.40% gc time, 0.69% compilation time)

Although our custom implementation takes time, it is able to reach similar accuracy as the usual ReLU function implementation.

Accuracy results

Average of correct guesses

accuracy(x, y) = Statistics.mean(Flux.onecold(m(x)) .== Flux.onecold(y));

Training accuracy

accuracy(train_X, train_Y)
0.562

Test accuracy

accuracy(test_X, test_Y)
0.478

Note that the accuracy is low due to simplified training. It is possible to increase the number of samples N, the number of epochs epoch and the connectivity inner.


This page was generated using Literate.jl.