From ec8699de115e01bf283a3c793b6f0fe5f7eca85a Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 25 Apr 2023 01:45:06 +0200 Subject: [PATCH 01/15] Flux.state --- src/loading.jl | 29 +++++++++++++++++++++++++++++ test/utils.jl | 17 +++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/src/loading.jl b/src/loading.jl index 5cdd129936..dbff945f09 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -104,3 +104,32 @@ function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet()) return dst end + +""" + state(x; full=false) + +Return an object with the same nested structure as `x` +according to `Functors.children()`, but made only of +basic containers (e.g. named tuples, tuples, arrays, and dictionaries). + +If `full` is `false` (default), then only arrays and scalar original leaves are used as leaf values in the return, +with the other leaves being replaced by `nothing`. + +This method is particularly useful for saving and loading models, since it doesn't +require the user to specify the model type. +The returned state, can be passed to `loadmodel!` to restore the model. +""" +function state(x; full=false) + if Functors.isleaf(x) + if full + return x + else + return x isa Union{Number, AbstractArray} ? x : nothing + end + else + return valuemap(c -> state(c; full), Functors.children(x)) + end +end + +valuemap(f, x) = map(f, x) +valuemap(f, x::Dict) = Dict(k => f(v) for (k, v) in x) diff --git a/test/utils.jl b/test/utils.jl index bda738f6a0..7234ae3ec7 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -612,6 +612,23 @@ end @test ∇p ≈ destructure(∇m)[1] end end + + @testset "state" begin + m1 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5 => 2))) + m2 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.zeros32(2, 5), Flux.ones32(2)), Dense(5 => 2))) + s = Flux.state(m5) + @test s isa NamedTuple + @test fieldnames(typeof(s)) == (:layers,) + @test s.layers isa Tuple + @test length(s.layers) == 2 + @test s.layers[1].weight === m5[1].weight + @test s.layers[1].σ === nothing + @test s.layers[2].layers[1].weight === m5[2].layers[1].weight + + Flux.loadmodel!(m6, s) + @test m6[1].weight == m5[1].weight + @test all(m6[2].layers[1].bias .== m5[2].layers[1].bias) + end end @testset "Train and test mode" begin From db801711373a930285cc2e499637bd221b5cf167 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 25 Apr 2023 01:56:30 +0200 Subject: [PATCH 02/15] some docs --- NEWS.md | 4 +++- docs/make.jl | 1 + docs/src/models/saving.md | 6 ++++++ docs/src/saving.md | 4 ++++ 4 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 docs/src/models/saving.md diff --git a/NEWS.md b/NEWS.md index e7fad6ccf0..6e9f80f73f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # Flux Release Notes +<<<<<<< HEAD See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release. ## v0.13.16 @@ -7,12 +8,13 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl Thus `LayerNorm(3; ϵ=1e-4)` (not `ε`!) should become `LayerNorm(3; eps=1e-4)`. * `DataLoader(...) |> gpu` will now produce a special iterator, moving each batch as needed, instead of giving an error. +* Added `Flux.state` returning the internal state of the model for serialization. ## v0.13.15 * Added [MultiHeadAttention](https://github.com/FluxML/Flux.jl/pull/2146) layer. * `f16, f32, f64` now specifically target floating point arrays (i.e. integers arrays and other types are preserved). * `f16, f32, f64` can now handle `Complex{<:AbstractFloat}` arrays. -* Added `EmbeddingBag` layer +* Added `EmbeddingBag` layer. ## v0.13.14 * Fixed various deprecation warnings, from `Zygone.@nograd` and `Vararg`. diff --git a/docs/make.jl b/docs/make.jl index a536014d99..f82415f669 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -35,6 +35,7 @@ makedocs( "Gradients -- Zygote.jl" => "training/zygote.md", "Batching Data -- MLUtils.jl" => "data/mlutils.md", "OneHotArrays.jl" => "data/onehot.md", + "Saving and Loading" => "models/saving.md" "Low-level Operations -- NNlib.jl" => "models/nnlib.md", "Nested Structures -- Functors.jl" => "models/functors.md", ], diff --git a/docs/src/models/saving.md b/docs/src/models/saving.md new file mode 100644 index 0000000000..cf7672d375 --- /dev/null +++ b/docs/src/models/saving.md @@ -0,0 +1,6 @@ +# Saving and Loading + +```julia +Flux.loadparams! +Flux.state +``` diff --git a/docs/src/saving.md b/docs/src/saving.md index 853f4b0d9c..e74b9cc2ef 100644 --- a/docs/src/saving.md +++ b/docs/src/saving.md @@ -138,3 +138,7 @@ exactly where you left off. BSON is smart enough to [cache values](https://githu opt = Adam() @save "model-$(now()).bson" model opt ``` + +## Saving the state only + +An alternative From f97602b3e69ca6d7aa45f9c075b91d1c7bd257be Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 25 Apr 2023 01:57:26 +0200 Subject: [PATCH 03/15] some docs --- docs/src/saving.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/src/saving.md b/docs/src/saving.md index e74b9cc2ef..1ce5143fcd 100644 --- a/docs/src/saving.md +++ b/docs/src/saving.md @@ -141,4 +141,7 @@ opt = Adam() ## Saving the state only -An alternative +An alternative ... TODO + +```julia + From e4e3921bced0ebe6267e5bbce556d769e742c175 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 25 Apr 2023 08:35:39 +0200 Subject: [PATCH 04/15] add keep keyword --- docs/make.jl | 1 - docs/src/destructure.md | 9 +- docs/src/models/saving.md | 6 - src/loading.jl | 49 +++++--- test/loading.jl | 227 ++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 4 + test/utils.jl | 193 -------------------------------- 7 files changed, 272 insertions(+), 217 deletions(-) delete mode 100644 docs/src/models/saving.md create mode 100644 test/loading.jl diff --git a/docs/make.jl b/docs/make.jl index f82415f669..a536014d99 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -35,7 +35,6 @@ makedocs( "Gradients -- Zygote.jl" => "training/zygote.md", "Batching Data -- MLUtils.jl" => "data/mlutils.md", "OneHotArrays.jl" => "data/onehot.md", - "Saving and Loading" => "models/saving.md" "Low-level Operations -- NNlib.jl" => "models/nnlib.md", "Nested Structures -- Functors.jl" => "models/functors.md", ], diff --git a/docs/src/destructure.md b/docs/src/destructure.md index 6e9eac191e..0ccfff54f8 100644 --- a/docs/src/destructure.md +++ b/docs/src/destructure.md @@ -72,4 +72,11 @@ Another kind of flat view of a nested model is provided by the `modules` command ```@docs Flux.modules -``` \ No newline at end of file +``` + +### Saving and Loading + +```@docs +Flux.loadmodel! +Flux.state +``` diff --git a/docs/src/models/saving.md b/docs/src/models/saving.md deleted file mode 100644 index cf7672d375..0000000000 --- a/docs/src/models/saving.md +++ /dev/null @@ -1,6 +0,0 @@ -# Saving and Loading - -```julia -Flux.loadparams! -Flux.state -``` diff --git a/src/loading.jl b/src/loading.jl index dbff945f09..8e2e289ca7 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -106,30 +106,47 @@ function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet()) end """ - state(x; full=false) + state(x; keep = leaf -> !(leaf isa Function)) Return an object with the same nested structure as `x` -according to `Functors.children()`, but made only of +according to `Functors.children`, but made only of basic containers (e.g. named tuples, tuples, arrays, and dictionaries). -If `full` is `false` (default), then only arrays and scalar original leaves are used as leaf values in the return, -with the other leaves being replaced by `nothing`. +This method is particularly useful for saving and loading models, +since it doesn't require the user to specify the model type. +The state can be passed to `loadmodel!` to restore the model. -This method is particularly useful for saving and loading models, since it doesn't -require the user to specify the model type. -The returned state, can be passed to `loadmodel!` to restore the model. +The `keep` function is applied on the leaves of `x`. +If `keep(leaf)` is `false` , the leaf is replaced by `nothing`, +otherwise it is left as is. By default, all functions are excluded. + +# Examples + +```julia-repl +julia> m1 = Chain(Dense(1, 2, tanh), Dense(2, 1)); + +julia> m2 = Chain(Dense(1, 2, tanh), Dense(2, 1)); + +julia> s = Flux.state(m1) +layers = ((weight = Float32[-0.56867087; 1.229064;;], bias = Float32[0.0, 0.0], σ = nothing), (weight = Float32[0.23323897 -0.5561147], bias = Float32[0.0], σ = nothing)),) + +julia> Flux.loadmodel!(m2, s); + +julia> m2[1].weight == m1[1].weight +true +``` """ -function state(x; full=false) +function state(x; keep = _state_keep) if Functors.isleaf(x) - if full - return x - else - return x isa Union{Number, AbstractArray} ? x : nothing - end + return keep(x) ? x : nothing else - return valuemap(c -> state(c; full), Functors.children(x)) + return _valuemap(c -> state(c; keep), Functors.children(x)) end end -valuemap(f, x) = map(f, x) -valuemap(f, x::Dict) = Dict(k => f(v) for (k, v) in x) +_state_keep(x::Function) = false +_state_keep(x) = true + +# map for tuples, namedtuples, and dicts +_valuemap(f, x) = map(f, x) +_valuemap(f, x::Dict) = Dict(k => f(v) for (k, v) in x) diff --git a/test/loading.jl b/test/loading.jl new file mode 100644 index 0000000000..11ea368462 --- /dev/null +++ b/test/loading.jl @@ -0,0 +1,227 @@ + +ls(dims...) = reshape(collect(Float32, 1:prod(dims)), dims...) # accepts dims in reverse order to Dense +dl(nin, nout, bias) = Dense(ls(nout, nin), bias(nout)) +dm(bias) = Chain( + dl(3, 5, bias), + dl(5, 4, bias), + dl(4, 3, bias) +) + +nobias(n) = false +testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt))) + @test l1.weight == l2.weight + @test l1.bias == l2.bias + @test_skip typeof(l1.bias) === typeof(l2.bias) +end + + +@testset "loadmodel!(dst, src)" begin + m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) + m2 = Chain(Dense(10, 5), Dense(5, 2)) + m3 = Chain(Conv((3, 3), 3 => 16), Dense(5, 2)) + m4 = Chain(Dense(10, 6), Dense(6, 2)) + m5 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5, 2))) + m6 = Chain(Dense(10, 5), Parallel(+, Dense(5, 2), Dense(5, 2))) + + loadmodel!(m1, m2) + # trainable parameters copy over + @test m1[1].weight == m2[1].weight + @test m1[1].bias == m2[1].bias + # non-array leaves are untouched + @test m1[2].σ == relu + + loadmodel!(m5, m6) + # more complex nested structures also work + @test m5[1].weight == m6[1].weight + @test m5[2][1].weight == m6[2][1].weight + # false bias is not overwritten + @test m5[2][1].bias == false + + # mismatched nodes throw an error + @test_throws ArgumentError loadmodel!(m1, m3) + @test_throws ArgumentError loadmodel!(m1, m5) + # size mismatches throw an error + @test_throws DimensionMismatch loadmodel!(m1, m4) + + # tests for BatchNorm and Dropout + m1 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), Flux.flatten, Dropout(0.2)) + m2 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), x -> reshape(x, :, size(x)[end]), Dropout(0.1)) + m2[2].μ .= rand(Float32, size(m2[2].μ)...) + loadmodel!(m1, m2) + # non-trainable parameters are copied as well + @test m1[2].μ == m2[2].μ + # functions are not copied + @test m1[3] == Flux.flatten + # dropout rate is not copied + @test m1[4].p == 0.2 + + # from LegolasFlux (https://github.com/beacon-biosignals/LegolasFlux.jl/blob/80569ab63a8248a8a063c76e0bbf701f4ada9bd4/examples/digits.jl#L33) + # tests Chain(...) vs Chain([...]) + # tests MaxPool + # tests testmode!/trainmode! is not copied + # tests Dense, Conv, BatchNorm, Dropout (like above) but in a bigger model + chain1 = Chain(Dropout(0.2), + Conv((3, 3), 1 => 32, relu), + BatchNorm(32, relu), + MaxPool((2, 2)), + Dropout(0.2), + Conv((3, 3), 32 => 16, relu), + Dropout(0.2), + MaxPool((2, 2)), + Dropout(0.2), + Conv((3, 3), 16 => 10, relu), + Dropout(0.2), + x -> reshape(x, :, size(x, 4)), + Dropout(0.2), + Dense(90, 10), + softmax) + chain2 = Chain([Dropout(0.1), + Conv((3, 3), 1 => 32, relu), + BatchNorm(32, relu), + MaxPool((3, 3)), + Dropout(0.1), + Conv((3, 3), 32 => 16, relu), + Dropout(0.1), + MaxPool((3, 3)), + Dropout(0.1), + Conv((3, 3), 16 => 10, relu), + Dropout(0.1), + x -> reshape(x, :, size(x, 4)), + Dropout(0.1), + Dense(90, 10), + softmax]) + chain2[3].μ .= 5f0 + chain2[3].σ² .= 2f0 + testmode!(chain2) + loadmodel!(chain1, chain2) + for (dst, src) in zip(chain1, chain2) + if dst isa Dropout + @test dst.p == 0.2 + elseif dst isa Union{Conv, Dense} + @test dst.weight == src.weight + @test dst.bias == src.bias + elseif dst isa MaxPool + @test dst.k == (2, 2) + elseif dst isa BatchNorm + @test dst.μ == src.μ + @test dst.σ² == src.σ² + @test isnothing(dst.active) + end + end + + # copy only a subset of the model + chain1[end - 1].weight .= 1f0 + chain1[3].μ .= 3f0 + chain1[2].bias .= 5f0 + loadmodel!(chain2[end - 1], chain1[end - 1]) + loadmodel!(chain2[3], chain1[3]) + @test chain2[end - 1].weight == chain1[end - 1].weight + @test chain2[3].μ == chain1[3].μ + @test chain2[2].bias != chain1[2].bias + + # test shared weights + shared_dst = Dense(10 => 10) + shared_src = Dense(10 => 10) + # matched weights are okay + m1 = Chain(shared_dst, Dense(shared_dst.weight)) + m2 = Chain(shared_src, Dense(shared_src.weight)) + loadmodel!(m1, m2) + @test m1[1].weight === m1[2].weight + @test m1[1].weight == m2[2].weight + # mismatched weights are an error + m2 = Chain(Dense(10 => 10), Dense(10 => 10)) + @test_throws ErrorException loadmodel!(m1, m2) + # loading into tied weights with absent parameter is okay when the dst == zero + b = Flux.zeros32(5) + m1 = Chain(Dense(10 => 5; bias = b), Dense(5 => 5; bias = b)) + m2 = Chain(Dense(10 => 5; bias = Flux.zeros32(5)), Dense(5 => 5; bias = false)) + loadmodel!(m1, m2) + @test m1[1].bias === m1[2].bias + @test iszero(m1[1].bias) + # loading into tied weights with absent parameter is bad when the dst != zero + m2[1].bias .= 1 + @test_throws ErrorException loadmodel!(m1, m2) + + @testset "loadmodel! & filter" begin + m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) + m2 = Chain(Dense(10, 5), Dropout(0.2), Dense(5, 2)) + m3 = Chain(Dense(10, 5), Dense(5, 2, relu)) + + # this will not error cause Dropout is skipped + loadmodel!(m1, m2; filter = x -> !(x isa Dropout)) + @test m1[1].weight == m2[1].weight + @test m1[2].weight == m2[3].weight + + # this will not error cause Dropout is skipped + loadmodel!(m2, m3; filter = x -> !(x isa Dropout)) + @test m3[1].weight == m2[1].weight + @test m3[2].weight == m2[3].weight + end + + @testset "loadmodel! & absent bias" begin + m0 = Chain(Dense(2 => 3; bias=false, init = Flux.ones32), Dense(3 => 1)) + m1 = Chain(Dense(2 => 3; bias = Flux.randn32(3)), Dense(3 => 1)) + m2 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1)) + + Flux.loadmodel!(m1, m2) + @test m1[1].bias == 7:9 + @test sum(m1[1].weight) == 21 + + # load from a model without bias -- should ideally recognise the `false` but `Params` doesn't store it + m1 = Flux.loadmodel!(m1, m0) + @test iszero(m1[1].bias) + @test sum(m1[1].weight) == 6 # written before error + + # load into a model without bias -- should it ignore the parameter which has no home, or error? + m0 = Flux.loadmodel!(m0, m2) + @test iszero(m0[1].bias) # obviously unchanged + @test sum(m0[1].weight) == 21 + end +end + +@testset "loadmodel!(dst, src) with BSON" begin + m1 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1)) + m2 = Chain(Dense(Float32[0 0; 0 0; 0 0], Float32[0, 0, 0]), Dense(3 => 1)) + @test m1[1].weight != m2[1].weight + mktempdir() do dir + BSON.@save joinpath(dir, "test.bson") m1 + m2 = Flux.loadmodel!(m2, BSON.load(joinpath(dir, "test.bson"))[:m1]) + @test m1[1].weight == m2[1].weight + end +end + +@testset "state" begin + m1 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5 => 2))) + m2 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.zeros32(2, 5), Flux.ones32(2)), Dense(5 => 2))) + s = Flux.state(m1) + @test s isa NamedTuple + @test fieldnames(typeof(s)) == (:layers,) + @test s.layers isa Tuple + @test length(s.layers) == 2 + @test s.layers[1].weight === m1[1].weight + @test s.layers[1].σ === nothing + @test s.layers[2].layers[1].weight === m1[2].layers[1].weight + + Flux.loadmodel!(m2, s) + @test m2[1].weight == m1[1].weight + @test all(m2[2].layers[1].bias .== m1[2].layers[1].bias) + + @testset "track active state and batch norm params" begin + m3 = Chain(Dense(10, 5), Dropout(0.2), Dense(5, 2), BatchNorm(2)) + trainmode!(m3) + s = Flux.state(m3) + @test s.layers[2].active == true + @test s.layers[2].p == 0.2 + @test s.layers[4] == (λ = nothing, β = Float32[0.0, 0.0], γ = Float32[1.0, 1.0], + μ = Float32[0.0, 0.0], σ² = Float32[1.0, 1.0], ϵ = 1.0f-5, momentum = 0.1f0, affine = true, + track_stats = true, active = true, chs = 2) + end + + @testset "keep" begin + s = Flux.state(m1, keep = x -> x isa AbstractArray) + @test s.layers[1].weight isa AbstractArray + @test s.layers[1].σ === nothing + @test s.layers[2].connection === nothing + @test s.layers[2].layers[1].bias === nothing + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 09a65bb046..8285a712d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,6 +17,10 @@ Random.seed!(0) include("utils.jl") end + @testset "Loading" begin + include("loading.jl") + end + @testset "Optimise / Train" begin include("optimise.jl") include("train.jl") diff --git a/test/utils.jl b/test/utils.jl index 7234ae3ec7..bac8deefa6 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -419,182 +419,6 @@ end @test_skip typeof(l1.bias) === typeof(l2.bias) end - - @testset "loadmodel!(dst, src)" begin - m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) - m2 = Chain(Dense(10, 5), Dense(5, 2)) - m3 = Chain(Conv((3, 3), 3 => 16), Dense(5, 2)) - m4 = Chain(Dense(10, 6), Dense(6, 2)) - m5 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5, 2))) - m6 = Chain(Dense(10, 5), Parallel(+, Dense(5, 2), Dense(5, 2))) - - loadmodel!(m1, m2) - # trainable parameters copy over - @test m1[1].weight == m2[1].weight - @test m1[1].bias == m2[1].bias - # non-array leaves are untouched - @test m1[2].σ == relu - - loadmodel!(m5, m6) - # more complex nested structures also work - @test m5[1].weight == m6[1].weight - @test m5[2][1].weight == m6[2][1].weight - # false bias is not overwritten - @test m5[2][1].bias == false - - # mismatched nodes throw an error - @test_throws ArgumentError loadmodel!(m1, m3) - @test_throws ArgumentError loadmodel!(m1, m5) - # size mismatches throw an error - @test_throws DimensionMismatch loadmodel!(m1, m4) - - # tests for BatchNorm and Dropout - m1 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), Flux.flatten, Dropout(0.2)) - m2 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), x -> reshape(x, :, size(x)[end]), Dropout(0.1)) - m2[2].μ .= rand(Float32, size(m2[2].μ)...) - loadmodel!(m1, m2) - # non-trainable parameters are copied as well - @test m1[2].μ == m2[2].μ - # functions are not copied - @test m1[3] == Flux.flatten - # dropout rate is not copied - @test m1[4].p == 0.2 - - # from LegolasFlux (https://github.com/beacon-biosignals/LegolasFlux.jl/blob/80569ab63a8248a8a063c76e0bbf701f4ada9bd4/examples/digits.jl#L33) - # tests Chain(...) vs Chain([...]) - # tests MaxPool - # tests testmode!/trainmode! is not copied - # tests Dense, Conv, BatchNorm, Dropout (like above) but in a bigger model - chain1 = Chain(Dropout(0.2), - Conv((3, 3), 1 => 32, relu), - BatchNorm(32, relu), - MaxPool((2, 2)), - Dropout(0.2), - Conv((3, 3), 32 => 16, relu), - Dropout(0.2), - MaxPool((2, 2)), - Dropout(0.2), - Conv((3, 3), 16 => 10, relu), - Dropout(0.2), - x -> reshape(x, :, size(x, 4)), - Dropout(0.2), - Dense(90, 10), - softmax) - chain2 = Chain([Dropout(0.1), - Conv((3, 3), 1 => 32, relu), - BatchNorm(32, relu), - MaxPool((3, 3)), - Dropout(0.1), - Conv((3, 3), 32 => 16, relu), - Dropout(0.1), - MaxPool((3, 3)), - Dropout(0.1), - Conv((3, 3), 16 => 10, relu), - Dropout(0.1), - x -> reshape(x, :, size(x, 4)), - Dropout(0.1), - Dense(90, 10), - softmax]) - chain2[3].μ .= 5f0 - chain2[3].σ² .= 2f0 - testmode!(chain2) - loadmodel!(chain1, chain2) - for (dst, src) in zip(chain1, chain2) - if dst isa Dropout - @test dst.p == 0.2 - elseif dst isa Union{Conv, Dense} - @test dst.weight == src.weight - @test dst.bias == src.bias - elseif dst isa MaxPool - @test dst.k == (2, 2) - elseif dst isa BatchNorm - @test dst.μ == src.μ - @test dst.σ² == src.σ² - @test isnothing(dst.active) - end - end - - # copy only a subset of the model - chain1[end - 1].weight .= 1f0 - chain1[3].μ .= 3f0 - chain1[2].bias .= 5f0 - loadmodel!(chain2[end - 1], chain1[end - 1]) - loadmodel!(chain2[3], chain1[3]) - @test chain2[end - 1].weight == chain1[end - 1].weight - @test chain2[3].μ == chain1[3].μ - @test chain2[2].bias != chain1[2].bias - - # test shared weights - shared_dst = Dense(10 => 10) - shared_src = Dense(10 => 10) - # matched weights are okay - m1 = Chain(shared_dst, Dense(shared_dst.weight)) - m2 = Chain(shared_src, Dense(shared_src.weight)) - loadmodel!(m1, m2) - @test m1[1].weight === m1[2].weight - @test m1[1].weight == m2[2].weight - # mismatched weights are an error - m2 = Chain(Dense(10 => 10), Dense(10 => 10)) - @test_throws ErrorException loadmodel!(m1, m2) - # loading into tied weights with absent parameter is okay when the dst == zero - b = Flux.zeros32(5) - m1 = Chain(Dense(10 => 5; bias = b), Dense(5 => 5; bias = b)) - m2 = Chain(Dense(10 => 5; bias = Flux.zeros32(5)), Dense(5 => 5; bias = false)) - loadmodel!(m1, m2) - @test m1[1].bias === m1[2].bias - @test iszero(m1[1].bias) - # loading into tied weights with absent parameter is bad when the dst != zero - m2[1].bias .= 1 - @test_throws ErrorException loadmodel!(m1, m2) - - @testset "loadmodel! & filter" begin - m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) - m2 = Chain(Dense(10, 5), Dropout(0.2), Dense(5, 2)) - m3 = Chain(Dense(10, 5), Dense(5, 2, relu)) - - # this will not error cause Dropout is skipped - loadmodel!(m1, m2; filter = x -> !(x isa Dropout)) - @test m1[1].weight == m2[1].weight - @test m1[2].weight == m2[3].weight - - # this will not error cause Dropout is skipped - loadmodel!(m2, m3; filter = x -> !(x isa Dropout)) - @test m3[1].weight == m2[1].weight - @test m3[2].weight == m2[3].weight - end - - @testset "loadmodel! & absent bias" begin - m0 = Chain(Dense(2 => 3; bias=false, init = Flux.ones32), Dense(3 => 1)) - m1 = Chain(Dense(2 => 3; bias = Flux.randn32(3)), Dense(3 => 1)) - m2 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1)) - - Flux.loadmodel!(m1, m2) - @test m1[1].bias == 7:9 - @test sum(m1[1].weight) == 21 - - # load from a model without bias -- should ideally recognise the `false` but `Params` doesn't store it - m1 = Flux.loadmodel!(m1, m0) - @test iszero(m1[1].bias) - @test sum(m1[1].weight) == 6 # written before error - - # load into a model without bias -- should it ignore the parameter which has no home, or error? - m0 = Flux.loadmodel!(m0, m2) - @test iszero(m0[1].bias) # obviously unchanged - @test sum(m0[1].weight) == 21 - end - end - - @testset "loadmodel!(dst, src) with BSON" begin - m1 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1)) - m2 = Chain(Dense(Float32[0 0; 0 0; 0 0], Float32[0, 0, 0]), Dense(3 => 1)) - @test m1[1].weight != m2[1].weight - mktempdir() do dir - BSON.@save joinpath(dir, "test.bson") m1 - m2 = Flux.loadmodel!(m2, BSON.load(joinpath(dir, "test.bson"))[:m1]) - @test m1[1].weight == m2[1].weight - end - end - @testset "destructure" begin import Flux: destructure @testset "Bias type $bt" for bt in (zeros, nobias) @@ -612,23 +436,6 @@ end @test ∇p ≈ destructure(∇m)[1] end end - - @testset "state" begin - m1 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5 => 2))) - m2 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.zeros32(2, 5), Flux.ones32(2)), Dense(5 => 2))) - s = Flux.state(m5) - @test s isa NamedTuple - @test fieldnames(typeof(s)) == (:layers,) - @test s.layers isa Tuple - @test length(s.layers) == 2 - @test s.layers[1].weight === m5[1].weight - @test s.layers[1].σ === nothing - @test s.layers[2].layers[1].weight === m5[2].layers[1].weight - - Flux.loadmodel!(m6, s) - @test m6[1].weight == m5[1].weight - @test all(m6[2].layers[1].bias .== m5[2].layers[1].bias) - end end @testset "Train and test mode" begin From a4f960701566a474af3e33daca7810fb73613c75 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 28 Apr 2023 06:27:58 +0200 Subject: [PATCH 05/15] update --- docs/src/destructure.md | 6 +- docs/src/saving.md | 161 ++++++++++++++++++++-------------------- src/loading.jl | 60 +++++++++------ test/loading.jl | 20 ++--- 4 files changed, 132 insertions(+), 115 deletions(-) diff --git a/docs/src/destructure.md b/docs/src/destructure.md index 0ccfff54f8..1cdcad5ce7 100644 --- a/docs/src/destructure.md +++ b/docs/src/destructure.md @@ -74,9 +74,9 @@ Another kind of flat view of a nested model is provided by the `modules` command Flux.modules ``` -### Saving and Loading +### Save and Load ```@docs -Flux.loadmodel! Flux.state -``` +Flux.loadmodel! +``` \ No newline at end of file diff --git a/docs/src/saving.md b/docs/src/saving.md index 1ce5143fcd..e34fbfaae2 100644 --- a/docs/src/saving.md +++ b/docs/src/saving.md @@ -1,7 +1,10 @@ # Saving and Loading Models You may wish to save models so that they can be loaded and run in a later -session. The easiest way to do this is via +session. Flux provides a number of ways to do this. +The recommended way, which is the most robust one for long term storage, +is to use [`Flux.state`](@ref) in combination with a serialization format like +[JLD2.jl](https://juliaio.github.io/JLD2.jl/dev/) or [BSON.jl](https://github.com/JuliaIO/BSON.jl). Save a model: @@ -9,87 +12,42 @@ Save a model: ```jldoctest saving julia> using Flux -julia> model = Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax) -Chain( - Dense(10 => 5, relu), # 55 parameters - Dense(5 => 2), # 12 parameters - NNlib.softmax, -) # Total: 4 arrays, 67 parameters, 524 bytes. +julia> struct MyModel + net + end -julia> using BSON: @save +julia> Flux.@functor MyModel -julia> @save "mymodel.bson" model +julia> MyModel() = MyModel(Chain(Dense(10, 5, relu), Dense(5, 2))) + +julia> model = MyModel() +MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2))) + +julia> model_state = Flux.state(model); + +julia> using JLD2 + +julia> jldsave("mymodel.jld2"; model_state) ``` -Load it again: +Load it again in a new session using [`Flux.loadmodel!`](@ref): ```jldoctest saving -julia> using Flux # Flux must be loaded before calling @load +julia> using Flux, JLD2 -julia> using BSON: @load +julia> model_state = JLD2.load("mymodel.jld2", "model_state") -julia> @load "mymodel.bson" model +julia> model = MyModel(); # MyModel definition must be available -julia> model -Chain( - Dense(10 => 5, relu), # 55 parameters - Dense(5 => 2), # 12 parameters - NNlib.softmax, -) # Total: 4 arrays, 67 parameters, 524 bytes. +julia> Flux.loadmodel!(model, model_state); ``` -Models are just normal Julia structs, so it's fine to use any Julia storage -format for this purpose. BSON.jl is particularly well supported and most likely -to be forwards compatible (that is, models saved now will load in future -versions of Flux). - !!! note If a saved model's parameters are stored on the GPU, the model will not load later on if there is no GPU support available. It's best to [move your model to the CPU](gpu.md) with `cpu(model)` before saving it. -!!! warning - - Previous versions of Flux suggested saving only the model weights using - `@save "mymodel.bson" params(model)`. - This is no longer recommended and even strongly discouraged. - Saving models this way will only store the trainable parameters which - will result in incorrect behavior for layers like `BatchNorm`. - -```julia -julia> using Flux - -julia> model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax) -Chain( - Dense(10 => 5, relu), # 55 parameters - Dense(5 => 2), # 12 parameters - NNlib.softmax, -) # Total: 4 arrays, 67 parameters, 524 bytes. - -julia> weights = Flux.params(model); -``` - -Loading the model as shown above will return a new model with the stored parameters. -But sometimes you already have a model, and you want to load stored parameters into it. -This can be done as - -```julia -using Flux: loadmodel! -using BSON - -# some predefined model -model = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax) - -# load one model into another -model = loadmodel!(model, BSON.load("mymodel.bson")[:model]) -``` - -This ensures that the model loaded from `"mymodel.bson"` matches the structure of `model`. [`Flux.loadmodel!`](@ref) is also convenient for copying parameters between models in memory. - -```@docs -Flux.loadmodel! -``` ## Checkpointing @@ -98,50 +56,91 @@ In longer training runs it's a good idea to periodically save your model, so tha ```jldoctest saving julia> using Flux: throttle -julia> using BSON: @save +julia> using JLD2 -julia> m = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax) +julia> m = Chain(Dense(10 => 5, relu), Dense(5 => 2)) Chain( Dense(10 => 5, relu), # 55 parameters Dense(5 => 2), # 12 parameters - NNlib.softmax, ) # Total: 4 arrays, 67 parameters, 524 bytes. julia> evalcb = throttle(30) do - # Show loss - @save "model-checkpoint.bson" model + jldsave("model-checkpoint.jld2", model_state = Flux.state(m)) end; ``` -This will update the `"model-checkpoint.bson"` file every thirty seconds. +This will update the `"model-checkpoint.jld2"` file every thirty seconds. You can get more advanced by saving a series of models throughout training, for example ```julia -@save "model-$(now()).bson" model +jldsave("model-$(now()).jld2", model_state = Flux.state(m)) ``` -will produce a series of models like `"model-2018-03-06T02:57:10.41.bson"`. You +will produce a series of models like `"model-2018-03-06T02:57:10.41.jld2"`. You could also store the current test set loss, so that it's easy to (for example) revert to an older copy of the model if it starts to overfit. ```julia -@save "model-$(now()).bson" model loss = testloss() +jldsave("model-$(now()).jld2", model_state = Flux.state(m), loss = testloss()) ``` -Note that to resume a model's training, you might need to restore other stateful parts of your training loop. Possible examples are stateful optimisers (which usually utilize an `IdDict` to store their state), and the randomness used to partition the original data into the training and validation sets. +Note that to resume a model's training, you might need to restore other stateful parts of your training loop. Possible examples are the optimiser state and the randomness used to partition the original data into the training and validation sets. You can store the optimiser state alongside the model, to resume training -exactly where you left off. BSON is smart enough to [cache values](https://github.com/JuliaIO/BSON.jl/blob/v0.3.4/src/write.jl#L71) and insert links when saving, but only if it knows everything to be saved up front. Thus models and optimisers must be saved together to have the latter work after restoring. +exactly where you left off: ```julia -opt = Adam() -@save "model-$(now()).bson" model opt +model = MyModel() +opt_state = Flux.setup(AdamW(), model) + +# ... train model ... + +model_state = Flux.state(model) +jldsave("checkpoint_epoch=42.jld2"; model_state, opt_state) ``` -## Saving the state only +# Saving Models as Julia Structs -An alternative ... TODO +Models are just normal Julia structs, so it's fine to use any Julia storage +format to save the struct as it is instead of saving the state returned by [`Flux.state`](@ref). +[BSON.jl](https://github.com/JuliaIO/BSON.jl) is particularly convenient for this, +since it can also save anynomous functions, which are sometimes part of a model definition. -```julia +Save a model: + +```jldoctest saving +julia> using Flux + +julia> model = Chain(Dense(10, 5, NNlib.relu), Dense(5, 2)); + +julia> using BSON: @save + +julia> @save "mymodel.bson" model +``` +Load it again in a new session: + +```jldoctest saving +julia> using Flux, BSON + +julia> BSON.@load "mymodel.bson" model + +julia> model +Chain( + Dense(10 => 5, relu), # 55 parameters + Dense(5 => 2), # 12 parameters +) # Total: 4 arrays, 67 parameters, 524 bytes. +``` +!!! warning + Saving models this way could lead to compatibility issues across julia versions + and across Flux versions if some of the Flux layers' internals are changed. + It is therefore not recommended for long term storage, use [`Flux.state`](@ref) instead. + +!!! warning + + Previous versions of Flux suggested saving only the model weights using + `@save "mymodel.bson" params(model)`. + This is no longer recommended and even strongly discouraged. + Saving models this way will only store the trainable parameters which + will result in incorrect behavior for layers like `BatchNorm`. diff --git a/src/loading.jl b/src/loading.jl index 8e2e289ca7..00b3deda96 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -47,6 +47,8 @@ Non-array elements (such as activation functions) are not copied and need not ma Zero bias vectors and `bias=false` are considered equivalent (see extended help for more details). +See also [`Flux.state`](@ref). + # Examples ```julia julia> dst = Chain(Dense(Flux.ones32(2, 5), Flux.ones32(2), tanh), Dense(2 => 1; bias = [1f0])) @@ -106,23 +108,28 @@ function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet()) end """ - state(x; keep = leaf -> !(leaf isa Function)) + state(x) + +Return an object with the same nested structure as `x` according to `Functors.children`, +but made only of basic containers (e.g. named tuples, tuples, arrays, and dictionaries). -Return an object with the same nested structure as `x` -according to `Functors.children`, but made only of -basic containers (e.g. named tuples, tuples, arrays, and dictionaries). +Besides trainable and non-trainable arrays, the state will contain leaf nodes that are not arrays, +such as numbers, symbols, strings, and nothing values. The leaf types that end up in the state +could increase in the future. This method is particularly useful for saving and loading models, -since it doesn't require the user to specify the model type. -The state can be passed to `loadmodel!` to restore the model. +since the state contain only simple data types that can be easily serialized. -The `keep` function is applied on the leaves of `x`. -If `keep(leaf)` is `false` , the leaf is replaced by `nothing`, -otherwise it is left as is. By default, all functions are excluded. +The state can be passed to [`loadmodel!`](@ref) to restore the model. # Examples +## Copy the state into another model + ```julia-repl +julia> s = Flux.state(Dense(1, 2, tanh)) +(weight = Float32[0.5058468; 1.2398405;;], bias = Float32[0.0, 0.0], σ = missing) + julia> m1 = Chain(Dense(1, 2, tanh), Dense(2, 1)); julia> m2 = Chain(Dense(1, 2, tanh), Dense(2, 1)); @@ -135,18 +142,29 @@ julia> Flux.loadmodel!(m2, s); julia> m2[1].weight == m1[1].weight true ``` + +## Save and load with BSON +```julia-repl +julia> using BSON + +julia> BSON.@save "checkpoint.bson" model_state = s + +julia> Flux.loadmodel!(m2, BSON.load("checkpoint.bson")[:model_state]) +``` + +## Save and load with JLD2 + +```julia-repl +julia> using JLD2 + +julia> JLD2.jldsave("checkpoint.jld2", model_state = s) + +julia> Flux.loadmodel!(m2, JLD2.load("checkpoint.jld2", "model_state")) +``` """ -function state(x; keep = _state_keep) - if Functors.isleaf(x) - return keep(x) ? x : nothing - else - return _valuemap(c -> state(c; keep), Functors.children(x)) - end -end +state(x) = Functors.fmapstructure(x -> _state_keep(x) ? x : missing, x) -_state_keep(x::Function) = false -_state_keep(x) = true +const STATE_TYPES = Union{AbstractArray, Number, Nothing, AbstractString, Symbol} -# map for tuples, namedtuples, and dicts -_valuemap(f, x) = map(f, x) -_valuemap(f, x::Dict) = Dict(k => f(v) for (k, v) in x) +_state_keep(x::STATE_TYPES) = true +_state_keep(x) = false diff --git a/test/loading.jl b/test/loading.jl index 11ea368462..0bb230e952 100644 --- a/test/loading.jl +++ b/test/loading.jl @@ -199,7 +199,7 @@ end @test s.layers isa Tuple @test length(s.layers) == 2 @test s.layers[1].weight === m1[1].weight - @test s.layers[1].σ === nothing + @test s.layers[1].σ === missing @test s.layers[2].layers[1].weight === m1[2].layers[1].weight Flux.loadmodel!(m2, s) @@ -212,16 +212,16 @@ end s = Flux.state(m3) @test s.layers[2].active == true @test s.layers[2].p == 0.2 - @test s.layers[4] == (λ = nothing, β = Float32[0.0, 0.0], γ = Float32[1.0, 1.0], - μ = Float32[0.0, 0.0], σ² = Float32[1.0, 1.0], ϵ = 1.0f-5, momentum = 0.1f0, affine = true, - track_stats = true, active = true, chs = 2) + @test s.layers[4].λ === missing + for k in (:β, :γ, :μ, :σ², :ϵ, :momentum, :affine, :track_stats, :active, :chs) + @test s.layers[4][k] === getfield(m3[4], k) + end end - @testset "keep" begin - s = Flux.state(m1, keep = x -> x isa AbstractArray) - @test s.layers[1].weight isa AbstractArray - @test s.layers[1].σ === nothing - @test s.layers[2].connection === nothing - @test s.layers[2].layers[1].bias === nothing + @testset "saved types" begin + m = (num = 1, cnum = Complex(1.2, 2), str = "hello", arr = [1, 2, 3], + dict = Dict(:a => 1, :b => 2), tup = (1, 2, 3), sym = :a, nth = nothing) + s = Flux.state(m) + @test s == m end end From e3db0e13ee0fb91321878eb95ebc088a1b291e68 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 2 May 2023 18:13:27 +0200 Subject: [PATCH 06/15] apply suggestions --- docs/checkpoint.bson | Bin 0 -> 4222 bytes docs/make.jl | 4 +++- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 docs/checkpoint.bson diff --git a/docs/checkpoint.bson b/docs/checkpoint.bson new file mode 100644 index 0000000000000000000000000000000000000000..3540a73bb93624df64cc171d1a90a0992a6afcf6 GIT binary patch literal 4222 zcmeHJy=xRf6o1*>7)1-YR*;}#C31I0TS)>T1Y#w266SJ~bIINZI}_wWNF!;)bV;9D z>=Z2h6Kn(<{{_LyQvAI)JC{4}!yF`8i8sa0+kL-%^P72JvwNK=tT(GVSu<8yO|P#J z1yot5x{r70 z?|QwMZ`Ui$go1{uYgsQ*ZJioGC@j(Z$n}1OV(%v=S#Kbz_YU-BY82=Vv|x7Nnf=Lp zo9X?F*niidbP7mCcFiz4;$^87HP z-|cu_eOYrh9mo02_dguWyGhzgQ~p2h2hRN;^GREut4Ld)!Tssco3{HtJ@`I!mX#$- zTTe0L#Wm=7 Date: Tue, 2 May 2023 18:14:04 +0200 Subject: [PATCH 07/15] cleanup --- Project.toml | 13 +++++++------ docs/checkpoint.bson | Bin 4222 -> 0 bytes src/loading.jl | 32 +++++++++++++++++++------------- test/loading.jl | 4 +++- 4 files changed, 29 insertions(+), 20 deletions(-) delete mode 100644 docs/checkpoint.bson diff --git a/Project.toml b/Project.toml index 4e2dba884e..2fe32ac7da 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -24,6 +25,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +[weakdeps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + +[extensions] +AMDGPUExt = "AMDGPU" + [compat] AMDGPU = "0.4.13" Adapt = "3.0" @@ -44,9 +51,6 @@ Zygote = "0.6.49" cuDNN = "1" julia = "1.6" -[extensions] -AMDGPUExt = "AMDGPU" - [extras] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" @@ -59,6 +63,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON"] - -[weakdeps] -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" diff --git a/docs/checkpoint.bson b/docs/checkpoint.bson deleted file mode 100644 index 3540a73bb93624df64cc171d1a90a0992a6afcf6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4222 zcmeHJy=xRf6o1*>7)1-YR*;}#C31I0TS)>T1Y#w266SJ~bIINZI}_wWNF!;)bV;9D z>=Z2h6Kn(<{{_LyQvAI)JC{4}!yF`8i8sa0+kL-%^P72JvwNK=tT(GVSu<8yO|P#J z1yot5x{r70 z?|QwMZ`Ui$go1{uYgsQ*ZJioGC@j(Z$n}1OV(%v=S#Kbz_YU-BY82=Vv|x7Nnf=Lp zo9X?F*niidbP7mCcFiz4;$^87HP z-|cu_eOYrh9mo02_dguWyGhzgQ~p2h2hRN;^GREut4Ld)!Tssco3{HtJ@`I!mX#$- zTTe0L#Wm=7 s = Flux.state(Dense(1, 2, tanh)) -(weight = Float32[0.5058468; 1.2398405;;], bias = Float32[0.0, 0.0], σ = missing) - -julia> m1 = Chain(Dense(1, 2, tanh), Dense(2, 1)); - -julia> m2 = Chain(Dense(1, 2, tanh), Dense(2, 1)); +```jldoctest +julia> m1 = Chain(Dense(1, 2, tanh; init=ones), Dense(2, 1; init=ones)); julia> s = Flux.state(m1) -layers = ((weight = Float32[-0.56867087; 1.229064;;], bias = Float32[0.0, 0.0], σ = nothing), (weight = Float32[0.23323897 -0.5561147], bias = Float32[0.0], σ = nothing)),) +(layers = ((weight = [1.0; 1.0;;], bias = [0.0, 0.0], σ = missing), (weight = [1.0 1.0], bias = [0.0], σ = missing)),) + +julia> m2 = Chain(Dense(1, 2, tanh), Dense(2, 1; bias=false)); # weights are random numbers julia> Flux.loadmodel!(m2, s); -julia> m2[1].weight == m1[1].weight -true +julia> m2[1].weight # now the weights of m2 are the same as m1 +2×1 Matrix{Float32}: + 1.0 + 1.0 + +julia> Flux.state(trainmode!(Dropout(0.2))) # contains p & activity, but not RNG state +(p = 0.2, dims = missing, active = true, rng = missing) + +julia> Flux.state(BatchNorm(1)) # contains non-trainable arrays μ, σ² +(λ = missing, β = Float32[0.0], γ = Float32[1.0], μ = Float32[0.0], σ² = Float32[1.0], ϵ = 1.0f-5, momentum = 0.1f0, affine = true, track_stats = true, active = nothing, chs = 1) ``` ## Save and load with BSON + ```julia-repl julia> using BSON @@ -162,9 +168,9 @@ julia> JLD2.jldsave("checkpoint.jld2", model_state = s) julia> Flux.loadmodel!(m2, JLD2.load("checkpoint.jld2", "model_state")) ``` """ -state(x) = Functors.fmapstructure(x -> _state_keep(x) ? x : missing, x) +state(x) = Functors.fmapstructure(_state, x) const STATE_TYPES = Union{AbstractArray, Number, Nothing, AbstractString, Symbol} -_state_keep(x::STATE_TYPES) = true -_state_keep(x) = false +_state(x::STATE_TYPES) = x +_state(x) = missing diff --git a/test/loading.jl b/test/loading.jl index 0bb230e952..ce8e5ba757 100644 --- a/test/loading.jl +++ b/test/loading.jl @@ -220,7 +220,9 @@ end @testset "saved types" begin m = (num = 1, cnum = Complex(1.2, 2), str = "hello", arr = [1, 2, 3], - dict = Dict(:a => 1, :b => 2), tup = (1, 2, 3), sym = :a, nth = nothing) + bool = true, dict = Dict(:a => 1, :b => 2), tup = (1, 2, 3), + sym = :a, nth = nothing) + s = Flux.state(m) @test s == m end From 30cf25a0321661cf1afe8e4cb4fe00d9efe0b9b2 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 2 May 2023 18:21:51 +0200 Subject: [PATCH 08/15] remove callback from docs --- docs/src/saving.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/src/saving.md b/docs/src/saving.md index e34fbfaae2..294d4c66d2 100644 --- a/docs/src/saving.md +++ b/docs/src/saving.md @@ -51,7 +51,7 @@ julia> Flux.loadmodel!(model, model_state); ## Checkpointing -In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). You can do this by saving the model in the [callback provided to `train!`](training/training.md). +In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). ```jldoctest saving julia> using Flux: throttle @@ -64,12 +64,13 @@ Chain( Dense(5 => 2), # 12 parameters ) # Total: 4 arrays, 67 parameters, 524 bytes. -julia> evalcb = throttle(30) do +julia> for epoch in 1:10 + # ... train model ... jldsave("model-checkpoint.jld2", model_state = Flux.state(m)) end; ``` -This will update the `"model-checkpoint.jld2"` file every thirty seconds. +This will update the `"model-checkpoint.jld2"` every epoch. You can get more advanced by saving a series of models throughout training, for example From 532fe94c9622c1240a3a75ac1a60d542723dcad0 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 2 May 2023 18:54:57 +0200 Subject: [PATCH 09/15] fix --- Project.toml | 1 - docs/Project.toml | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2fe32ac7da..709e99a1d2 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" diff --git a/docs/Project.toml b/docs/Project.toml index 4af31f2254..a4d907e63e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -5,6 +5,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" From 71492921df3eab0a0b21caaf60524687c92d97f9 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 3 May 2023 16:00:39 +0200 Subject: [PATCH 10/15] new proposal pruning missings --- docs/src/saving.md | 4 ++-- src/loading.jl | 28 ++++++++++++++++++++-------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/docs/src/saving.md b/docs/src/saving.md index 294d4c66d2..16f944ef08 100644 --- a/docs/src/saving.md +++ b/docs/src/saving.md @@ -18,7 +18,7 @@ julia> struct MyModel julia> Flux.@functor MyModel -julia> MyModel() = MyModel(Chain(Dense(10, 5, relu), Dense(5, 2))) +julia> MyModel() = MyModel(Chain(Dense(10, 5, relu), Dense(5, 2))); julia> model = MyModel() MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2))) @@ -35,7 +35,7 @@ Load it again in a new session using [`Flux.loadmodel!`](@ref): ```jldoctest saving julia> using Flux, JLD2 -julia> model_state = JLD2.load("mymodel.jld2", "model_state") +julia> model_state = JLD2.load("mymodel.jld2", "model_state"); julia> model = MyModel(); # MyModel definition must be available diff --git a/src/loading.jl b/src/loading.jl index c42d019d39..24f16cf234 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -90,10 +90,10 @@ but copying a `src` value of `true` will error. function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet()) ldsts = _filter_children(filter, Functors.children(dst)) lsrcs = _filter_children(filter, Functors.children(src)) - (keys(ldsts) == keys(lsrcs)) || - throw(ArgumentError("Tried to load $(keys(lsrcs)) into $(keys(ldsts)) but the structures do not match.")) - - foreach(ldsts, lsrcs) do ldst, lsrc + keys_ldsts = keys(ldsts) + for k in keys(lsrcs) + k ∈ keys_ldsts || throw(ArgumentError("Tried to load $(keys(lsrcs)) into $(keys(ldsts)) but the structures do not match.")) + lsrc, ldst = lsrcs[k], ldsts[k] if ldst in cache # we already loaded this parameter before _tie_check(ldst, lsrc) && return ldst elseif Functors.isleaf(ldst) # our first time loading this leaf @@ -130,7 +130,7 @@ The state can be passed to [`loadmodel!`](@ref) to restore the model. julia> m1 = Chain(Dense(1, 2, tanh; init=ones), Dense(2, 1; init=ones)); julia> s = Flux.state(m1) -(layers = ((weight = [1.0; 1.0;;], bias = [0.0, 0.0], σ = missing), (weight = [1.0 1.0], bias = [0.0], σ = missing)),) +(layers = ((weight = [1.0; 1.0;;], bias = [0.0, 0.0]), (weight = [1.0 1.0], bias = [0.0])),) julia> m2 = Chain(Dense(1, 2, tanh), Dense(2, 1; bias=false)); # weights are random numbers @@ -142,10 +142,10 @@ julia> m2[1].weight # now the weights of m2 are the same as m1 1.0 julia> Flux.state(trainmode!(Dropout(0.2))) # contains p & activity, but not RNG state -(p = 0.2, dims = missing, active = true, rng = missing) +(p = 0.2, active = true) julia> Flux.state(BatchNorm(1)) # contains non-trainable arrays μ, σ² -(λ = missing, β = Float32[0.0], γ = Float32[1.0], μ = Float32[0.0], σ² = Float32[1.0], ϵ = 1.0f-5, momentum = 0.1f0, affine = true, track_stats = true, active = nothing, chs = 1) +(β = Float32[0.0], γ = Float32[1.0], μ = Float32[0.0], σ² = Float32[1.0], ϵ = 1.0f-5, momentum = 0.1f0, affine = true, track_stats = true, active = nothing, chs = 1) ``` ## Save and load with BSON @@ -168,9 +168,21 @@ julia> JLD2.jldsave("checkpoint.jld2", model_state = s) julia> Flux.loadmodel!(m2, JLD2.load("checkpoint.jld2", "model_state")) ``` """ -state(x) = Functors.fmapstructure(_state, x) +state(x) = Functors.fmapstructure(_state, x) |> prune_missing const STATE_TYPES = Union{AbstractArray, Number, Nothing, AbstractString, Symbol} _state(x::STATE_TYPES) = x _state(x) = missing + +prune_missing(x) = x + +prune_missing(nt::NamedTuple) = + (; (k => prune_missing(v) for (k,v) in pairs(nt) if !ismissing(v))...) + +prune_missing(d::Dict) = + Dict(k => prune_missing(v) for (k,v) in pairs(d) if !ismissing(v)) + +# we preserve missings in tuples to avoid ambiguities +prune_missing(t::Tuple) = prune_missing.(t) + From 6e1a5e1cf14fd6da83be36a880a428dd15dfd99c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 4 May 2023 06:31:17 +0200 Subject: [PATCH 11/15] sentinel is empty tuple --- NEWS.md | 1 - src/loading.jl | 5 +++-- test/loading.jl | 10 ++++++++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/NEWS.md b/NEWS.md index 6e9f80f73f..611fd53868 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,5 @@ # Flux Release Notes -<<<<<<< HEAD See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release. ## v0.13.16 diff --git a/src/loading.jl b/src/loading.jl index 24f16cf234..1b28115040 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -183,6 +183,7 @@ prune_missing(nt::NamedTuple) = prune_missing(d::Dict) = Dict(k => prune_missing(v) for (k,v) in pairs(d) if !ismissing(v)) -# we preserve missings in tuples to avoid ambiguities -prune_missing(t::Tuple) = prune_missing.(t) +# we replace missings with () in tuples instead +# of dropping them to avoid ambiguities +prune_missing(t::Tuple) = ((ismissing(x) ? () : prune_missing(x) for x in t)...,) diff --git a/test/loading.jl b/test/loading.jl index ce8e5ba757..f659a2c7d2 100644 --- a/test/loading.jl +++ b/test/loading.jl @@ -199,20 +199,26 @@ end @test s.layers isa Tuple @test length(s.layers) == 2 @test s.layers[1].weight === m1[1].weight - @test s.layers[1].σ === missing + @test !hasfield(typeof(s.layers[1]), :σ) @test s.layers[2].layers[1].weight === m1[2].layers[1].weight Flux.loadmodel!(m2, s) @test m2[1].weight == m1[1].weight @test all(m2[2].layers[1].bias .== m1[2].layers[1].bias) + @testset "sentinel value is empty tuple" begin + @test Flux.state((1, tanh)) == (1, ()) + @test Flux.state((a=1, b=tanh)) == (; a=1) + @test Flux.state(Dict(:a=>1, :b=>tanh)) == Dict(:a=>1) + end + @testset "track active state and batch norm params" begin m3 = Chain(Dense(10, 5), Dropout(0.2), Dense(5, 2), BatchNorm(2)) trainmode!(m3) s = Flux.state(m3) @test s.layers[2].active == true @test s.layers[2].p == 0.2 - @test s.layers[4].λ === missing + @test !hasfield(typeof(s.layers[4]), :λ) for k in (:β, :γ, :μ, :σ², :ϵ, :momentum, :affine, :track_stats, :active, :chs) @test s.layers[4][k] === getfield(m3[4], k) end From 6895b26ace1bbc85a2271eb161c08e5b0f15f145 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 4 May 2023 06:33:31 +0200 Subject: [PATCH 12/15] rewording --- test/loading.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/loading.jl b/test/loading.jl index f659a2c7d2..90e53942a5 100644 --- a/test/loading.jl +++ b/test/loading.jl @@ -206,7 +206,7 @@ end @test m2[1].weight == m1[1].weight @test all(m2[2].layers[1].bias .== m1[2].layers[1].bias) - @testset "sentinel value is empty tuple" begin + @testset "non-state elements are dropped or replaced with empty tuple" begin @test Flux.state((1, tanh)) == (1, ()) @test Flux.state((a=1, b=tanh)) == (; a=1) @test Flux.state(Dict(:a=>1, :b=>tanh)) == Dict(:a=>1) From b50d1a96d9c2ae94864043468299174683cb8660 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 4 May 2023 08:03:07 +0200 Subject: [PATCH 13/15] fix loadmodel! --- src/loading.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/loading.jl b/src/loading.jl index 1b28115040..289dc1db40 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -91,11 +91,13 @@ function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet()) ldsts = _filter_children(filter, Functors.children(dst)) lsrcs = _filter_children(filter, Functors.children(src)) keys_ldsts = keys(ldsts) - for k in keys(lsrcs) - k ∈ keys_ldsts || throw(ArgumentError("Tried to load $(keys(lsrcs)) into $(keys(ldsts)) but the structures do not match.")) + keys_lsrcs = keys(lsrcs) + + for k in keys_lsrcs + k ∈ keys_ldsts || throw(ArgumentError("Tried to load $(keys_lsrcs) into $(keys_ldsts) but the structures do not match.")) lsrc, ldst = lsrcs[k], ldsts[k] if ldst in cache # we already loaded this parameter before - _tie_check(ldst, lsrc) && return ldst + _tie_check(ldst, lsrc) elseif Functors.isleaf(ldst) # our first time loading this leaf push!(cache, ldst) loadleaf!(ldst, lsrc) From 1eb89a2232811125fc34e07fd1b8451b02731770 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 5 May 2023 15:27:11 +0200 Subject: [PATCH 14/15] don't drop fields --- src/loading.jl | 23 +++++------------------ test/loading.jl | 16 ++++++++++------ 2 files changed, 15 insertions(+), 24 deletions(-) diff --git a/src/loading.jl b/src/loading.jl index 289dc1db40..5096dcca21 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -132,7 +132,7 @@ The state can be passed to [`loadmodel!`](@ref) to restore the model. julia> m1 = Chain(Dense(1, 2, tanh; init=ones), Dense(2, 1; init=ones)); julia> s = Flux.state(m1) -(layers = ((weight = [1.0; 1.0;;], bias = [0.0, 0.0]), (weight = [1.0 1.0], bias = [0.0])),) +(layers = ((weight = [1.0; 1.0;;], bias = [0.0, 0.0], σ = ()), (weight = [1.0 1.0], bias = [0.0], σ = ())),) julia> m2 = Chain(Dense(1, 2, tanh), Dense(2, 1; bias=false)); # weights are random numbers @@ -144,10 +144,10 @@ julia> m2[1].weight # now the weights of m2 are the same as m1 1.0 julia> Flux.state(trainmode!(Dropout(0.2))) # contains p & activity, but not RNG state -(p = 0.2, active = true) +(p = 0.2, dims = (), active = true, rng = ()) julia> Flux.state(BatchNorm(1)) # contains non-trainable arrays μ, σ² -(β = Float32[0.0], γ = Float32[1.0], μ = Float32[0.0], σ² = Float32[1.0], ϵ = 1.0f-5, momentum = 0.1f0, affine = true, track_stats = true, active = nothing, chs = 1) +(λ = (), β = Float32[0.0], γ = Float32[1.0], μ = Float32[0.0], σ² = Float32[1.0], ϵ = 1.0f-5, momentum = 0.1f0, affine = true, track_stats = true, active = nothing, chs = 1) ``` ## Save and load with BSON @@ -170,22 +170,9 @@ julia> JLD2.jldsave("checkpoint.jld2", model_state = s) julia> Flux.loadmodel!(m2, JLD2.load("checkpoint.jld2", "model_state")) ``` """ -state(x) = Functors.fmapstructure(_state, x) |> prune_missing +state(x) = Functors.fmapstructure(_state, x) const STATE_TYPES = Union{AbstractArray, Number, Nothing, AbstractString, Symbol} _state(x::STATE_TYPES) = x -_state(x) = missing - -prune_missing(x) = x - -prune_missing(nt::NamedTuple) = - (; (k => prune_missing(v) for (k,v) in pairs(nt) if !ismissing(v))...) - -prune_missing(d::Dict) = - Dict(k => prune_missing(v) for (k,v) in pairs(d) if !ismissing(v)) - -# we replace missings with () in tuples instead -# of dropping them to avoid ambiguities -prune_missing(t::Tuple) = ((ismissing(x) ? () : prune_missing(x) for x in t)...,) - +_state(x) = () diff --git a/test/loading.jl b/test/loading.jl index 90e53942a5..06bc412d31 100644 --- a/test/loading.jl +++ b/test/loading.jl @@ -199,17 +199,21 @@ end @test s.layers isa Tuple @test length(s.layers) == 2 @test s.layers[1].weight === m1[1].weight - @test !hasfield(typeof(s.layers[1]), :σ) + @test s.layers[1].σ === () @test s.layers[2].layers[1].weight === m1[2].layers[1].weight Flux.loadmodel!(m2, s) @test m2[1].weight == m1[1].weight @test all(m2[2].layers[1].bias .== m1[2].layers[1].bias) - @testset "non-state elements are dropped or replaced with empty tuple" begin + @testset "non-state elements are replaced with empty tuple" begin @test Flux.state((1, tanh)) == (1, ()) - @test Flux.state((a=1, b=tanh)) == (; a=1) - @test Flux.state(Dict(:a=>1, :b=>tanh)) == Dict(:a=>1) + @test Flux.state((a=1, b=tanh)) == (; a=1, b=()) + @test Flux.state(Dict(:a=>1, :b=>tanh)) == Dict(:a=>1, :b=>()) + X, Y = Flux.ones32(3, 2), Flux.zeros32(2, 2) + tree = Dict(:a=>1, :b=>(; c=X, d=(Y, 1, (tanh,)), e=sin)) + state_tree = Dict(:a=>1, :b=>(; c=X, d=(Y, 1, ((),)), e=())) + @test Flux.state(tree) == state_tree end @testset "track active state and batch norm params" begin @@ -218,13 +222,13 @@ end s = Flux.state(m3) @test s.layers[2].active == true @test s.layers[2].p == 0.2 - @test !hasfield(typeof(s.layers[4]), :λ) + @test s.layers[4].λ === () for k in (:β, :γ, :μ, :σ², :ϵ, :momentum, :affine, :track_stats, :active, :chs) @test s.layers[4][k] === getfield(m3[4], k) end end - @testset "saved types" begin + @testset "preservation of saved types" begin m = (num = 1, cnum = Complex(1.2, 2), str = "hello", arr = [1, 2, 3], bool = true, dict = Dict(:a => 1, :b => 2), tup = (1, 2, 3), sym = :a, nth = nothing) From 8d45e774c7d1fce6cac5e02c9103308b52023544 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 5 May 2023 16:29:29 +0200 Subject: [PATCH 15/15] require keys to be equal in loadmodel! --- src/loading.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/loading.jl b/src/loading.jl index 5096dcca21..8238a19cde 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -92,9 +92,9 @@ function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet()) lsrcs = _filter_children(filter, Functors.children(src)) keys_ldsts = keys(ldsts) keys_lsrcs = keys(lsrcs) + collect(keys_ldsts) == collect(keys_lsrcs) || throw(ArgumentError("Tried to load $(keys_lsrcs) into $(keys_ldsts) but the structures do not match.")) for k in keys_lsrcs - k ∈ keys_ldsts || throw(ArgumentError("Tried to load $(keys_lsrcs) into $(keys_ldsts) but the structures do not match.")) lsrc, ldst = lsrcs[k], ldsts[k] if ldst in cache # we already loaded this parameter before _tie_check(ldst, lsrc)