Logistic regression
using DataFrames
using Plots
using RDatasets
using Convex
using SCS
This is an example logistic regression using RDatasets
's iris data. Our goal is to predict whether the iris species is versicolor using the sepal length and width and petal length and width.
iris = dataset("datasets", "iris");
iris[1:10, :]
10×5 DataFrame
Row | SepalLength | SepalWidth | PetalLength | PetalWidth | Species |
---|---|---|---|---|---|
Float64 | Float64 | Float64 | Float64 | Cat… | |
1 | 5.1 | 3.5 | 1.4 | 0.2 | setosa |
2 | 4.9 | 3.0 | 1.4 | 0.2 | setosa |
3 | 4.7 | 3.2 | 1.3 | 0.2 | setosa |
4 | 4.6 | 3.1 | 1.5 | 0.2 | setosa |
5 | 5.0 | 3.6 | 1.4 | 0.2 | setosa |
6 | 5.4 | 3.9 | 1.7 | 0.4 | setosa |
7 | 4.6 | 3.4 | 1.4 | 0.3 | setosa |
8 | 5.0 | 3.4 | 1.5 | 0.2 | setosa |
9 | 4.4 | 2.9 | 1.4 | 0.2 | setosa |
10 | 4.9 | 3.1 | 1.5 | 0.1 | setosa |
We'll define Y
as the outcome variable: +1 for versicolor, -1 otherwise.
Y = [species == "versicolor" ? 1.0 : -1.0 for species in iris.Species]
150-element Vector{Float64}:
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
⋮
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
We'll create our data matrix with one column for each feature (first column corresponds to offset).
X = hcat(
ones(size(iris, 1)),
iris.SepalLength,
iris.SepalWidth,
iris.PetalLength,
iris.PetalWidth,
);
Now to solve the logistic regression problem.
n, p = size(X)
beta = Variable(p)
problem = minimize(logisticloss(-Y .* (X * beta)))
solve!(problem, SCS.Optimizer; silent = true)
Problem statistics
problem is DCP : true
number of variables : 1 (5 scalar elements)
number of constraints : 0 (0 scalar elements)
number of coefficients : 1_050
number of atoms : 6
Solution summary
termination status : OPTIMAL
primal status : FEASIBLE_POINT
dual status : FEASIBLE_POINT
objective value : 72.535
Expression graph
minimize
└─ sum (convex; real)
└─ logsumexp (convex; real)
└─ hcat (affine; real)
├─ …
└─ …
Let's see how well the model fits.
using Plots
logistic(x::Real) = inv(exp(-x) + one(x))
perm = sortperm(vec(X * evaluate(beta)))
plot(1:n, (Y[perm] .+ 1) / 2, st = :scatter)
plot!(1:n, logistic.(X * evaluate(beta))[perm])
This page was generated using Literate.jl.