@@ -164,64 +164,67 @@ end
164164# Work around for https://github.com/JuliaLang/julia/issues/27988
165165# The following code is borrowed from https://github.com/JuliaLang/julia/pull/43322
166166# with some modification to make it also works on 1.6.
167- # TODO : make `broadcast_flatten` call `Broadcast.flatten` once julia#43322 is merged.
168167module StableFlatten
169168
170169export broadcast_flatten
171170
172- using Base: tail
173- using Base. Broadcast: isflat, Broadcasted
174-
175- maybeconstructor (f) = f
176- maybeconstructor (:: Type{F} ) where {F} = (args... ; kwargs... ) -> F (args... ; kwargs... )
171+ if VERSION >= v " 1.11.0-DEV.103"
172+ const broadcast_flatten = Broadcast. flatten
173+ else
174+ using Base: tail
175+ using Base. Broadcast: isflat, Broadcasted
176+
177+ maybeconstructor (f) = f
178+ maybeconstructor (:: Type{F} ) where {F} = (args... ; kwargs... ) -> F (args... ; kwargs... )
179+
180+ function broadcast_flatten (bc:: Broadcasted{Style} ) where {Style}
181+ isflat (bc) && return bc
182+ args = cat_nested (bc)
183+ len = Val {length(args)} ()
184+ makeargs = make_makeargs (bc. args, len, ntuple (_-> true , len))
185+ f = maybeconstructor (bc. f)
186+ @inline newf (args... ) = f (prepare_args (makeargs, args)... )
187+ return Broadcasted {Style} (newf, args, bc. axes)
188+ end
177189
178- function broadcast_flatten (bc:: Broadcasted{Style} ) where {Style}
179- isflat (bc) && return bc
180- args = cat_nested (bc)
181- len = Val {length(args)} ()
182- makeargs = make_makeargs (bc. args, len, ntuple (_-> true , len))
183- f = maybeconstructor (bc. f)
184- @inline newf (args... ) = f (prepare_args (makeargs, args)... )
185- return Broadcasted {Style} (newf, args, bc. axes)
186- end
190+ cat_nested (bc:: Broadcasted ) = cat_nested_args (bc. args)
191+ cat_nested_args (:: Tuple{} ) = ()
192+ cat_nested_args (t:: Tuple ) = (cat_nested (t[1 ])... , cat_nested_args (tail (t))... )
193+ cat_nested (@nospecialize (a)) = (a,)
187194
188- cat_nested (bc:: Broadcasted ) = cat_nested_args (bc. args)
189- cat_nested_args (:: Tuple{} ) = ()
190- cat_nested_args (t:: Tuple ) = (cat_nested (t[1 ])... , cat_nested_args (tail (t))... )
191- cat_nested (@nospecialize (a)) = (a,)
195+ function make_makeargs (args:: Tuple , len, flags)
196+ makeargs, r = _make_makeargs (args, len, flags)
197+ r isa Tuple{} || error (" Internal error. Please file a bug" )
198+ return makeargs
199+ end
192200
193- function make_makeargs (args:: Tuple , len, flags)
194- makeargs, r = _make_makeargs (args, len, flags)
195- r isa Tuple{} || error (" Internal error. Please file a bug" )
196- return makeargs
197- end
201+ # We build `makeargs` by traversing the broadcast nodes recursively.
202+ # note: `len` isa `Val` indicates the length of whole flattened argument list.
203+ # `flags` is a tuple of `Bool` with the same length of the rest arguments.
204+ @inline function _make_makeargs (args:: Tuple , len:: Val , flags:: Tuple )
205+ head, flags′ = _make_makeargs1 (args[1 ], len, flags)
206+ rest, flags″ = _make_makeargs (tail (args), len, flags′)
207+ (head, rest... ), flags″
208+ end
209+ _make_makeargs (:: Tuple{} , :: Val , x:: Tuple ) = (), x
198210
199- # We build `makeargs` by traversing the broadcast nodes recursively.
200- # note: `len` isa `Val` indicates the length of whole flattened argument list.
201- # `flags` is a tuple of `Bool` with the same length of the rest arguments.
202- @inline function _make_makeargs (args:: Tuple , len:: Val , flags:: Tuple )
203- head, flags′ = _make_makeargs1 (args[1 ], len, flags)
204- rest, flags″ = _make_makeargs (tail (args), len, flags′)
205- (head, rest... ), flags″
206- end
207- _make_makeargs (:: Tuple{} , :: Val , x:: Tuple ) = (), x
211+ # For flat nodes:
212+ # 1. we just consume one argument, and return the "pick" function
213+ @inline function _make_makeargs1 (@nospecialize (a), :: Val{N} , flags:: Tuple ) where {N}
214+ pickargs (:: Val{N} ) where {N} = (@nospecialize (x:: Tuple )) -> x[N]
215+ return pickargs (Val {N-length(flags)+1} ()), tail (flags)
216+ end
208217
209- # For flat nodes:
210- # 1. we just consume one argument, and return the "pick" function
211- @inline function _make_makeargs1 (@nospecialize (a), :: Val{N} , flags:: Tuple ) where {N}
212- pickargs (:: Val{N} ) where {N} = (@nospecialize (x:: Tuple )) -> x[N]
213- return pickargs (Val {N-length(flags)+1} ()), tail (flags)
214- end
218+ # For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc)))
219+ @inline function _make_makeargs1 (bc:: Broadcasted , len:: Val , flags:: Tuple )
220+ makeargs, flags′ = _make_makeargs (bc. args, len, flags)
221+ f = maybeconstructor (bc. f)
222+ @inline makeargs1 (@nospecialize (args:: Tuple )) = f (prepare_args (makeargs, args)... )
223+ makeargs1, flags′
224+ end
215225
216- # For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc)))
217- @inline function _make_makeargs1 (bc:: Broadcasted , len:: Val , flags:: Tuple )
218- makeargs, flags′ = _make_makeargs (bc. args, len, flags)
219- f = maybeconstructor (bc. f)
220- @inline makeargs1 (@nospecialize (args:: Tuple )) = f (prepare_args (makeargs, args)... )
221- makeargs1, flags′
226+ prepare_args (:: Tuple{} , @nospecialize (:: Tuple )) = ()
227+ @inline prepare_args (makeargs:: Tuple , @nospecialize (x:: Tuple )) = (makeargs[1 ](x), prepare_args (tail (makeargs), x)... )
222228end
223-
224- prepare_args (:: Tuple{} , @nospecialize (:: Tuple )) = ()
225- @inline prepare_args (makeargs:: Tuple , @nospecialize (x:: Tuple )) = (makeargs[1 ](x), prepare_args (tail (makeargs), x)... )
226229end
227230using . StableFlatten
0 commit comments