Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
matrix:
group:
- Core
- OptimizationAuglag
- OptimizationBBO
- OptimizationCMAEvolutionStrategy
- OptimizationEvolutionary
Expand Down
25 changes: 25 additions & 0 deletions lib/OptimizationAuglag/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name = "OptimizationAuglag"
uuid = "2ea93f80-9333-43a1-a68d-1f53b957a421"
authors = ["paramthakkar123 <paramthakkar864@gmail.com>"]
version = "0.1.0"

[deps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be able to just need optimizationbase?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay will replace with that one


[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
ForwardDiff = "1.0.1"
MLUtils = "0.4.8"
Optimization = "4.4.0"
OptimizationBase = "2.10.0"
OptimizationOptimisers = "0.3.8"
Test = "1.11.0"

[targets]
test = ["Test"]
14 changes: 11 additions & 3 deletions src/auglag.jl → ...imizationAuglag/src/OptimizationAuglag.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
module OptimizationAuglag

using Optimization
using OptimizationBase.SciMLBase: OptimizationProblem, OptimizationFunction, OptimizationStats
using OptimizationBase.LinearAlgebra: norm

@kwdef struct AugLag
inner::Any
τ = 0.5
Expand All @@ -20,7 +26,7 @@ SciMLBase.requiresgradient(::AugLag) = true
SciMLBase.allowsconstraints(::AugLag) = true
SciMLBase.requiresconsjac(::AugLag) = true

function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::AugLag;
function __map_optimizer_args(cache::OptimizationBase.OptimizationCache, opt::AugLag;
callback = nothing,
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
Expand Down Expand Up @@ -110,7 +116,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
cache.f.cons(cons_tmp, θ)
cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds]
cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds]
opt_state = Optimization.OptimizationState(u = θ, objective = x[1], p = p)
opt_state = Optimization.OptimizationState(u = θ, objective = x[1])
if cache.callback(opt_state, x...)
error("Optimization halted by callback.")
end
Expand Down Expand Up @@ -176,10 +182,12 @@ function SciMLBase.__solve(cache::OptimizationCache{
break
end
end
stats = Optimization.OptimizationStats(; iterations = maxiters,
stats = OptimizationStats(; iterations = maxiters,
time = 0.0, fevals = maxiters, gevals = maxiters)
return SciMLBase.build_solution(
cache, cache.opt, θ, x,
stats = stats, retcode = opt_ret)
end
end

end
36 changes: 36 additions & 0 deletions lib/OptimizationAuglag/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using OptimizationBase
using MLUtils
using OptimizationOptimisers
using OptimizationAuglag
using ForwardDiff
using OptimizationBase: OptimizationCache
using OptimizationBase.SciMLBase: OptimizationFunction
using Test

@testset "OptimizationAuglag.jl" begin
x0 = (-pi):0.001:pi
y0 = sin.(x0)
data = MLUtils.DataLoader((x0, y0), batchsize = 126)

function loss(coeffs, data)
ypred = [evalpoly(data[1][i], coeffs) for i in eachindex(data[1])]
return sum(abs2, ypred .- data[2])
end

function cons1(res, coeffs, p = nothing)
res[1] = coeffs[1] * coeffs[5] - 1
return nothing
end

optf = OptimizationFunction(loss, OptimizationBase.AutoSparseForwardDiff(), cons = cons1)
callback = (st, l) -> (@show l; return false)

initpars = rand(5)
l0 = optf(initpars, (x0, y0))

prob = OptimizationProblem(optf, initpars, data, lcons = [-Inf], ucons = [1],
lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0])
opt = solve(
prob, OptimizationAuglag.AugLag(; inner = Adam()), maxiters = 10000, callback = callback)
@test opt.objective < l0
end
1 change: 0 additions & 1 deletion src/Optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ include("utils.jl")
include("state.jl")
include("lbfgsb.jl")
include("sophia.jl")
include("auglag.jl")

export solve

Expand Down
6 changes: 0 additions & 6 deletions test/native.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,6 @@ prob = OptimizationProblem(optf, initpars, (x0, y0), lcons = [-Inf], ucons = [0.
opt1 = solve(prob, Optimization.LBFGS(), maxiters = 1000, callback = callback)
@test opt1.objective < l0

prob = OptimizationProblem(optf, initpars, data, lcons = [-Inf], ucons = [1],
lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0])
opt = solve(
prob, Optimization.AugLag(; inner = Adam()), maxiters = 10000, callback = callback)
@test opt.objective < l0

optf1 = OptimizationFunction(loss, AutoSparseForwardDiff())
prob1 = OptimizationProblem(optf1, rand(5), data)
sol1 = solve(prob1, OptimizationOptimisers.Adam(), maxiters = 1000, callback = callback)
Expand Down
Loading