7474
7575
7676"""
77- contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}}
77+ contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}}
7878
7979Returns a tuple boolean `Val`s indicating whether that axis is contiguous.
8080"""
@@ -84,14 +84,14 @@ contiguous_axis_indicator(::Nothing, ::Val) = nothing
8484Base. @pure contiguous_axis_indicator (:: Contiguous{N} , :: Val{D} ) where {N,D} = ntuple (d -> Val {d == N} (), Val {D} ())
8585
8686"""
87- If the contiguous dimension is not the dimension with `Stride_rank {1}`:
87+ If the contiguous dimension is not the dimension with `StrideRank {1}`:
8888"""
8989struct ContiguousBatch{N} end
9090Base. @pure ContiguousBatch (N:: Int ) = ContiguousBatch {N} ()
9191_get (:: ContiguousBatch{N} ) where {N} = N
9292
9393"""
94- contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
94+ contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
9595
9696Returns the Base.size of contiguous batches if `!isone(stride_rank(T, contiguous_axis(T)))`.
9797If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `ContiguousBatch{0}()`.
@@ -126,7 +126,7 @@ Base.collect(::StrideRank{R}) where {R} = collect(R)
126126@inline Base. getindex (:: StrideRank{R} , :: Val{I} ) where {R,I} = StrideRank {permute(R, I)} ()
127127
128128"""
129- rank_to_sortperm(::StrideRank) -> NTuple{N,Int}
129+ rank_to_sortperm(::StrideRank) -> NTuple{N,Int}
130130
131131Returns the `sortperm` of the stride ranks.
132132"""
@@ -177,7 +177,9 @@ stride_rank(x, i) = stride_rank(x)[i]
177177stride_rank (:: Type{R} ) where {T, N, S, A <: Array{S} , R <: Base.ReinterpretArray{T, N, S, A} } = StrideRank {ntuple(identity, Val{N}())} ()
178178
179179"""
180- is_column_major(A) -> Val{true/false}()
180+ is_column_major(A) -> Val{true/false}()
181+
182+ Returns `Val{true}` if elements of `A` are stored in column major order. Otherwise returns `Val{false}`.
181183"""
182184is_column_major (A) = is_column_major (stride_rank (A), contiguous_batch_size (A))
183185is_column_major (:: Nothing , :: Any ) = Val {false} ()
@@ -197,7 +199,7 @@ Base.@pure DenseDims(D::NTuple{<:Any,Bool}) = DenseDims{D}()
197199@inline Base. getindex (:: DenseDims{D} , i:: Integer ) where {D} = D[i]
198200@inline Base. getindex (:: DenseDims{D} , :: Val{I} ) where {D,I} = DenseDims {permute(D, I)} ()
199201"""
200- dense_dims(::Type{T}) -> NTuple{N,Bool}
202+ dense_dims(::Type{T}) -> NTuple{N,Bool}
201203
202204Returns a tuple of indicators for whether each axis is dense.
203205An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A, j)` where `stride_rank(A)[i] + 1 == stride_rank(A)[j]`.
@@ -250,7 +252,7 @@ permute(t::NTuple{N}, I::NTuple{N,Int}) where {N} = ntuple(n -> t[I[n]], Val{N}(
250252end
251253
252254"""
253- strides(A)
255+ strides(A) -> Tuple
254256
255257Returns the strides of array `A`. If any strides are known at compile time,
256258these should be returned as `Static` numbers. For example:
@@ -274,8 +276,196 @@ while still producing correct behavior when using valid cartesian indices, such
274276strides (A) = Base. strides (A)
275277strides (A, d) = strides (A)[to_dims (A, d)]
276278
279+ @generated function _perm_tuple (:: Type{T} , :: Val{P} ) where {T,P}
280+ out = Expr (:curly , :Tuple )
281+ for p in P
282+ push! (out. args, :(T. parameters[$ p]))
283+ end
284+ Expr (:block , Expr (:meta , :inline ), out)
285+ end
286+
287+ """
288+ axes_types(::Type{T}[, d]) -> Type
289+
290+ Returns the type of the axes for `T`
291+ """
292+ axes_types (x) = axes_types (typeof (x))
293+ axes_types (x, d) = axes_types (typeof (x), d)
294+ @inline axes_types (:: Type{T} , d) where {T} = axes_types (T). parameters[to_dims (T, d)]
295+ function axes_types (:: Type{T} ) where {T}
296+ if parent_type (T) <: T
297+ return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims (T)}}
298+ else
299+ return axes_types (parent_type (T))
300+ end
301+ end
302+ axes_types (:: Type{T} ) where {T<: Adjoint } = _perm_tuple (axes_types (parent_type (T)), Val ((2 , 1 )))
303+ axes_types (:: Type{T} ) where {T<: Transpose } = _perm_tuple (axes_types (parent_type (T)), Val ((2 , 1 )))
304+ function axes_types (:: Type{T} ) where {I1,T<: PermutedDimsArray{<:Any,<:Any,I1} }
305+ return _perm_tuple (axes_types (parent_type (T)), Val (I1))
306+ end
307+ function axes_types (:: Type{T} ) where {T<: AbstractRange }
308+ if known_length (T) === nothing
309+ return Tuple{OptionallyStaticUnitRange{One,Int}}
310+ else
311+ return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length (T)}}}
312+ end
313+ end
314+
315+ @inline function axes_types (:: Type{T} ) where {P,I,T<: SubArray{<:Any,<:Any,P,I} }
316+ return _sub_axes_types (Val (ArrayStyle (T)), I, axes_types (P))
317+ end
318+ @generated function _sub_axes_types (:: Val{S} , :: Type{I} , :: Type{PI} ) where {S,I<: Tuple ,PI<: Tuple }
319+ out = Expr (:curly , :Tuple )
320+ d = 1
321+ for i in I. parameters
322+ ad = argdims (S, i)
323+ if ad > 0
324+ push! (out. args, :(sub_axis_type ($ (PI. parameters[d]), $ i)))
325+ d += ad
326+ else
327+ d += 1
328+ end
329+ end
330+ Expr (:block , Expr (:meta , :inline ), out)
331+ end
332+
333+ @inline function axes_types (:: Type{T} ) where {T<: Base.ReinterpretArray }
334+ return _reinterpret_axes_types (axes_types (parent_type (T)), eltype (T), eltype (parent_type (T)))
335+ end
336+ @generated function _reinterpret_axes_types (:: Type{I} , :: Type{T} , :: Type{S} ) where {I<: Tuple ,T,S}
337+ out = Expr (:curly , :Tuple )
338+ for i in 1 : length (I. parameters)
339+ if i === 1
340+ push! (out. args, reinterpret_axis_type (I. parameters[1 ], T, S))
341+ else
342+ push! (out. args, I. parameters[i])
343+ end
344+ end
345+ Expr (:block , Expr (:meta , :inline ), out)
346+ end
347+
348+
349+ # These methods help handle identifying axes that dont' directly propagate from the
350+ # parent array axes. They may be worth making a formal part of the API, as they provide
351+ # a low traffic spot to change what axes_types produces.
352+ @inline function sub_axis_type (:: Type{A} , :: Type{I} ) where {A,I}
353+ if known_length (I) === nothing
354+ return OptionallyStaticUnitRange{One,Int}
355+ else
356+ return OptionallyStaticUnitRange{One,StaticInt{known_length (I)}}
357+ end
358+ end
359+
360+ @inline function reinterpret_axis_type (:: Type{A} , :: Type{T} , :: Type{S} ) where {A,T,S}
361+ if known_length (A) === nothing
362+ return OptionallyStaticUnitRange{One,Int}
363+ else
364+ return OptionallyStaticUnitRange{One,StaticInt{Int (known_length (A) / (sizeof (T) / sizeof (S)))}}
365+ end
366+ end
367+
277368"""
278- offsets(A)
369+ known_offsets(::Type{T}[, d]) -> Tuple
370+
371+ Returns a tuple of offset values known at compile time. If the offset of a given axis is
372+ not known at compile time `nothing` is returned its position.
373+ """
374+ @inline known_offsets (x, d) = known_offsets (x)[to_dims (x, d)]
375+ known_offsets (x) = known_offsets (typeof (x))
376+ @generated function known_offsets (:: Type{T} ) where {T}
377+ out = Expr (:tuple )
378+ for p in axes_types (T). parameters
379+ push! (out. args, known_first (p))
380+ end
381+ return out
382+ end
383+
384+ """
385+ known_size(::Type{T}[, d]) -> Tuple
386+
387+ Returns the size of each dimension for `T` known at compile time. If a dimension does not
388+ have a known size along a dimension then `nothing` is returned in its position.
389+ """
390+ @inline known_size (x, d) = known_size (x)[to_dims (x, d)]
391+ known_size (x) = known_size (typeof (x))
392+ known_size (:: Type{T} ) where {T} = _known_size (axes_types (T))
393+ @generated function _known_size (:: Type{Axs} ) where {Axs<: Tuple }
394+ out = Expr (:tuple )
395+ for p in Axs. parameters
396+ push! (out. args, :(known_length ($ p)))
397+ end
398+ return Expr (:block , Expr (:meta , :inline ), out)
399+ end
400+
401+ """
402+ known_strides(::Type{T}[, d]) -> Tuple
403+
404+ Returns the strides of array `A` known at compile time. Any strides that are not known at
405+ compile time are represented by `nothing`.
406+ """
407+ known_strides (x) = known_strides (typeof (x))
408+ known_strides (x, d) = known_strides (x)[to_dims (x, d)]
409+ known_strides (:: Type{T} ) where {T<: Vector } = (1 ,)
410+ @inline function known_strides (:: Type{T} ) where {T<: Adjoint{<:Any,<:AbstractVector} }
411+ strd = first (known_strides (parent_type (T)))
412+ return (strd, strd)
413+ end
414+ function known_strides (:: Type{T} ) where {T<: Adjoint }
415+ return permute (known_strides (parent_type (T)), Val {(2,1)} ())
416+ end
417+ function known_strides (:: Type{T} ) where {T<: Transpose }
418+ return permute (known_strides (parent_type (T)), Val {(2,1)} ())
419+ end
420+ @inline function known_strides (:: Type{T} ) where {T<: Transpose{<:Any,<:AbstractVector} }
421+ strd = first (known_strides (parent_type (T)))
422+ return (strd, strd)
423+ end
424+ @inline function known_strides (:: Type{T} ) where {I1,T<: PermutedDimsArray{<:Any,<:Any,I1} }
425+ return permute (known_strides (parent_type (T)), Val {I1} ())
426+ end
427+ @inline function known_strides (:: Type{T} ) where {I1,T<: SubArray{<:Any,<:Any,<:Any,I1} }
428+ return _sub_strides (Val (ArrayStyle (T)), I1, Val (known_strides (parent_type (T))))
429+ end
430+
431+ @generated function _sub_strides (:: Val{S} , :: Type{I} , :: Val{P} ) where {S,I<: Tuple ,P}
432+ out = Expr (:tuple )
433+ d = 1
434+ for i in I. parameters
435+ ad = argdims (S, i)
436+ if ad > 0
437+ push! (out. args, P[d])
438+ d += ad
439+ else
440+ d += 1
441+ end
442+ end
443+ Expr (:block , Expr (:meta , :inline ), out)
444+ end
445+
446+ function known_strides (:: Type{T} ) where {T}
447+ if ndims (T) === 1
448+ return (1 ,)
449+ else
450+ return _known_strides (Val (Base. front (known_size (T))))
451+ end
452+ end
453+ @generated function _known_strides (:: Val{S} ) where {S}
454+ out = Expr (:tuple )
455+ N = length (S)
456+ push! (out. args, 1 )
457+ for s in S
458+ if s === nothing || out. args[end ] === nothing
459+ push! (out. args, nothing )
460+ else
461+ push! (out. args, out. args[end ] * s)
462+ end
463+ end
464+ return Expr (:block , Expr (:meta , :inline ), out)
465+ end
466+
467+ """
468+ offsets(A) -> Tuple
279469
280470Returns offsets of indices with respect to 0. If values are known at compile time,
281471it should return them as `Static` numbers.
294484 strd = stride (parent (x), One ())
295485 (strd, strd)
296486end
297-
487+
298488@generated function _strides (A:: AbstractArray{T,N} , s:: NTuple{N} , :: Contiguous{C} ) where {T,N,C}
299489 if C ≤ 0 || C > N
300490 return Expr (:block , Expr (:meta ,:inline ), :s )
@@ -325,15 +515,11 @@ if VERSION ≥ v"1.6.0-DEV.1581"
325515 quote
326516 $ (Expr (:meta ,:inline ))
327517 @inbounds $ stup
328- end
518+ end
329519 end
330520end
331521
332- @inline function offsets (x, i)
333- inds = indices (x, i)
334- start = known_first (inds)
335- isnothing (start) ? first (inds) : StaticInt (start)
336- end
522+ @inline offsets (x, i) = static_first (indices (x, i))
337523# @inline offsets(A::AbstractArray{<:Any,N}) where {N} = ntuple(n -> offsets(A, n), Val{N}())
338524# Explicit tuple needed for inference.
339525@generated function offsets (A:: AbstractArray{<:Any,N} ) where {N}
344530 Expr (:block , Expr (:meta , :inline ), t)
345531end
346532
533+ @inline size (v:: AbstractVector ) = (static_length (axes_types (v, 1 )),)
347534@inline size (B:: Union{Transpose{T,A},Adjoint{T,A}} ) where {T,A<: AbstractMatrix{T} } = permute (size (parent (B)), Val {(2,1)} ())
348535@inline size (B:: PermutedDimsArray{T,N,I1,I2,A} ) where {T,N,I1,I2,A<: AbstractArray{T,N} } = permute (size (parent (B)), Val {I1} ())
349536@inline size (A:: AbstractArray , :: StaticInt{N} ) where {N} = size (A)[N]
0 commit comments