7474
7575
7676"""
77- contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}}
77+ contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}}
7878
7979Returns a tuple boolean `Val`s indicating whether that axis is contiguous.
8080"""
@@ -84,14 +84,14 @@ contiguous_axis_indicator(::Nothing, ::Val) = nothing
8484Base. @pure contiguous_axis_indicator (:: Contiguous{N} , :: Val{D} ) where {N,D} = ntuple (d -> Val {d == N} (), Val {D} ())
8585
8686"""
87- If the contiguous dimension is not the dimension with `Stride_rank {1}`:
87+ If the contiguous dimension is not the dimension with `StrideRank {1}`:
8888"""
8989struct ContiguousBatch{N} end
9090Base. @pure ContiguousBatch (N:: Int ) = ContiguousBatch {N} ()
9191_get (:: ContiguousBatch{N} ) where {N} = N
9292
9393"""
94- contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
94+ contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
9595
9696Returns the Base.size of contiguous batches if `!isone(stride_rank(T, contiguous_axis(T)))`.
9797If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `ContiguousBatch{0}()`.
@@ -126,7 +126,7 @@ Base.collect(::StrideRank{R}) where {R} = collect(R)
126126@inline Base. getindex (:: StrideRank{R} , :: Val{I} ) where {R,I} = StrideRank {permute(R, I)} ()
127127
128128"""
129- rank_to_sortperm(::StrideRank) -> NTuple{N,Int}
129+ rank_to_sortperm(::StrideRank) -> NTuple{N,Int}
130130
131131Returns the `sortperm` of the stride ranks.
132132"""
@@ -197,7 +197,7 @@ Base.@pure DenseDims(D::NTuple{<:Any,Bool}) = DenseDims{D}()
197197@inline Base. getindex (:: DenseDims{D} , i:: Integer ) where {D} = D[i]
198198@inline Base. getindex (:: DenseDims{D} , :: Val{I} ) where {D,I} = DenseDims {permute(D, I)} ()
199199"""
200- dense_dims(::Type{T}) -> NTuple{N,Bool}
200+ dense_dims(::Type{T}) -> NTuple{N,Bool}
201201
202202Returns a tuple of indicators for whether each axis is dense.
203203An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A, j)` where `stride_rank(A)[i] + 1 == stride_rank(A)[j]`.
@@ -274,8 +274,217 @@ while still producing correct behavior when using valid cartesian indices, such
274274strides (A) = Base. strides (A)
275275strides (A, d) = strides (A)[to_dims (A, d)]
276276
277+ @generated function _perm_tuple (:: Type{T} , :: Val{P} ) where {T,P}
278+ out = Expr (:curly , :Tuple )
279+ for p in P
280+ push! (out. args, :(T. parameters[$ p]))
281+ end
282+ Expr (:block , Expr (:meta , :inline ), out)
283+ end
284+
285+ """
286+ axes_types(::Type{T}[, d]) -> Type
287+
288+ Returns the type of the axes for `T`
289+ """
290+ axes_types (x) = axes_types (typeof (x))
291+ axes_types (x, d) = axes_types (typeof (x), d)
292+ @inline axes_types (:: Type{T} , d) where {T} = axes_types (T). parameters[to_dims (T, d)]
293+ function axes_types (:: Type{T} ) where {T}
294+ if parent_type (T) <: T
295+ return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims (T)}}
296+ else
297+ return axes_types (parent_type (T))
298+ end
299+ end
300+ axes_types (:: Type{T} ) where {T<: Adjoint } = _perm_tuple (axes_types (parent_type (T)), Val ((2 , 1 )))
301+ axes_types (:: Type{T} ) where {T<: Transpose } = _perm_tuple (axes_types (parent_type (T)), Val ((2 , 1 )))
302+ function axes_types (:: Type{T} ) where {I1,T<: PermutedDimsArray{<:Any,<:Any,I1} }
303+ return _perm_tuple (axes_types (parent_type (T)), Val (I1))
304+ end
305+ function axes_types (:: Type{T} ) where {T<: OptionallyStaticRange }
306+ if known_length (T) === nothing
307+ return Tuple{OptionallyStaticUnitRange{One,Int}}
308+ else
309+ return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length (T) - 1 }}}
310+ end
311+ end
312+
313+ @inline function axes_types (:: Type{T} ) where {P,I,T<: SubArray{<:Any,<:Any,P,I} }
314+ return _sub_axes_types (Val (ArrayStyle (T)), I, axes_types (P))
315+ end
316+ @generated function _sub_axes_types (:: Val{S} , :: Type{I} , :: Type{PI} ) where {S,I<: Tuple ,PI<: Tuple }
317+ out = Expr (:curly , :Tuple )
318+ d = 1
319+ for i in I. parameters
320+ ad = argdims (S, i)
321+ if ad > 0
322+ push! (out. args, :(sub_axis_type ($ (PI. parameters[d]), $ i)))
323+ d += ad
324+ else
325+ d += 1
326+ end
327+ end
328+ Expr (:block , Expr (:meta , :inline ), out)
329+ end
330+
331+ @inline function axes_types (:: Type{T} ) where {T<: Base.ReinterpretArray }
332+ return _reinterpret_axes_types (axes_type (parent_type (T)), eltype (T), eltype (parent_type (T)))
333+ end
334+ @generated function _reinterpret_axes_types (:: Type{I} , :: Type{T} , :: Type{S} ) where {I<: Tuple ,T,S}
335+ out = Expr (:curly , :Tuple )
336+ for i in 1 : length (T. parameters)
337+ if i === 1
338+ push! (out. args, :(reinterpret_axis_type ($ (I. parameters[1 ]), $ T, $ S)))
339+ else
340+ # FIXME double check this once I've slept
341+ push! (out. args, :($ (I. parameters[i])))
342+ end
343+ end
344+ Expr (:block , Expr (:meta , :inline ), out)
345+ end
346+
347+
348+ # These methods help handle identifying axes that dont' directly propagate from the
349+ # parent array axes. They may be worth making a formal part of the API, as they provide
350+ # a low traffic spot to change what axes_types produces.
351+ @inline function sub_axis_type (:: Type{A} , :: Type{I} ) where {A,I}
352+ if known_length (I) === nothing
353+ return OptionallyStaticUnitRange{One,Int}
354+ else
355+ return OptionallyStaticUnitRange{One,StaticInt{known_length (I)}}
356+ end
357+ end
358+
359+ @inline function reinterpret_axis_type (:: Type{A} , :: Type{T} , :: Type{S} ) where {A,T,S}
360+ if known_length (A) === nothing
361+ return OptionallyStaticUnitRange{One,Int}
362+ else
363+ return OptionallyStaticUnitRange{One,StaticInt{Int (known_length (A) / (sizeof (T) / sizeof (S))) - 1 }}
364+ end
365+ end
366+
367+ """
368+ known_offsets(::Type{T}[, d]) -> Tuple
369+
370+ Returns a tuple of offset values known at compile time. If the offset of a given axis is
371+ not known at compile time `nothing` is returned its position.
372+ """
373+ @inline known_offsets (x, d) = known_offsets (x)[to_dims (x, d)]
374+ known_offsets (x) = known_offsets (typeof (x))
375+ @generated function known_offsets (:: Type{T} ) where {T}
376+ out = Expr (:tuple )
377+ for p in axes_types (T). parameters
378+ push! (out. args, known_first (p))
379+ end
380+ return out
381+ end
382+
383+ """
384+ known_size(::Type{T}[, d]) -> Tuple
385+
386+ Returns the size of each dimension for `T` known at compile time. If a dimension does not
387+ have a known size along a dimension then `nothing` is returned in its position.
388+ """
389+ @inline known_size (x, d) = known_size (x)[to_dims (x, d)]
390+ known_size (x) = known_size (typeof (x))
391+ known_size (:: Type{T} ) where {T} = _known_size (axes_types (T))
392+ @generated function _known_size (:: Type{Axs} ) where {Axs<: Tuple }
393+ out = Expr (:tuple )
394+ for p in Axs. parameters
395+ push! (out. args, :(known_length ($ p)))
396+ end
397+ return Expr (:block , Expr (:meta , :inline ), out)
398+ end
399+
277400"""
278- offsets(A)
401+ known_strides(::Type{T}[, d]) -> Tuple
402+ """
403+ known_strides (x) = known_strides (typeof (x))
404+ known_strides (x, d) = known_strides (x)[to_dims (x, d)]
405+ known_strides (:: Type{T} ) where {T<: Vector } = (1 ,)
406+ @inline function known_strides (:: Type{T} ) where {T<: Adjoint{<:Any,<:AbstractVector} }
407+ strd = first (known_strides (parent_type (T)))
408+ return (strd, strd)
409+ end
410+ function known_strides (:: Type{T} ) where {T<: Adjoint }
411+ return permute (known_strides (parent_type (T)), Val {(2,1)} ())
412+ end
413+ function known_strides (:: Type{T} ) where {T<: Transpose }
414+ return permute (known_strides (parent_type (T)), Val {(2,1)} ())
415+ end
416+ @inline function known_strides (:: Type{T} ) where {T<: Transpose{<:Any,<:AbstractVector} }
417+ strd = first (known_strides (parent_type (T)))
418+ return (strd, strd)
419+ end
420+ @inline function known_strides (:: Type{T} ) where {I1,T<: PermutedDimsArray{<:Any,<:Any,I1} }
421+ return permute (known_strides (parent_type (T)), Val {I1} ())
422+ end
423+ @inline function known_strides (:: Type{T} ) where {I1,T<: SubArray{<:Any,<:Any,<:Any,I1} }
424+ return _sub_strides (Val (ArrayStyle (T)), I1, Val (known_strides (parent_type (T))))
425+ end
426+
427+ @generated function _sub_strides (:: Val{S} , :: Type{I} , :: Val{P} ) where {S,I<: Tuple ,P}
428+ out = Expr (:tuple )
429+ d = 1
430+ for i in I. parameters
431+ ad = argdims (S, i)
432+ if ad > 0
433+ push! (out. args, P[d])
434+ d += ad
435+ else
436+ d += 1
437+ end
438+ end
439+ Expr (:block , Expr (:meta , :inline ), out)
440+ end
441+
442+ function known_strides (:: Type{T} ) where {T}
443+ if ndims (T) === 1
444+ return (1 ,)
445+ else
446+ return _known_strides (Val (Base. front (known_size (T))))
447+ end
448+ end
449+ @generated function _known_strides (:: Val{S} ) where {S}
450+ out = Expr (:tuple )
451+ N = length (S)
452+ push! (out. args, 1 )
453+ for s in S
454+ if s === nothing || out. args[end ] === nothing
455+ push! (out. args, nothing )
456+ else
457+ push! (out. args, out. args[end ] * s)
458+ end
459+ end
460+ return Expr (:block , Expr (:meta , :inline ), out)
461+ end
462+
463+ #=
464+
465+ function strides(a::ReinterpretArray)
466+ a.parent isa StridedArray || ArgumentError("Parent must be strided.") |> throw
467+ size_to_strides(1, size(a)...)
468+ end
469+
470+ strides(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}) = size_to_strides(1, size(a)...)
471+ @generated function _strides(_::Base.ReinterpretArray{T, N, S, A, true}, s::NTuple{N}, ::Contiguous{1}) where {T, N, S, D, A <: Array{S,D}}
472+ stup = Expr(:tuple, :(One()))
473+ if D < N
474+ push!(stup.args, Expr(:call, Expr(:curly, :StaticInt, sizeof(S) ÷ sizeof(T))))
475+ end
476+ for n ∈ 2+(D < N):N
477+ push!(stup.args, Expr(:ref, :s, n))
478+ end
479+ quote
480+ $(Expr(:meta,:inline))
481+ @inbounds $stup
482+ end
483+ end
484+ =#
485+
486+ """
487+ offsets(A) -> Tuple
279488
280489Returns offsets of indices with respect to 0. If values are known at compile time,
281490it should return them as `Static` numbers.
294503 strd = stride (parent (x), One ())
295504 (strd, strd)
296505end
297-
506+
298507@generated function _strides (A:: AbstractArray{T,N} , s:: NTuple{N} , :: Contiguous{C} ) where {T,N,C}
299508 if C ≤ 0 || C > N
300509 return Expr (:block , Expr (:meta ,:inline ), :s )
@@ -325,15 +534,11 @@ if VERSION ≥ v"1.6.0-DEV.1581"
325534 quote
326535 $ (Expr (:meta ,:inline ))
327536 @inbounds $ stup
328- end
537+ end
329538 end
330539end
331540
332- @inline function offsets (x, i)
333- inds = indices (x, i)
334- start = known_first (inds)
335- isnothing (start) ? first (inds) : StaticInt (start)
336- end
541+ @inline offsets (x, i) = static_first (indices (x, i))
337542# @inline offsets(A::AbstractArray{<:Any,N}) where {N} = ntuple(n -> offsets(A, n), Val{N}())
338543# Explicit tuple needed for inference.
339544@generated function offsets (A:: AbstractArray{<:Any,N} ) where {N}
0 commit comments