diff --git a/.gitignore b/.gitignore index 20fe29d..35c98b5 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ *.jl.mem /Manifest.toml /docs/build/ +.vscode/ diff --git a/src/StructWalk.jl b/src/StructWalk.jl index 17575cf..a74c77d 100644 --- a/src/StructWalk.jl +++ b/src/StructWalk.jl @@ -6,7 +6,7 @@ using ConstructionBase: constructorof export prewalk, postwalk, mapleaves """ -Abstract type `WalkStyle` + abstract type WalkStyle and Subtype `WalkStyle` and overload [`walkstyle`](@ref) to define custom walking behaviors (constructor / children /...). """ @@ -18,12 +18,12 @@ abstract type WalkStyle end Should return a tuple of length 3 with: 1. [constructor](@ref): A proper constuctor for `T`, can be `identity` if `x` isa leaf. - 2. [children](@ref): Children of `x` in a tuple, or empty tuple `()` if `x` is a leaf. - 3. [iscontainer](@ref): A bool indicate whether element of 2. is the actual list of children. default to `false`. - -For example, since `Array` has 0 `fieldcount`, we doesn't split the value into a tuple as children. - Instead, we return `(x,)` as children and the extra boolean `true`, so it will `walk`/`map` through `x` - accordingly. + 2. [children](@ref): Children of `x` in a tuple, or empty tuple `()` if `x` is a leaf. + Named tuples are also allowed as alternatives to tuples. + 3. [iscontainer](@ref): A bool indicate whether element of 2. is the actual list of children. + For example, since `Array` has 0 `fieldcount`, we doesn't split the value into a tuple as children. + Instead, we return `(x,)` as children and the extra boolean `true`, so it will `walk`/`map` through `x` + accordingly. Default `false`. """ function walkstyle end @@ -77,7 +77,7 @@ iscontainer(::Type{WalkStyle}, x) = false const WALKSTYLE = Union{WalkStyle, Type{<:WalkStyle}} # default walkstyle for some types -include("./walkstyle.jl") +include("walkstyle.jl") """ LeafNode(x) @@ -222,8 +222,9 @@ x = (a = 2, b = (c = 4, d = 0)) mapnonleaves(f, x) = mapnonleaves(f, WalkStyle, x) mapnonleaves(f, style::WALKSTYLE, x) = walk(identity, f, style, x -> mapnonleaves(f, style, x), x) -include("./aligned.jl") -include("./scan.jl") +include("aligned.jl") +include("scan.jl") +include("functors.jl") @specialize diff --git a/src/functors.jl b/src/functors.jl new file mode 100644 index 0000000..a361736 --- /dev/null +++ b/src/functors.jl @@ -0,0 +1,118 @@ +# Replacement for Functors.jl + +const NoChildren = Tuple{} +# isleaf(x) = isempty(x) + +struct FunctorStyle <: WalkStyle end + +isleaf(@nospecialize(x)) = children(FunctorStyle(), x) === NoChildren() + +children(::FunctorStyle, x::AbstractArray{<:Number}) = () +constructor(::FunctorStyle, x::AbstractArray{<:Number}) = _ -> x +iscontainer(::FunctorStyle, x::AbstractArray{<:Number}) = false + +constructor(::FunctorStyle, x::AbstractArray) = identity +constructor(::FunctorStyle, x::Tuple) = identity +constructor(::FunctorStyle, x::NamedTuple) = identity +constructor(::FunctorStyle, x::Dict) = identity + +children(::FunctorStyle, x::AbstractArray) = x +children(::FunctorStyle, x::Tuple) = x +children(::FunctorStyle, x::NamedTuple) = x +children(::FunctorStyle, x::Dict) = x + +function constructor(::FunctorStyle, x::T) where T + if iszero(fieldcount(T)) + return identity + else + ch -> ConstructionBase.constructorof(T)(ch...) + end +end + +# mimicks Functors.fmap +fmap(f, x) = functor_mapleaves(f, FunctorStyle(), x) + +# mimicks Functors.fmapstructure +struct FunctorStructureStyle <: WalkStyle end +children(::FunctorStructureStyle, x) = children(FunctorStyle(), x) +iscontainer(::FunctorStructureStyle, x) = iscontainer(FunctorStyle(), x) +constructor(::FunctorStructureStyle, x) = to_standard_container + +to_standard_container(x::Union{Tuple, NamedTuple, AbstractArray, AbstractDict}) = x +to_standard_container(x::T) where T = (; (f => getfield(x, f) for f in fieldnames(T))...) + +""" + fmapstructure(f, x; exclude = isleaf) + +Like fmap, but doesn't preserve the type of custom structs. Instead, it returns a NamedTuple (or a Tuple, or an array), +or a nested set of these. + +Useful for when the output must not contain custom structs. + +# Examples +``` +julia> struct Foo; x; y; end + +julia> @functor Foo + +julia> m = Foo([1,2,3], [4, (5, 6), Foo(7, 8)]); + +julia> fmapstructure(x -> 2x, m) +(x = [2, 4, 6], y = Any[8, (10, 12), (x = 14, y = 16)]) + +julia> fmapstructure(println, m) +[1, 2, 3] +4 +5 +6 +7 +8 +(x = nothing, y = Any[nothing, (nothing, nothing), (x = nothing, y = nothing)]) +``` +""" +fmapstructure(f, x) = functor_mapleaves(f, FunctorStructureStyle(), x) + + +functor_mapleaves(f, style::WALKSTYLE, x) = functor_walk(f, identity, style, x -> functor_mapleaves(f, style, x), x) + +### Same as walk but doesn't splat the constructor. +### We could replace `walk` with this in the next breaking release. +function functor_walk(f, g, style::WALKSTYLE, inner_walk, x) + T, fields, iscontainer = walkstyle(style, x) + isleaf = isempty(fields) + if isleaf + return f(x) + else + v = mapvalues(inner_walk, fields) + return g(T(v)) + end +end + +mapvalues(f, x) = map(f, x) +mapvalues(f, x::Dict) = Dict(k => f(v) for (k, v) in pairs(x)) + +# functor(::Type{<:Adjoint}, x) = (parent = _adjoint(x),), y -> adjoint(only(y)) + +# _adjoint(x) = adjoint(x) # _adjoint is the inverse, and also understands more types: +# _adjoint(x::NamedTuple{(:parent,)}) = x.parent # "structural" gradient, and lazy broadcast used by Optimisers: +# _adjoint(bc::Broadcast.Broadcasted{S}) where S = Broadcast.Broadcasted{S}(_conjugate(bc.f, adjoint), _adjoint.(bc.args)) + +# functor(::Type{<:Transpose}, x) = (parent = _transpose(x),), y -> transpose(only(y)) + +# _transpose(x) = transpose(x) +# _transpose(x::NamedTuple{(:parent,)}) = x.parent +# _transpose(bc::Broadcast.Broadcasted{S}) where S = Broadcast.Broadcasted{S}(_conjugate(bc.f, transpose), _transpose.(bc.args)) + +# _conjugate(f::F, ::typeof(identity)) where F = f +# _conjugate(f::F, op::Union{typeof(transpose), typeof(adjoint)}) where F = (xs...,) -> op(f(op.(xs)...)) + +# function functor(::Type{<:PermutedDimsArray{T,N,perm,iperm}}, x) where {T,N,perm,iperm} +# (parent = _PermutedDimsArray(x, iperm),), y -> PermutedDimsArray(only(y), perm) +# end +# function functor(::Type{<:PermutedDimsArray{T,N,perm,iperm}}, x::PermutedDimsArray{Tx,N,perm,iperm}) where {T,Tx,N,perm,iperm} +# (parent = parent(x),), y -> PermutedDimsArray(only(y), perm) # most common case, avoid wrapping wrice. +# end + +# _PermutedDimsArray(x, iperm) = PermutedDimsArray(x, iperm) +# _PermutedDimsArray(x::NamedTuple{(:parent,)}, iperm) = x.parent +# _PermutedDimsArray(bc::Broadcast.Broadcasted, iperm) = _Pe \ No newline at end of file diff --git a/test/functors.jl b/test/functors.jl new file mode 100644 index 0000000..b0dba3d --- /dev/null +++ b/test/functors.jl @@ -0,0 +1,387 @@ +using StructWalk: FunctorStyle, fmap, isleaf, NoChildren, children, fmapstructure +using StructWalk +using Test + +const FS = FunctorStyle() + +struct Foo; x; y; end + +struct Bar{T}; x::T; end + + +struct OneChild3; x; y; z; end +StructWalk.children(::FunctorStyle, a::OneChild3) = (a.y,) +StructWalk.constructor(::FunctorStyle, a::OneChild3) = ch -> OneChild3(a.x, ch..., a.z) + +struct NoChildren2; x; y; end +StructWalk.children(::FunctorStyle, a::NoChildren2) = () +StructWalk.constructor(::FunctorStyle, a::NoChildren2) = _ -> a + +struct NoChild{T}; x::T; end +StructWalk.children(::FunctorStyle, a::NoChild) = () +StructWalk.constructor(::FunctorStyle, a::NoChild) = _ -> a + + + +# ### +# ### Basic functionality +# ### + +@testset "Children and Leaves" begin + no_children = NoChildren2(1, 2) + has_children = Foo(1, 2) + @test isleaf(no_children) + @test !isleaf(has_children) + @test children(FS, no_children) === NoChildren() + @test children(FS, has_children) == (x=1, y=2) +end + +@testset "Nested" begin + model = Bar(Foo(1, [1, 2, 3])) + + model′ = fmap(float, model) + + @test model.x.y == model′.x.y + @test model′.x.y isa Vector{Float64} +end + +# @testset "Exclude" begin +# f(x::AbstractArray) = x +# f(x::Char) = 'z' + +# x = ['a', 'b', 'c'] +# @test fmap(f, x) == ['z', 'z', 'z'] +# @test fmap(f, x; exclude = x -> x isa AbstractArray) == x + +# x = (['a', 'b', 'c'], ['d', 'e', 'f']) +# @test fmap(f, x) == (['z', 'z', 'z'], ['z', 'z', 'z']) +# @test fmap(f, x; exclude = x -> x isa AbstractArray) == x +# end + +@testset "Property list" begin + model = OneChild3(1, 2, 3) + model′ = fmap(x -> 2x, model) + @test (model′.x, model′.y, model′.z) == (1, 4, 3) +end + +# @testset "Sharing" begin +# shared = [1,2,3] +# m1 = Foo(shared, Foo([1,2,3], Foo(shared, [1,2,3]))) +# m1f = fmap(float, m1) +# @test m1f.x === m1f.y.y.x +# @test m1f.x !== m1f.y.x +# m1p = fmapstructure(identity, m1; prune = nothing) +# @test m1p == (x = [1, 2, 3], y = (x = [1, 2, 3], y = (x = nothing, y = [1, 2, 3]))) +# m1no = fmap(float, m1; cache = nothing) # disable the cache by hand +# @test m1no.x !== m1no.y.y.x + +# # Here "4" is not shared, because Foo isn't leaf: +# m2 = Foo(Foo(shared, 4), Foo(shared, 4)) +# @test m2.x === m2.y +# m2f = fmap(float, m2) +# @test m2f.x.x === m2f.y.x +# m2p = fmapstructure(identity, m2; prune = Bar(0)) +# @test m2p == (x = (x = [1, 2, 3], y = 4), y = (x = Bar{Int64}(0), y = 4)) + +# # Repeated isbits types should not automatically be regarded as shared: +# m3 = Foo(Foo(shared, 1:3), Foo(1:3, shared)) +# m3p = fmapstructure(identity, m3; prune = 0) +# @test m3p.y.y == 0 +# @test m3p.y.x == 1:3 + +# # All-isbits trees need not create a cache at all: +# m4 = (x=1, y=(2, 3), z=4:5) +# @test isbits(fmap(float, m4)) +# @test_skip 0 == @allocated fmap(float, m4) # true, but fails in tests + +# # Shared mutable containers are preserved, even if all children are isbits: +# ref = Ref(1) +# m5 = (x = ref, y = ref, z = Ref(1)) +# m5f = fmap(x -> x/2, m5) +# @test m5f.x === m5f.y +# @test m5f.x !== m5f.z + +# @testset "usecache ($d)" for d in [IdDict(), Base.IdSet()] +# # Leaf types: +# @test usecache(d, [1,2]) +# @test !usecache(d, 4.0) +# @test usecache(d, NoChild([1,2])) +# @test !usecache(d, NoChild((3,4))) + +# # Not leaf: +# @test usecache(d, Ref(3)) # mutable container +# @test !usecache(d, (5, 6.0)) +# @test !usecache(d, (a = 2pi, b = missing)) + +# @test !usecache(d, (5, [6.0]')) # contains mutable +# @test !usecache(d, (x = [1,2,3], y = 4)) + +# usecache(d, OneChild3([1,2], 3, nothing)) # mutable isn't a child, do we care? + +# # No dictionary: +# @test !usecache(nothing, [1,2]) +# @test !usecache(nothing, 3) +# end +# end + +# @testset "functor(typeof(x), y) from @functor" begin +# nt1, re1 = functor(Foo, (x=1, y=2, z=3)) +# @test nt1 == (x = 1, y = 2) +# @test re1((x = 10, y = 20)) == Foo(10, 20) +# re1((y = 22, x = 11)) # gives Foo(22, 11), is that a bug? + +# nt2, re2 = functor(Foo, (z=33, x=1, y=2)) +# @test nt2 == (x = 1, y = 2) +# @test re2((x = 10, y = 20)) == Foo(10, 20) + +# @test_throws Exception functor(Foo, (z=33, x=1)) # type NamedTuple has no field y + +# nt3, re3 = functor(OneChild3, (x=1, y=2, z=3)) +# @test nt3 == (y = 2,) +# @test re3((y = 20,)) == OneChild3(1, 20, 3) +# re3(22) # gives OneChild3(1, 22, 3), is that a bug? +# end + +# @testset "functor(typeof(x), y) for Base types" begin +# nt11, re11 = functor(NamedTuple{(:x, :y)}, (x=1, y=2, z=3)) +# @test nt11 == (x = 1, y = 2) +# @test re11((x = 10, y = 20)) == (x = 10, y = 20) +# re11((y = 22, x = 11)) +# re11((11, 22)) # passes right through + +# nt12, re12 = functor(NamedTuple{(:x, :y)}, (z=33, x=1, y=2)) +# @test nt12 == (x = 1, y = 2) +# @test re12((x = 10, y = 20)) == (x = 10, y = 20) + +# @test_throws Exception functor(NamedTuple{(:x, :y)}, (z=33, x=1)) +# end + +# ### +# ### Extras +# ### + +@testset "Walk" begin + model = Foo((0, Bar([1, 2, 3])), [4, 5]) + + model′ = fmapstructure(identity, model) + @test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5]) +end + +# @testset "fcollect" begin +# m1 = [1, 2, 3] +# m2 = 1 +# m3 = Foo(m1, m2) +# m4 = Bar(m3) +# @test all(fcollect(m4) .=== [m4, m3, m1, m2]) +# @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3, m2]) +# @test all(fcollect(m4, exclude = x -> x isa Foo) .=== [m4]) + +# m1 = [1, 2, 3] +# m2 = Bar(m1) +# m0 = NoChildren2(:a, :b) +# m3 = Foo(m2, m0) +# m4 = Bar(m3) +# @test all(fcollect(m4) .=== [m4, m3, m2, m1, m0]) + +# m1 = [1, 2, 3] +# m2 = [1, 2, 3] +# m3 = Foo(m1, m2) +# @test all(fcollect(m3) .=== [m3, m1, m2]) + +# m1 = [1, 2, 3] +# m2 = SVector{length(m1)}(m1) +# m2′ = SVector{length(m1)}(m1) +# m3 = Foo(m1, m1) +# m4 = Foo(m2, m2′) +# @test all(fcollect(m3) .=== [m3, m1]) +# @test all(fcollect(m4) .=== [m4, m2, m2′]) +# end + +# ### +# ### Vararg forms +# ### + +# @testset "fmap(f, x, y)" begin +# m1 = (x = [1,2], y = 3) +# n1 = (x = [4,5], y = 6) +# @test fmap(+, m1, n1) == (x = [5, 7], y = 9) + +# # Reconstruction type comes from the first argument +# foo1 = Foo([7,8], 9) +# @test fmap(+, m1, foo1) == (x = [8, 10], y = 12) +# @test fmap(+, foo1, n1) isa Foo +# @test fmap(+, foo1, n1).x == [11, 13] + +# # Mismatched trees should be an error +# m2 = (x = [1,2], y = (a = [3,4], b = 5)) +# n2 = (x = [6,7], y = 8) +# @test_throws Exception fmap(first∘tuple, m2, n2) # ERROR: type Int64 has no field a +# @test_throws Exception fmap(first∘tuple, m2, n2) + +# # The cache uses IDs from the first argument +# shared = [1,2,3] +# m3 = (x = shared, y = [4,5,6], z = shared) +# n3 = (x = shared, y = shared, z = [7,8,9]) +# @test fmap(+, m3, n3) == (x = [2, 4, 6], y = [5, 7, 9], z = [2, 4, 6]) +# z3 = fmap(+, m3, n3) +# @test z3.x === z3.z + +# # Pruning of duplicates: +# @test fmap(+, m3, n3; prune = nothing) == (x = [2,4,6], y = [5,7,9], z = nothing) + +# # More than two arguments: +# z4 = fmap(+, m3, n3, m3, n3) +# @test z4 == fmap(x -> 2x, z3) +# @test z4.x === z4.z + +# @test fmap(+, foo1, m1, n1) isa Foo +# @static if VERSION >= v"1.6" # fails on Julia 1.0 +# @test fmap(.*, m1, foo1, n1) == (x = [4*7, 2*5*8], y = 3*6*9) +# end +# end + +# @testset "old test update.jl" begin +# struct M{F,T,S} +# σ::F +# W::T +# b::S +# end + +# @functor M + +# (m::M)(x) = m.σ.(m.W * x .+ m.b) + +# m = M(identity, ones(Float32, 3, 4), zeros(Float32, 3)) +# x = ones(Float32, 4, 2) +# m̄, _ = gradient((m,x) -> sum(m(x)), m, x) +# m̂ = Functors.fmap(m, m̄) do x, y +# isnothing(x) && return y +# isnothing(y) && return x +# x .- 0.1f0 .* y +# end + +# @test m̂.W ≈ fill(0.8f0, size(m.W)) +# @test m̂.b ≈ fill(-0.2f0, size(m.b)) +# end + +# ### +# ### FlexibleFunctors.jl +# ### + +# struct FFoo +# x +# y +# p +# end +# @flexiblefunctor FFoo p + +# struct FBar +# x +# p +# end +# @flexiblefunctor FBar p + +# struct FOneChild4 +# x +# y +# z +# p +# end +# @flexiblefunctor FOneChild4 p + +# @testset "Flexible Nested" begin +# model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,)) + +# model′ = fmap(float, model) + +# @test model.x.y == model′.x.y +# @test model′.x.y isa Vector{Float64} +# end + +# @testset "Flexible Walk" begin +# model = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x, :y)) + +# model′ = fmapstructure(identity, model) +# @test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5]) + +# model2 = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x,)) + +# model2′ = fmapstructure(identity, model2) +# @test model2′ == (; x=(0, (; x=[1, 2, 3]))) +# end + +# @testset "Flexible Property list" begin +# model = FOneChild4(1, 2, 3, (:x, :z)) +# model′ = fmap(x -> 2x, model) + +# @test (model′.x, model′.y, model′.z) == (2, 2, 6) +# end + +# @testset "Flexible fcollect" begin +# m1 = 1 +# m2 = [1, 2, 3] +# m3 = FFoo(m1, m2, (:y, )) +# m4 = FBar(m3, (:x,)) +# @test all(fcollect(m4) .=== [m4, m3, m2]) +# @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3]) +# @test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4]) + +# m0 = NoChildren2(:a, :b) +# m1 = [1, 2, 3] +# m2 = FBar(m1, ()) +# m3 = FFoo(m2, m0, (:x, :y,)) +# m4 = FBar(m3, (:x,)) +# @test all(fcollect(m4) .=== [m4, m3, m2, m0]) +# end + +@testset "Dict" begin + d = Dict(:a => 1, :b => 2) + + @test children(FS, d) == d + @test fmap(x -> x + 1, d) == Dict(:a => 2, :b => 3) + + d = Dict(:a => 1, :b => Dict("a" => 5, "b" => 6, "c" => 7)) + @test children(FS, d) == d + @test fmap(x -> x + 1, d) == Dict(:a => 2, :b => Dict("a" => 6, "b" => 7, "c" => 8)) + +# @testset "fmap(+, x, y)" begin +# m1 = Dict("x" => [1,2], "y" => 3) +# n1 = Dict("x" => [4,5], "y" => 6) +# @test fmap(+, m1, n1) == Dict("x" => [5, 7], "y" => 9) + +# m1 = Dict(:x => [1,2], :y => 3) +# n1 = (x = [4,5], y = 6) +# @test fmap(+, m1, n1) == Dict(:x => [5, 7], :y => 9) + +# # extra keys in n1 are ignored +# m1 = Dict("x" => [1,2], "y" => Dict(:a => 3, :b => 4)) +# n1 = Dict("x" => [4,5], "y" => Dict(:a => 0.1, :b => 0.2, :c => 5), "z" => Dict(:a => 5)) +# @test fmap(+, m1, n1) == Dict("x" => [5, 7], "y" => Dict(:a=>3.1, :b=>4.2)) +# end +end + +# @testset "@leaf" begin +# struct A; x; end +# @functor A +# a = A(1) +# @test Functors.children(a) === (x = 1,) +# Functors.@leaf A +# children, re = Functors.functor(a) +# @test children == Functors.NoChildren() +# @test re(children) === a +# end + +# @testset "IterateWalk" begin +# x = ([1, 2, 3], 4, (5, 6, [7, 8])); +# make_iterator(x) = x isa AbstractVector ? x.^2 : (x^2,); +# iter = fmap(make_iterator, x; walk=Functors.IterateWalk(), cache=nothing); +# @test iter isa Iterators.Flatten +# @test collect(iter) == [1, 2, 3, 4, 5, 6, 7, 8].^2 + +# # Test iteration of multiple trees together +# y = ([8, 7, 6], 5, (4, 3, [2, 1])); +# make_zipped_iterator(x, y) = zip(make_iterator(x), make_iterator(y)); +# zipped_iter = fmap(make_zipped_iterator, x, y; walk=Functors.IterateWalk(), cache=nothing); +# @test zipped_iter isa Iterators.Flatten +# @test collect(zipped_iter) == collect(Iterators.zip([1, 2, 3, 4, 5, 6, 7, 8].^2, [8, 7, 6, 5, 4, 3, 2, 1].^2)) +# end \ No newline at end of file