@@ -2,6 +2,7 @@ using Flux.Optimise
22using Flux. Optimise: runall
33using Flux: Params, gradient
44import FillArrays, ComponentArrays
5+ import Optimisers
56using Test
67using Random
78
@@ -167,21 +168,19 @@ end
167168@testset " update!: handle ComponentArrays" begin
168169 w = ComponentArrays. ComponentArray (a= 1.0 , b= [2 , 1 , 4 ], c= (a= 2 , b= [1 , 2 ]))
169170 wold = deepcopy (w)
170- θ = Flux. params ([w])
171- gs = gradient (() -> sum (w. a) + sum (w. c. b), θ)
172- opt = Descent (0.1 )
173- Flux. update! (opt, θ, gs)
174- @test w. a ≈ wold. a .- 0.1
171+ opt_state = Optimisers. setup (Optimisers. Descent (0.1 ), w)
172+ gs = gradient (w -> w. a + sum (w. c. b), w)[1 ]
173+ Flux. update! (opt_state, w, gs)
174+ @test w. a ≈ wold. a - 0.1
175175 @test w. b ≈ wold. b
176176 @test w. c. b ≈ wold. c. b .- 0.1
177177 @test w. c. a ≈ wold. c. a
178178
179179 w = ComponentArrays. ComponentArray (a= 1.0 , b= [2 , 1 , 4 ], c= (a= 2 , b= [1 , 2 ]))
180180 wold = deepcopy (w)
181- θ = Flux. params ([w])
182- gs = gradient (() -> sum (w), θ)
183- opt = Descent (0.1 )
184- Flux. update! (opt, θ, gs)
181+ opt_state = Optimisers. setup (Optimisers. Descent (0.1 ), w)
182+ gs = gradient (w -> sum (w), w)[1 ]
183+ Flux. update! (opt_state, w, gs)
185184 @test w ≈ wold .- 0.1
186185end
187186
0 commit comments