@@ -494,7 +494,8 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
494494end
495495
496496# broadcast
497- import Base. Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown
497+ import Base. Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict
498+ using Base. Broadcast: combine_styles
498499
499500struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
500501
@@ -524,6 +525,82 @@ Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).para
524525
525526BroadcastStyle (:: Type{SA} ) where {SA<: StructArray } = StructArrayStyle {typeof(cst(SA)), ndims(SA)} ()
526527
528+ """
529+ always_struct_broadcast(style::BroadcastStyle)
530+
531+ Check if `style` supports struct-broadcast natively, which means:
532+ 1) `Base.copy` is not overloaded.
533+ 2) `Base.similar` is defined.
534+ 3) `Base.copyto!` supports `StructArray`s as broadcasted arguments.
535+
536+ If any of the above conditions are not met, then this function should
537+ not be overloaded.
538+ In that case, try to overload [`try_struct_copy`](@ref) to support out-of-place
539+ struct-broadcast.
540+ """
541+ always_struct_broadcast (:: Any ) = false
542+ always_struct_broadcast (:: DefaultArrayStyle ) = true
543+ always_struct_broadcast (:: ArrayConflict ) = true
544+
545+ """
546+ try_struct_copy(bc::Broadcasted)
547+
548+ Entry for non-native outplace struct-broadcast.
549+
550+ See also [`always_struct_broadcast`](@ref).
551+ """
552+ try_struct_copy (bc:: Broadcasted ) = copy (bc)
553+
554+ function Base. copy (bc:: Broadcasted{StructArrayStyle{S, N}} ) where {S, N}
555+ if always_struct_broadcast (S ())
556+ return invoke (copy, Tuple{Broadcasted}, bc)
557+ else
558+ return try_struct_copy (replace_structarray (bc))
559+ end
560+ end
561+
562+ """
563+ replace_structarray(bc::Broadcasted)
564+
565+ An internal function transforms the `Broadcasted` with `StructArray` into
566+ an equivalent one without it. This is not a must if the root `BroadcastStyle`
567+ supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
568+ e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
569+ """
570+ function replace_structarray (bc:: Broadcasted{Style} ) where {Style}
571+ args = replace_structarray_args (bc. args)
572+ Style′ = parent_style (Style ())
573+ return Broadcasted {Style′} (bc. f, args, bc. axes)
574+ end
575+ function replace_structarray (A:: StructArray )
576+ f = Instantiator (eltype (A))
577+ args = Tuple (components (A))
578+ Style = typeof (combine_styles (args... ))
579+ return Broadcasted {Style} (f, args, axes (A))
580+ end
581+ replace_structarray (@nospecialize (A)) = A
582+
583+ replace_structarray_args (args:: Tuple ) = (replace_structarray (args[1 ]), replace_structarray_args (tail (args))... )
584+ replace_structarray_args (:: Tuple{} ) = ()
585+
586+ parent_style (@nospecialize (x)) = typeof (x)
587+ parent_style (:: StructArrayStyle{S, N} ) where {S, N} = S
588+ parent_style (:: StructArrayStyle{S, N} ) where {N, S<: AbstractArrayStyle{N} } = S
589+ parent_style (:: StructArrayStyle{S, N} ) where {S<: AbstractArrayStyle{Any} , N} = S
590+ parent_style (:: StructArrayStyle{S, N} ) where {S<: AbstractArrayStyle , N} = typeof (S (Val (N)))
591+
592+ # `instantiate` and `_axes` might be overloaded for static axes.
593+ function Broadcast. instantiate (bc:: Broadcasted{Style} ) where {Style <: StructArrayStyle }
594+ Style′ = parent_style (Style ())
595+ bc′ = Broadcast. instantiate (convert (Broadcasted{Style′}, bc))
596+ return convert (Broadcasted{Style}, bc′)
597+ end
598+
599+ function Broadcast. _axes (bc:: Broadcasted{Style} , :: Nothing ) where {Style <: StructArrayStyle }
600+ Style′ = parent_style (Style ())
601+ return Broadcast. _axes (convert (Broadcasted{Style′}, bc), nothing )
602+ end
603+
527604# Here we use `similar` defined for `S` to build the dest Array.
528605function Base. similar (bc:: Broadcasted{StructArrayStyle{S, N}} , :: Type{ElType} ) where {S, N, ElType}
529606 bc′ = convert (Broadcasted{S}, bc)
@@ -532,12 +609,22 @@ end
532609
533610# Unwrapper to recover the behaviour defined by parent style.
534611@inline function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{StructArrayStyle{S, N}} ) where {S, N}
535- return copyto! (dest, convert (Broadcasted{S}, bc))
612+ bc′ = always_struct_broadcast (S ()) ? convert (Broadcasted{S}, bc) : replace_structarray (bc)
613+ return copyto! (dest, bc′)
536614end
537615
538616@inline function Broadcast. materialize! (:: StructArrayStyle{S} , dest, bc:: Broadcasted ) where {S}
539- return Broadcast. materialize! (S (), dest, bc)
617+ bc′ = always_struct_broadcast (S ()) ? bc : replace_structarray (bc)
618+ return Broadcast. materialize! (S (), dest, bc′)
540619end
541620
542621# for aliasing analysis during broadcast
622+ function Broadcast. broadcast_unalias (dest:: StructArray , src:: AbstractArray )
623+ if dest === src || any (Base. Fix2 (=== , src), components (dest))
624+ return src
625+ else
626+ return Base. unalias (dest, src)
627+ end
628+ end
629+
543630Base. dataids (u:: StructArray ) = mapreduce (Base. dataids, (a, b) -> (a... , b... ), values (components (u)), init= ())
0 commit comments