11
2- using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo
2+ using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
33const NoT = NoTangent()
44
55"""
@@ -11,11 +11,11 @@ Differentiable.
1111
1212# Example
1313```jldoctest
14- julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0 ])))
15- ([1.0, 2.0, 3.0], Restructure(NamedTuple, ..., 3))
14+ julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3 + 4im ])))
15+ (ComplexF64 [1.0 + 0.0im , 2.0 + 0.0im , 3.0 + 4.0im ], Restructure(NamedTuple, ..., 3))
1616
17- julia> re([10,20,30 ])
18- (x = [10 .0, 20 .0], y = (sin, [30.0 ]))
17+ julia> re([3, 5-im, 7+11im ])
18+ (x = [3 .0, 5 .0], y = (sin, ComplexF64[7.0 + 11.0im ]))
1919```
2020"""
2121function destructure(x)
2727 Restructure(Model, ..., length)
2828
2929This is what [`destructure`](@ref) returns, and `re(p)` will re-build the model with
30- new parameters from vector `p`. If the model is callable, then `re(x, p)` .
30+ new parameters from vector `p`. If the model is callable, then `re(x, p) == re(p)(x)` .
3131
3232# Example
3333```julia
@@ -107,22 +107,22 @@ end
107107
108108function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat; len)
109109 dflat = map!(zero, similar(flat, float(eltype(flat))), flat)
110- _rebuild_back(dx) = (NoT, NoT, NoT, _accumulate !(x, dx , off, dflat))
110+ _rebuild_back(dx) = (NoT, NoT, NoT, _grad !(x, unthunk(dx) , off, dflat))
111111 _rebuild(x, off, flat; len), _rebuild_back
112112end
113113
114114# This is the gradient of model reconstruction, accumulating duplicates:
115- function _accumulate !(x, dx, off, flat::AbstractVector)
115+ function _grad !(x, dx, off, flat::AbstractVector)
116116 x′, _ = functor(typeof(x), x)
117117 dx′, _ = functor(typeof(x), dx)
118118 off′, _ = functor(typeof(x), off)
119- foreach((xᵢ, dxᵢ, oᵢ) -> _accumulate !(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
119+ foreach((xᵢ, dxᵢ, oᵢ) -> _grad !(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
120120 flat
121121end
122- function _accumulate !(x, dx, off::Integer, flat::AbstractVector)
122+ function _grad !(x, dx, off::Integer, flat::AbstractVector)
123123 @views flat[off .+ (1:length(x))] .+= dx # must visit all tied nodes
124124 flat
125125end
126- _accumulate !(x, dx::Zero, off, flat::AbstractVector) = nothing
127- _accumulate !(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity
126+ _grad !(x, dx::Zero, off, flat::AbstractVector) = nothing
127+ _grad !(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity
128128
0 commit comments