Skip to content

Commit d1d29da

Browse files
Merge remote-tracking branch 'origin/master' into dg/grad
2 parents 10ec1ae + adeb24b commit d1d29da

File tree

6 files changed

+146
-23
lines changed

6 files changed

+146
-23
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Functors"
22
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
33
authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
4-
version = "0.2.0"
4+
version = "0.2.1"
55

66
[deps]
77
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"

README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,22 @@ Any field not in the list will not be returned by `functor` and passed through a
7171
It is also possible to implement `functor` by hand when greater flexibility is required. See [here](https://github.com/FluxML/Functors.jl/issues/3) for an example.
7272

7373
For a discussion regarding the need for a `cache` in the implementation of `fmap`, see [here](https://github.com/FluxML/Functors.jl/issues/2).
74+
75+
Use `exclude` for more fine-grained control over whether `fmap` descends into a particular value (the default is `exclude = Functors.isleaf`):
76+
```julia
77+
julia> using CUDA
78+
79+
julia> x = ['a', 'b', 'c'];
80+
81+
julia> fmap(cu, x)
82+
3-element Array{Char,1}:
83+
'a': ASCII/Unicode U+0061 (category Ll: Letter, lowercase)
84+
'b': ASCII/Unicode U+0062 (category Ll: Letter, lowercase)
85+
'c': ASCII/Unicode U+0063 (category Ll: Letter, lowercase)
86+
87+
julia> fmap(cu, x; exclude = x -> CUDA.isbitstype(eltype(x)))
88+
3-element CuArray{Char,1}:
89+
'a': ASCII/Unicode U+0061 (category Ll: Letter, lowercase)
90+
'b': ASCII/Unicode U+0062 (category Ll: Letter, lowercase)
91+
'c': ASCII/Unicode U+0063 (category Ll: Letter, lowercase)
92+
```

src/Functors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module Functors
22

33
using MacroTools
44

5-
export @functor, fmap
5+
export @functor, fmap, fcollect
66

77
include("functor.jl")
88

src/functor.jl

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,20 @@ macro functor(args...)
2929
functorm(args...)
3030
end
3131

32-
isleaf(x) = functor(x)[1] === ()
32+
"""
33+
isleaf(x)
34+
35+
Return true if `x` has no [`children`](@ref) according to [`functor`](@ref).
36+
"""
37+
isleaf(x) = children(x) === ()
38+
39+
"""
40+
children(x)
41+
42+
Return the children of `x` as defined by [`functor`](@ref).
43+
Equivalent to `functor(x)[1]`.
44+
"""
45+
children(x) = functor(x)[1]
3346

3447
# for Chain
3548
function functor_tuple(f, x::Tuple, dx::Tuple)
@@ -56,9 +69,67 @@ end
5669

5770
# See https://github.com/FluxML/Functors.jl/issues/2 for a discussion regarding the need for
5871
# cache.
59-
function fmap(f, x; cache = IdDict())
72+
function fmap(f, x; exclude = isleaf, cache = IdDict())
6073
haskey(cache, x) && return cache[x]
61-
cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x)
74+
y = exclude(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache, exclude = exclude), x)
75+
cache[x] = y
76+
77+
return y
78+
end
79+
80+
"""
81+
fcollect(x; exclude = v -> false)
82+
83+
Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref)
84+
and collecting the results into a flat array.
85+
86+
Doesn't recurse inside branches rooted at nodes `v`
87+
for which `exclude(v) == true`.
88+
In such cases, the root `v` is also excluded from the result.
89+
By default, `exclude` always yields `false`.
90+
91+
See also [`children`](@ref).
92+
93+
# Examples
94+
95+
```jldoctest
96+
julia> struct Foo; x; y; end
97+
98+
julia> @functor Foo
99+
100+
julia> struct Bar; x; end
101+
102+
julia> @functor Bar
103+
104+
julia> struct NoChildren; x; y; end
105+
106+
julia> m = Foo(Bar([1,2,3]), NoChildren(:a, :b))
107+
108+
julia> fcollect(m)
109+
4-element Vector{Any}:
110+
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
111+
Bar([1, 2, 3])
112+
[1, 2, 3]
113+
NoChildren(:a, :b)
114+
115+
julia> fcollect(m, exclude = v -> v isa Bar)
116+
2-element Vector{Any}:
117+
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
118+
NoChildren(:a, :b)
119+
120+
julia> fcollect(m, exclude = v -> Functors.isleaf(v))
121+
2-element Vector{Any}:
122+
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
123+
Bar([1, 2, 3])
124+
```
125+
"""
126+
function fcollect(x; cache = [], exclude = v -> false)
127+
x in cache && return cache
128+
if !exclude(x)
129+
push!(cache, x)
130+
foreach(y -> fcollect(y; cache = cache, exclude = exclude), children(x))
131+
end
132+
return cache
62133
end
63134

64135
# Allow gradients and other constructs that match the structure of the functor

test/basics.jl

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,29 @@
11
using Functors, Test
22

3-
@testset "Nested" begin
4-
struct Foo
5-
x
6-
y
7-
end
3+
struct Foo
4+
x
5+
y
6+
end
7+
@functor Foo
88

9-
@functor Foo
9+
struct Bar
10+
x
11+
end
12+
@functor Bar
1013

11-
struct Bar
12-
x
13-
end
14+
struct Baz
15+
x
16+
y
17+
z
18+
end
19+
@functor Baz (y,)
1420

15-
@functor Bar
21+
struct NoChildren
22+
x
23+
y
24+
end
1625

26+
@testset "Nested" begin
1727
model = Bar(Foo(1, [1, 2, 3]))
1828

1929
model′ = fmap(float, model)
@@ -22,17 +32,40 @@ using Functors, Test
2232
@test model′.x.y isa Vector{Float64}
2333
end
2434

35+
@testset "Exclude" begin
36+
f(x::AbstractArray) = x
37+
f(x::Char) = 'z'
38+
39+
x = ['a', 'b', 'c']
40+
@test fmap(f, x) == ['z', 'z', 'z']
41+
@test fmap(f, x; exclude = x -> x isa AbstractArray) == x
42+
43+
x = (['a', 'b', 'c'], ['d', 'e', 'f'])
44+
@test fmap(f, x) == (['z', 'z', 'z'], ['z', 'z', 'z'])
45+
@test fmap(f, x; exclude = x -> x isa AbstractArray) == x
46+
end
47+
2548
@testset "Property list" begin
26-
struct Baz
27-
x
28-
y
29-
z
30-
end
31-
32-
@functor Baz (y,)
33-
3449
model = Baz(1, 2, 3)
3550
model′ = fmap(x -> 2x, model)
3651

3752
@test (model′.x, model′.y, model′.z) == (1, 4, 3)
3853
end
54+
55+
@testset "fcollect" begin
56+
m1 = [1, 2, 3]
57+
m2 = 1
58+
m3 = Foo(m1, m2)
59+
m4 = Bar(m3)
60+
@test all(fcollect(m4) .=== [m4, m3, m1, m2])
61+
@test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3, m2])
62+
@test all(fcollect(m4, exclude = x -> x isa Foo) .=== [m4])
63+
64+
m1 = [1, 2, 3]
65+
m2 = Bar(m1)
66+
m0 = NoChildren(:a, :b)
67+
m3 = Foo(m2, m0)
68+
m4 = Bar(m3)
69+
println(fcollect(m4))
70+
@test all(fcollect(m4) .=== [m4, m3, m2, m1, m0])
71+
end

0 commit comments

Comments
 (0)