1+
2+ """
3+ @layer Dense
4+ @layer :expand Chain
5+ @layer BatchNorm trainable=(β,γ)
6+ @layer Struct functor=(α,β) trainable=(β,)
7+
8+ This macro replaces most uses of `@functor` in Flux 0.14. Its basic purpose is the same:
9+ When you define a new layer, this tells Flux to explore inside it
10+ to see the parameters it trains, and also to move them to the GPU, change precision, etc.
11+
12+ Some "keywords" allow control of the recursion:
13+ * If some fields look like parameters but should not be trained,
14+ then `Optimisers.trainable` lets you specify fields to include, and ignore the rest.
15+ * We can likewise add restructions to `Functors.functor`, but not yet written.
16+ * In fact you can provide an arbitrary keyword with this syntax, and it will
17+ overload this function alla `trainable`... that might be a terrible idea.
18+
19+ It also handles overloads of `show` for pretty printing.
20+ * By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`.
21+ * If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents.
22+ * To disable all `show` overloads, maybe we want a `:ignore` option too.
23+
24+ (You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.)
25+ """
26+ macro layer (exs... )
27+ out = quote end
28+
29+ # These functions are defined in show.jl, and each return an expression overloading Base.show
30+ type, rest... = if exs[1 ] == QuoteNode (:expand )
31+ push! (out. args, _macro_big_show (esc (exs[2 ])))
32+ exs[2 : end ]
33+ elseif exs[1 ] == QuoteNode (:ignore )
34+ exs[2 : end ]
35+ elseif exs[1 ] isa QuoteNode
36+ error (" before the type, only accepted options are `:expand` and `:ignore`" )
37+ else
38+ push! (out. args, _macro_layer_show (esc (exs[1 ])))
39+ exs
40+ end
41+
42+ # This function exists only for depwarns when you use @functor directly
43+ push! (out. args, :(Flux. _check_new_macro (:: $ (esc (type))) = nothing )) # scope is weird ?? can't use $ on func name?
44+
45+ i = findfirst (ex -> Meta. isexpr (ex, :(= )) && ex. args[1 ] == :functor , rest)
46+ if isnothing (i)
47+ push! (out. args, _macro_functor (esc (type)))
48+ else
49+ push! (out. args, _macro_functor (esc (type), rest[i]. args[2 ]))
50+ end
51+ for j in 1 : length (rest)
52+ j == i && continue
53+ ex = rest[j]
54+ Meta. isexpr (ex, :(= )) || error (" expected keyword = fields" )
55+ if ex. args[1 ] == :trainable
56+ push! (out. args, _macro_trainable (type, trainable, ex. args[2 ])) # pass the function "trainable" not the symbol
57+ else
58+ error ()
59+ # @warn "defining a method for $(ex.args[1]) in your scope" # ??
60+ # push!(out.args, _macro_trainable(type, esc(ex.args[1]), ex.args[2]))
61+ end
62+ end
63+
64+ out
65+ end
66+
67+ # Temporary depwarn function:
68+
69+ function _check_new_macro (x:: T ) where T
70+ Functors. isleaf (x) && return
71+ @warn " you used @functor for this type, but should now use @layer" T maxlog= 1 _id= hash (T)
72+ end
73+ _check_new_macro (:: Tuple ) = nothing # defined by Functors.jl, not by users
74+ _check_new_macro (:: NamedTuple ) = nothing
75+ _check_new_macro (:: Transpose ) = nothing
76+ _check_new_macro (:: Adjoint ) = nothing
77+ _check_new_macro (:: Ref ) = nothing
78+
79+ # @layer's code for Functors & Adapt
80+ # Unlike @functor, _default_functor doesn't need to eval anything
81+
82+ function _macro_functor (type)
83+ quote
84+ Functors. functor (:: Type{T} , x) where {T<: $type } = _default_functor (T, x)
85+ Adapt. adapt_structure (to, layer:: $type ) = fmap (adapt (to), layer)
86+ end
87+ end
88+
89+ function _macro_functor (type, fields)
90+ error (" the equivalent of @functor Layer (:x,) isn't written yet, sorry" )
91+ end
92+
93+ function _default_functor (:: Type{T} , x) where {T}
94+ if @generated
95+ F = fieldnames (T)
96+ args = map (sy -> :(getfield (x, $ (QuoteNode (sy)))), F)
97+ C = Base. typename (T). name # constructor
98+ recon = VERSION > v " 1.9-" ? :(Splat ($ C)) : :(Base. splat ($ C))
99+ :((NamedTuple {$F} (($ (args... ),)), $ recon))
100+ else
101+ # Getting this parameterless type takes about 2μs, every time:
102+ namedtuple (x), Base. splat (Base. typename (T). wrapper)
103+ end
104+ end
105+
106+ function namedtuple (x:: T ) where T
107+ F = fieldnames (T)
108+ NamedTuple {F} (map (sy -> getfield (x, sy), F))
109+ end
110+
111+ # @layer's code for Optimisers.trainable, and perhaps anything else,
112+ # with the pattern that keywords mean function names & what fields they pick.
113+
114+ function _macro_trainable (type, fun, fields)
115+ Meta. isexpr (fields, :tuple ) || error (" expected a tuple of field names" )
116+ symbols = Tuple (map (_noquotenode, fields. args))
117+ quoted = map (QuoteNode, symbols)
118+ gets = [:(getfield (x, $ f)) for f in quoted]
119+ quote
120+ # $fun(x::$type) = NamedTuple{$names}(($(gets...),))
121+ Flux. trainable (x:: $type ) = NamedTuple {$symbols} (($ (gets... ),)) # ?? scope is weird
122+ end
123+ end
124+ _macro_trainable (type, fun, field:: Union{Symbol,QuoteNode} ) = _macro_trainable (type, fun, :(($ field,))) # lets you forget a comma
125+
126+ _noquotenode (s:: Symbol ) = s
127+ _noquotenode (q:: QuoteNode ) = q. value # lets you write trainable=(:x,:y) instead of (x,y)
128+ _noquotenode (ex) = error (" expected a symbol, got $ex " )
129+
130+
131+
132+
133+
134+
135+ # @big_show Chain
136+ # @big_show Parallel
137+ # @big_show SkipConnection
138+ # @big_show Recur
139+ # @big_show Maxout
140+
141+
142+
143+
144+ """
145+ @big_show MyContainer
146+
147+ This macro lets you opt-in to Flux's fancy printing.
148+
149+ When `model::MyContainer` is returned at the REPL it will be treated like `Chain`,
150+ and the printing routine will recursively unfold its children.
151+ This is triggered by adding a method to 3-arg `Base.show(io::IO, ::MIME"text/plain", l::MyContainer)`.
152+
153+ Custom layers which do not contain other layers (more like `Dense` than like `Chain`)
154+ need not call this, and should simply define 2-arg `Base.show(io::IO, l::MyLayer)`.
155+
156+ # Example
157+ ```jldoctest
158+ julia> struct Trio{A,B,C}; a::A; b::B; c::C end
159+
160+ julia> Flux.@functor Trio
161+
162+ julia> Flux.@big_show Trio
163+
164+ julia> tri = Trio(Dense(10=>5,tanh), Dense(5=>2), softmax)
165+ Trio(
166+ Dense(10 => 5, tanh), # 55 parameters
167+ Dense(5 => 2), # 12 parameters
168+ NNlib.softmax,
169+ ) # Total: 4 arrays, 67 parameters, 492 bytes.
170+ ```
171+
172+ Note that there is no automatic method for 2-arg `show`, and thus
173+ something like `(tri, tri)` will print all the type parameters.
174+
175+ However, `Chain(tri, tri)` will always use Flux's recursive printing,
176+ even without using this macro: `Chain` is the entry point.
177+ """
0 commit comments