@@ -37,11 +37,9 @@ function axes_types(::Type{T}) where {T}
3737end
3838axes_types (:: Type{LinearIndices{N,R}} ) where {N,R} = R
3939axes_types (:: Type{CartesianIndices{N,R}} ) where {N,R} = R
40- function axes_types (:: Type{T} ) where {T<: VecAdjTrans }
41- return Tuple{OptionallyStaticUnitRange{One,One},axes_types (parent_type (T), One ())}
42- end
43- function axes_types (:: Type{T} ) where {T<: MatAdjTrans }
44- return eachop_tuple (_get_tuple, to_parent_dims (T), axes_types (parent_type (T)))
40+ function axes_types (:: Type{T} ) where {T<: Union{Adjoint,Transpose} }
41+ P = parent_type (T)
42+ return Tuple{axes_types (P, static (2 )), axes_types (P, static (1 ))}
4543end
4644function axes_types (:: Type{T} ) where {T<: PermutedDimsArray }
4745 return eachop_tuple (_get_tuple, to_parent_dims (T), axes_types (parent_type (T)))
@@ -133,6 +131,21 @@ function axes(a::A, dim::Integer) where {A}
133131 return axes (parent (a), to_parent_dims (A, dim))
134132 end
135133end
134+ function axes (A:: CartesianIndices{N} , dim:: Integer ) where {N}
135+ if dim > N
136+ return static (1 ): static (1 )
137+ else
138+ return getfield (axes (A), Int (dim))
139+ end
140+ end
141+ function axes (A:: LinearIndices{N} , dim:: Integer ) where {N}
142+ if dim > N
143+ return static (1 ): static (1 )
144+ else
145+ return getfield (axes (A), Int (dim))
146+ end
147+ end
148+
136149axes (A:: SubArray , dim:: Integer ) = Base. axes (A, Int (dim)) # TODO implement ArrayInterface version
137150axes (A:: ReinterpretArray , dim:: Integer ) = Base. axes (A, Int (dim)) # TODO implement ArrayInterface version
138151axes (A:: Base.ReshapedArray , dim:: Integer ) = Base. axes (A, Int (dim)) # TODO implement ArrayInterface version
@@ -160,3 +173,137 @@ axes(A::Base.ReshapedArray) = Base.axes(A) # TODO implement ArrayInterface vers
160173axes (A:: CartesianIndices ) = A. indices
161174axes (A:: LinearIndices ) = A. indices
162175
176+ """
177+ LazyAxis{N}(parent::AbstractArray)
178+
179+ A lazy representation of `axes(parent, N)`.
180+ """
181+ struct LazyAxis{N,P} <: AbstractUnitRange{Int}
182+ parent:: P
183+
184+ LazyAxis {N} (parent:: P ) where {N,P} = new {N::Int,P} (parent)
185+ @inline function LazyAxis {:} (parent:: P ) where {P}
186+ if ndims (P) === 1
187+ return new {1,P} (parent)
188+ else
189+ return new {:,P} (parent)
190+ end
191+ end
192+ end
193+
194+ @inline Base. parent (x:: LazyAxis{N,P} ) where {N,P} = axes (getfield (x, :parent ), static (N))
195+ @inline function Base. parent (x:: LazyAxis{:,P} ) where {P}
196+ return eachindex (IndexLinear (), getfield (x, :parent ))
197+ end
198+
199+ @inline parent_type (:: Type{LazyAxis{N,P}} ) where {N,P} = axes_types (P, static (N))
200+ # TODO this approach to parent_type(::Type{LazyAxis{:}}) is a bit hacky. Something like
201+ # LabelledArrays has a linear set of symbolic keys, which could be propagated through
202+ # `to_indices` for key based indexing. However, there currently isn't a good way of handling
203+ # that when the linear indices aren't linearly accessible through a child array (e.g, adjoint)
204+ # For now we just make sure the linear elements are accurate.
205+ parent_type (:: Type{LazyAxis{:,P}} ) where {P<: Array } = OneTo{Int}
206+ @inline function parent_type (:: Type{LazyAxis{:,P}} ) where {P}
207+ if known_length (P) === nothing
208+ return OptionallyStaticUnitRange{StaticInt{1 },Int}
209+ else
210+ return OptionallyStaticUnitRange{StaticInt{1 },StaticInt{known_length (P)}}
211+ end
212+ end
213+
214+ Base. keys (x:: LazyAxis ) = keys (parent (x))
215+
216+ Base. IndexStyle (:: Type{T} ) where {T<: LazyAxis } = IndexStyle (parent_type (T))
217+
218+ can_change_size (:: Type{LazyAxis{N,P}} ) where {N,P} = can_change_size (P)
219+
220+ known_first (:: Type{T} ) where {T<: LazyAxis } = known_first (parent_type (T))
221+
222+ known_length (:: Type{LazyAxis{N,P}} ) where {N,P} = known_size (P, N)
223+ known_length (:: Type{LazyAxis{:,P}} ) where {P} = known_length (P)
224+
225+ @inline function known_last (:: Type{T} ) where {T<: LazyAxis }
226+ return _lazy_axis_known_last (known_first (T), known_length (T))
227+ end
228+ _lazy_axis_known_last (start:: Int , length:: Int ) = (length + start) - 1
229+ _lazy_axis_known_last (:: Any , :: Any ) = nothing
230+
231+ @inline function Base. first (x:: LazyAxis{N} ):: Int where {N}
232+ if known_first (x) === nothing
233+ return offsets (getfield (x, :parent ), static (N))
234+ else
235+ return known_first (x)
236+ end
237+ end
238+ @inline function Base. first (x:: LazyAxis{:} ):: Int
239+ if known_first (x) === nothing
240+ return firstindex (getfield (x, :parent ))
241+ else
242+ return known_first (x)
243+ end
244+ end
245+
246+ @inline function Base. length (x:: LazyAxis{N} ):: Int where {N}
247+ if known_length (x) === nothing
248+ return size (getfield (x, :parent ), static (N))
249+ else
250+ return known_length (x)
251+ end
252+ end
253+ @inline function Base. length (x:: LazyAxis{:} ):: Int
254+ if known_length (x) === nothing
255+ return lastindex (getfield (x, :parent ))
256+ else
257+ return known_length (x)
258+ end
259+ end
260+
261+ @inline function Base. last (x:: LazyAxis ):: Int
262+ if known_last (x) === nothing
263+ if known_first (x) === 1
264+ return length (x)
265+ else
266+ return (static_length (x) + static_first (x)) - 1
267+ end
268+ else
269+ return known_last (x)
270+ end
271+ end
272+
273+ Base. to_shape (x:: LazyAxis ) = length (x)
274+
275+ @inline function Base. checkindex (:: Type{Bool} , x:: LazyAxis , i:: Integer )
276+ if known_first (x) === nothing || known_last (x) === nothing
277+ return checkindex (Bool, parent (x), i)
278+ else # everything is static so we don't have to retrieve the axis
279+ return (! (known_first (x) > i) || ! (known_last (x) < i))
280+ end
281+ end
282+
283+ @propagate_inbounds function Base. getindex (x:: LazyAxis , i:: Integer )
284+ @boundscheck checkindex (Bool, x, i) || throw (BoundsError (x, i))
285+ return Int (i)
286+ end
287+ @propagate_inbounds Base. getindex (x:: LazyAxis , i:: StepRange{T} ) where {T<: Integer } = parent (x)[i]
288+ @propagate_inbounds Base. getindex (x:: LazyAxis , i:: AbstractUnitRange{<:Integer} ) = parent (x)[i]
289+
290+ Base. show (io:: IO , x:: LazyAxis{N} ) where {N} = print (io, " LazyAxis{$N }($(parent (x)) ))" )
291+
292+ """
293+ lazy_axes(x)
294+
295+ Produces a tuple of axes where each axis is constructed lazily. If an axis of `x` is already
296+ constructed or it is simply retrieved.
297+ """
298+ @generated function lazy_axes (x:: X ) where {X}
299+ Expr (:block ,
300+ Expr (:meta , :inline ),
301+ Expr (:tuple , [:(LazyAxis {$dim} (x)) for dim in 1 : ndims (X)]. .. )
302+ )
303+ end
304+ lazy_axes (x:: LinearIndices ) = axes (x)
305+ lazy_axes (x:: CartesianIndices ) = axes (x)
306+ @inline lazy_axes (x:: MatAdjTrans ) = reverse (lazy_axes (parent (x)))
307+ @inline lazy_axes (x:: VecAdjTrans ) = (LazyAxis {1} (x), first (lazy_axes (parent (x))))
308+ @inline lazy_axes (x:: PermutedDimsArray ) = permute (lazy_axes (parent (x)), to_parent_dims (A))
309+
0 commit comments