@@ -60,21 +60,25 @@ static_check_broadcast_shape(::Tuple{}, ::Tuple{}) = ()
6060@inline function Base. copy (B:: Broadcasted{StaticArrayStyle{M}} ) where M
6161 flat = Broadcast. flatten (B); as = flat. args; f = flat. f
6262 argsizes = broadcast_sizes (as... )
63- destsize = combine_sizes (argsizes)
64- _broadcast (f, destsize, argsizes, as... )
63+ ax = axes (B)
64+ if ax isa Tuple{Vararg{SOneTo}}
65+ return _broadcast (f, Size (map (length, ax)), argsizes, as... )
66+ end
67+ return copy (convert (Broadcasted{DefaultArrayStyle{M}}, B))
6568end
6669# copyto! overloads
6770@inline Base. copyto! (dest, B:: Broadcasted{<:StaticArrayStyle} ) = _copyto! (dest, B)
6871@inline Base. copyto! (dest:: AbstractArray , B:: Broadcasted{<:StaticArrayStyle} ) = _copyto! (dest, B)
6972@inline function _copyto! (dest, B:: Broadcasted{StaticArrayStyle{M}} ) where M
7073 flat = Broadcast. flatten (B); as = flat. args; f = flat. f
7174 argsizes = broadcast_sizes (as... )
72- destsize = combine_sizes (( Size (dest), argsizes ... ) )
73- if Length (destsize) === Length {Dynamic()} ()
74- # destination dimension cannot be determined statically; fall back to generic broadcast!
75- return copyto! (dest, convert (Broadcasted{DefaultArrayStyle{M}}, B) )
75+ ax = axes (B )
76+ if ax isa Tuple{Vararg{SOneTo}}
77+ @boundscheck axes (dest) == ax || Broadcast . throwdm ( axes (dest), ax)
78+ return _broadcast! (f, Size ( map (length, ax)), dest, argsizes, as ... )
7679 end
77- _broadcast! (f, destsize, dest, argsizes, as... )
80+ # destination dimension cannot be determined statically; fall back to generic broadcast!
81+ return copyto! (dest, convert (Broadcasted{DefaultArrayStyle{M}}, B))
7882end
7983
8084# Resolving priority between dynamic and static axes
@@ -99,45 +103,13 @@ _bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(b, a)
99103@inline broadcast_size (a:: AbstractArray ) = Size (a)
100104@inline broadcast_size (a:: Tuple ) = Size (length (a))
101105
102- function broadcasted_index (oldsize, newindex)
103- index = ones (Int, length (oldsize))
104- for i = 1 : length (oldsize)
105- if oldsize[i] != 1
106- index[i] = newindex[i]
107- end
108- end
109- return LinearIndices (oldsize)[index... ]
110- end
111-
112- # similar to Base.Broadcast.combine_indices:
113- @generated function combine_sizes (s:: Tuple{Vararg{Size}} )
114- sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
115- ndims = 0
116- for i = 1 : length (sizes)
117- ndims = max (ndims, length (sizes[i]))
118- end
119- newsize = StaticDimension[Dynamic () for _ = 1 : ndims]
120- for i = 1 : length (sizes)
121- s = sizes[i]
122- for j = 1 : length (s)
123- if s[j] isa Dynamic
124- continue
125- elseif newsize[j] isa Dynamic || newsize[j] == 1
126- newsize[j] = s[j]
127- elseif newsize[j] ≠ s[j] && s[j] ≠ 1
128- throw (DimensionMismatch (" Tried to broadcast on inputs sized $sizes " ))
129- end
130- end
131- end
132- quote
133- @_inline_meta
134- Size ($ (tuple (newsize... )))
135- end
106+ broadcast_getindex (:: Tuple{} , i:: Int , I:: CartesianIndex ) = return :(_broadcast_getindex (a[$ i], $ I))
107+ function broadcast_getindex (oldsize:: Tuple , i:: Int , newindex:: CartesianIndex )
108+ li = LinearIndices (oldsize)
109+ ind = _broadcast_getindex (li, newindex)
110+ return :(a[$ i][$ ind])
136111end
137112
138- scalar_getindex (x) = x
139- scalar_getindex (x:: Ref ) = x[]
140-
141113isstatic (:: StaticArray ) = true
142114isstatic (:: Transpose{<:Any, <:StaticArray} ) = true
143115isstatic (:: Adjoint{<:Any, <:StaticArray} ) = true
@@ -161,13 +133,11 @@ end
161133
162134@generated function __broadcast (f, :: Size{newsize} , s:: Tuple{Vararg{Size}} , a... ) where newsize
163135 sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
136+
164137 indices = CartesianIndices (newsize)
165138 exprs = similar (indices, Expr)
166139 for (j, current_ind) ∈ enumerate (indices)
167- exprs_vals = [
168- (! (a[i] <: AbstractArray || a[i] <: Tuple ) ? :(scalar_getindex (a[$ i])) : :(a[$ i][$ (broadcasted_index (sizes[i], current_ind))]))
169- for i = 1 : length (sizes)
170- ]
140+ exprs_vals = (broadcast_getindex (sz, i, current_ind) for (i, sz) in enumerate (sizes))
171141 exprs[j] = :(f ($ (exprs_vals... )))
172142 end
173143
@@ -181,27 +151,18 @@ end
181151# # Internal broadcast! machinery for StaticArrays ##
182152# ###################################################
183153
184- @generated function _broadcast! (f, :: Size{newsize} , dest:: AbstractArray , s:: Tuple{Vararg{Size}} , as... ) where {newsize}
185- sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
186- sizes = tuple (sizes... )
187-
188- # TODO : this could also be done outside the generated function:
189- sizematch (Size {newsize} (), Size (dest)) ||
190- throw (DimensionMismatch (" Tried to broadcast to destination sized $newsize from inputs sized $sizes " ))
154+ @generated function _broadcast! (f, :: Size{newsize} , dest:: AbstractArray , s:: Tuple{Vararg{Size}} , a... ) where {newsize}
155+ sizes = [sz. parameters[1 ] for sz in s. parameters]
191156
192157 indices = CartesianIndices (newsize)
193158 exprs = similar (indices, Expr)
194159 for (j, current_ind) ∈ enumerate (indices)
195- exprs_vals = [
196- (! (as[i] <: AbstractArray || as[i] <: Tuple ) ? :(as[$ i][]) : :(as[$ i][$ (broadcasted_index (sizes[i], current_ind))]))
197- for i = 1 : length (sizes)
198- ]
160+ exprs_vals = (broadcast_getindex (sz, i, current_ind) for (i, sz) in enumerate (sizes))
199161 exprs[j] = :(dest[$ j] = f ($ (exprs_vals... )))
200162 end
201163
202164 return quote
203- @_propagate_inbounds_meta
204- @boundscheck sizematch ($ (Size {newsize} ()), dest) || throw (DimensionMismatch (" array could not be broadcast to match destination" ))
165+ @_inline_meta
205166 @inbounds $ (Expr (:block , exprs... ))
206167 return dest
207168 end
0 commit comments