Custom ReLU layer

We demonstrate how DiffOpt can be used to generate a simple neural network unit - the ReLU layer. A neural network is created using Flux.jl and trained on the MNIST dataset.

This tutorial uses the following packages

using JuMP
import DiffOpt
import Ipopt
import ChainRulesCore
import Flux
import MLDatasets
import Statistics
import Base.Iterators: repeated
using LinearAlgebra

The ReLU and its derivative

Define a relu through an optimization problem solved by a quadratic solver. Return the solution of the problem.

function matrix_relu(
    y::Matrix;
    model = Model(() -> DiffOpt.diff_optimizer(Ipopt.Optimizer))
)
    layer_size, batch_size = size(y)
    empty!(model)
    set_silent(model)
    @variable(model, x[1:layer_size, 1:batch_size] >= 0)
    @objective(model, Min, x[:]'x[:] -2y[:]'x[:])
    optimize!(model)
    return value.(x)
end
matrix_relu (generic function with 1 method)

Define the reverse differentiation rule, for the function we defined above.

function ChainRulesCore.rrule(::typeof(matrix_relu), y::Matrix{T}) where T
    model = Model(() -> DiffOpt.diff_optimizer(Ipopt.Optimizer))
    pv = matrix_relu(y, model = model)
    function pullback_matrix_relu(dl_dx)
        # some value from the backpropagation (e.g., loss) is denoted by `l`
        # so `dl_dy` is the derivative of `l` wrt `y`
        x = model[:x] # load decision variable `x` into scope
        dl_dy = zeros(T, size(dl_dx))
        dl_dq = zeros(T, size(dl_dx))
        # set sensitivities
        MOI.set.(model, DiffOpt.ReverseVariablePrimal(), x[:], dl_dx[:])
        # compute grad
        DiffOpt.reverse_differentiate!(model)
        # return gradient wrt objective function parameters
        obj_exp = MOI.get(model, DiffOpt.ReverseObjectiveFunction())
        # coeff of `x` in q'x = -2y'x
        dl_dq[:] .= JuMP.coefficient.(obj_exp, x[:])
        dq_dy = -2 # dq/dy = -2
        dl_dy[:] .= dl_dq[:] * dq_dy
        return (ChainRulesCore.NoTangent(), dl_dy,)
    end
    return pv, pullback_matrix_relu
end

For more details about backpropagation, visit Introduction, ChainRulesCore.jl.

Define the network

layer_size = 10
m = Flux.Chain(
    Flux.Dense(784, layer_size), # 784 being image linear dimension (28 x 28)
    matrix_relu,
    Flux.Dense(layer_size, 10), # 10 being the number of outcomes (0 to 9)
    Flux.softmax,
)
Chain(
  Dense(784 => 10),                     # 7_850 parameters
  Main.matrix_relu,
  Dense(10 => 10),                      # 110 parameters
  NNlib.softmax,
)                   # Total: 4 arrays, 7_960 parameters, 31.344 KiB.

Prepare data

N = 1000 # batch size
# Preprocessing train data
imgs = MLDatasets.MNIST.traintensor(1:N)
labels = MLDatasets.MNIST.trainlabels(1:N)
train_X = float.(reshape(imgs, size(imgs, 1) * size(imgs, 2), N)) # stack images
train_Y = Flux.onehotbatch(labels, 0:9);
# Preprocessing test data
test_imgs = MLDatasets.MNIST.testtensor(1:N)
test_labels = MLDatasets.MNIST.testlabels(1:N)
test_X = float.(reshape(test_imgs, size(test_imgs, 1) * size(test_imgs, 2), N))
test_Y = Flux.onehotbatch(test_labels, 0:9);
┌ Warning: MNIST.traintensor() is deprecated, use `MNIST(split=:train).features` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/7BhSm/src/datasets/vision/mnist.jl:157
┌ Warning: MNIST.trainlabels() is deprecated, use `MNIST(split=:train).targets` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/7BhSm/src/datasets/vision/mnist.jl:173
┌ Warning: MNIST.testtensor() is deprecated, use `MNIST(split=:test).features` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/7BhSm/src/datasets/vision/mnist.jl:165
┌ Warning: MNIST.testlabels() is deprecated, use `MNIST(split=:test).targets` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/7BhSm/src/datasets/vision/mnist.jl:180

Define input data The original data is repeated epochs times because Flux.train! only loops through the data set once

epochs = 50 # ~1 minute (i7 8th gen with 16gb RAM)
# epochs = 100 # leads to 77.8% in about 2 minutes
dataset = repeated((train_X, train_Y), epochs);

Network training

training loss function, Flux optimizer

custom_loss(x, y) = Flux.crossentropy(m(x), y)
opt = Flux.ADAM()
evalcb = () -> @show(custom_loss(train_X, train_Y))
#11 (generic function with 1 method)

Train to optimize network parameters

@time Flux.train!(custom_loss, Flux.params(m), dataset, opt, cb = Flux.throttle(evalcb, 5));
custom_loss(train_X, train_Y) = 2.3553663399193323
custom_loss(train_X, train_Y) = 2.2240455562806773
custom_loss(train_X, train_Y) = 2.1510348515820947
custom_loss(train_X, train_Y) = 2.0600817034320347
custom_loss(train_X, train_Y) = 1.9604446390166004
custom_loss(train_X, train_Y) = 1.8702694198201986
custom_loss(train_X, train_Y) = 1.7790908032583186
custom_loss(train_X, train_Y) = 1.69186586227992
custom_loss(train_X, train_Y) = 1.6101347352056274
custom_loss(train_X, train_Y) = 1.531688536945617
100.842656 seconds (106.24 M allocations: 4.890 GiB, 2.10% gc time, 1.03% compilation time)

Although our custom implementation takes time, it is able to reach similar accuracy as the usual ReLU function implementation.

Accuracy results

Average of correct guesses

accuracy(x, y) = Statistics.mean(Flux.onecold(m(x)) .== Flux.onecold(y));

Training accuracy

accuracy(train_X, train_Y)
0.562

Test accuracy

accuracy(test_X, test_Y)
0.478

Note that the accuracy is low due to simplified training. It is possible to increase the number of samples N, the number of epochs epoch and the connectivity inner.


This page was generated using Literate.jl.