Skip to content

Commit ea4a7ea

Browse files
Merge pull request #10 from darsnack/darsnack/predicate-fmap
Add predicate keyword to fmap
2 parents 0b90595 + 1fd68e2 commit ea4a7ea

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,22 @@ Any field not in the list will not be returned by `functor` and passed through a
7171
It is also possible to implement `functor` by hand when greater flexibility is required. See [here](https://github.com/FluxML/Functors.jl/issues/3) for an example.
7272

7373
For a discussion regarding the need for a `cache` in the implementation of `fmap`, see [here](https://github.com/FluxML/Functors.jl/issues/2).
74+
75+
Use `exclude` for more fine-grained control over whether `fmap` descends into a particular value (the default is `exclude = Functors.isleaf`):
76+
```julia
77+
julia> using CUDA
78+
79+
julia> x = ['a', 'b', 'c'];
80+
81+
julia> fmap(cu, x)
82+
3-element Array{Char,1}:
83+
'a': ASCII/Unicode U+0061 (category Ll: Letter, lowercase)
84+
'b': ASCII/Unicode U+0062 (category Ll: Letter, lowercase)
85+
'c': ASCII/Unicode U+0063 (category Ll: Letter, lowercase)
86+
87+
julia> fmap(cu, x; exclude = x -> CUDA.isbitstype(eltype(x)))
88+
3-element CuArray{Char,1}:
89+
'a': ASCII/Unicode U+0061 (category Ll: Letter, lowercase)
90+
'b': ASCII/Unicode U+0062 (category Ll: Letter, lowercase)
91+
'c': ASCII/Unicode U+0063 (category Ll: Letter, lowercase)
92+
```

src/functor.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,12 @@ end
5151

5252
# See https://github.com/FluxML/Functors.jl/issues/2 for a discussion regarding the need for
5353
# cache.
54-
function fmap(f, x; cache = IdDict())
54+
function fmap(f, x; exclude = isleaf, cache = IdDict())
5555
haskey(cache, x) && return cache[x]
56-
cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x)
56+
y = exclude(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache, exclude = exclude), x)
57+
cache[x] = y
58+
59+
return y
5760
end
5861

5962
"""

test/basics.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,20 @@ end
3232
@test model′.x.y isa Vector{Float64}
3333
end
3434

35-
@testset "Property list" begin
35+
@testset "Exclude" begin
36+
f(x::AbstractArray) = x
37+
f(x::Char) = 'z'
38+
39+
x = ['a', 'b', 'c']
40+
@test fmap(f, x) == ['z', 'z', 'z']
41+
@test fmap(f, x; exclude = x -> x isa AbstractArray) == x
42+
43+
x = (['a', 'b', 'c'], ['d', 'e', 'f'])
44+
@test fmap(f, x) == (['z', 'z', 'z'], ['z', 'z', 'z'])
45+
@test fmap(f, x; exclude = x -> x isa AbstractArray) == x
46+
end
47+
48+
@testset "Property list" begin
3649
model = Baz(1, 2, 3)
3750
model′ = fmap(x -> 2x, model)
3851

0 commit comments

Comments
 (0)