8787
8888
8989"""
90- contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<: Val}}
90+ contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{Val}}
9191
9292Returns a tuple boolean `Val`s indicating whether that axis is contiguous.
9393"""
@@ -98,53 +98,6 @@ contiguous_axis_indicator(::Nothing, ::Val) = nothing
9898Base. @pure contiguous_axis_indicator (:: Contiguous{N} , :: Val{D} ) where {N,D} =
9999 ntuple (d -> Val {d == N} (), Val {D} ())
100100
101- """
102- If the contiguous dimension is not the dimension with `StrideRank{1}`:
103- """
104- struct ContiguousBatch{N} end
105- Base. @pure ContiguousBatch (N:: Int ) = ContiguousBatch {N} ()
106- _get (:: ContiguousBatch{N} ) where {N} = N
107-
108- """
109- contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
110-
111- Returns the Base.size of contiguous batches if `!isone(stride_rank(T, contiguous_axis(T)))`.
112- If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `ContiguousBatch{0}()`.
113- If `contiguous_axis(T) == -1`, it will return `ContiguousBatch{-1}()`.
114- If unknown, it will return `nothing`.
115- """
116- contiguous_batch_size (x) = contiguous_batch_size (typeof (x))
117- contiguous_batch_size (:: Type ) = nothing
118- contiguous_batch_size (:: Type{Array{T,N}} ) where {T,N} = ContiguousBatch {0} ()
119- contiguous_batch_size (:: Type{<:Tuple} ) = ContiguousBatch {0} ()
120- contiguous_batch_size (
121- :: Type{<:Union{Transpose{T,A},Adjoint{T,A}}} ,
122- ) where {T,A<: AbstractVecOrMat{T} } = contiguous_batch_size (A)
123- contiguous_batch_size (
124- :: Type{<:PermutedDimsArray{T,N,I1,I2,A}} ,
125- ) where {T,N,I1,I2,A<: AbstractArray{T,N} } = contiguous_batch_size (A)
126- function contiguous_batch_size (
127- :: Type{S} ,
128- ) where {N,NP,T,A<: AbstractArray{T,NP} ,I,S<: SubArray{T,N,A,I} }
129- _contiguous_batch_size (S, contiguous_batch_size (A), contiguous_axis (A))
130- end
131- _contiguous_batch_size (:: Any , :: Any , :: Any ) = nothing
132- @generated function _contiguous_batch_size (
133- :: Type{S} ,
134- :: ContiguousBatch{B} ,
135- :: Contiguous{C} ,
136- ) where {B,C,N,NP,T,A<: AbstractArray{T,NP} ,I,S<: SubArray{T,N,A,I} }
137- if I. parameters[C] <: AbstractUnitRange
138- Expr (:call , Expr (:curly , :ContiguousBatch , B))
139- else
140- Expr (:call , Expr (:curly , :ContiguousBatch , - 1 ))
141- end
142- end
143-
144- contiguous_batch_size (
145- :: Type{R} ,
146- ) where {T,N,S,A<: Array{S} ,R<: Base.ReinterpretArray{T,N,S,A} } = ContiguousBatch {0} ()
147-
148101struct StrideRank{R} end
149102Base. @pure StrideRank (R:: NTuple{<:Any,Int} ) = StrideRank {R} ()
150103_get (:: StrideRank{R} ) where {R} = R
@@ -230,6 +183,67 @@ stride_rank(x, i) = stride_rank(x)[i]
230183stride_rank (:: Type{R} ) where {T,N,S,A<: Array{S} ,R<: Base.ReinterpretArray{T,N,S,A} } =
231184 StrideRank {ntuple(identity, Val{N}())} ()
232185
186+ function stride_rank (:: Type {Base. ReshapedArray{T, N, P, Tuple{Vararg{Base. SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
187+
188+ _reshaped_striderank (is_column_major (P), Val {N} (), Val {M} ())
189+ end
190+ _reshaped_striderank (:: Val{true} , :: Val{N} , :: Val{0} ) where {N} = StrideRank {ntuple(identity, Val{N}())} ()
191+ _reshaped_striderank (_, __, ___) = nothing
192+
193+
194+ """
195+ If the contiguous dimension is not the dimension with `StrideRank{1}`:
196+ """
197+ struct ContiguousBatch{N} end
198+ Base. @pure ContiguousBatch (N:: Int ) = ContiguousBatch {N} ()
199+ _get (:: ContiguousBatch{N} ) where {N} = N
200+
201+ """
202+ contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
203+
204+ Returns the Base.size of contiguous batches if `!isone(stride_rank(T, contiguous_axis(T)))`.
205+ If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `ContiguousBatch{0}()`.
206+ If `contiguous_axis(T) == -1`, it will return `ContiguousBatch{-1}()`.
207+ If unknown, it will return `nothing`.
208+ """
209+ contiguous_batch_size (x) = contiguous_batch_size (typeof (x))
210+ contiguous_batch_size (:: Type{T} ) where {T} = _contiguous_batch_size (contiguous_axis (T), stride_rank (T))
211+ _contiguous_batch_size (_, __) = nothing
212+ @generated function _contiguous_batch_size (:: Contiguous{D} , :: StrideRank{R} ) where {D,R}
213+ isone (R[D]) ? :(ContiguousBatch {0} ()) : :nothing
214+ end
215+
216+ contiguous_batch_size (:: Type{Array{T,N}} ) where {T,N} = ContiguousBatch {0} ()
217+ contiguous_batch_size (:: Type{<:Tuple} ) = ContiguousBatch {0} ()
218+ contiguous_batch_size (
219+ :: Type{<:Union{Transpose{T,A},Adjoint{T,A}}} ,
220+ ) where {T,A<: AbstractVecOrMat{T} } = contiguous_batch_size (A)
221+ contiguous_batch_size (
222+ :: Type{<:PermutedDimsArray{T,N,I1,I2,A}} ,
223+ ) where {T,N,I1,I2,A<: AbstractArray{T,N} } = contiguous_batch_size (A)
224+ function contiguous_batch_size (
225+ :: Type{S} ,
226+ ) where {N,NP,T,A<: AbstractArray{T,NP} ,I,S<: SubArray{T,N,A,I} }
227+ _contiguous_batch_size (S, contiguous_batch_size (A), contiguous_axis (A))
228+ end
229+ _contiguous_batch_size (:: Any , :: Any , :: Any ) = nothing
230+ @generated function _contiguous_batch_size (
231+ :: Type{S} ,
232+ :: ContiguousBatch{B} ,
233+ :: Contiguous{C} ,
234+ ) where {B,C,N,NP,T,A<: AbstractArray{T,NP} ,I,S<: SubArray{T,N,A,I} }
235+ if I. parameters[C] <: AbstractUnitRange
236+ Expr (:call , Expr (:curly , :ContiguousBatch , B))
237+ else
238+ Expr (:call , Expr (:curly , :ContiguousBatch , - 1 ))
239+ end
240+ end
241+
242+ contiguous_batch_size (
243+ :: Type{R} ,
244+ ) where {T,N,S,A<: Array{S} ,R<: Base.ReinterpretArray{T,N,S,A} } = ContiguousBatch {0} ()
245+
246+
233247"""
234248 is_column_major(A) -> Val{true/false}()
235249
@@ -260,7 +274,8 @@ An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A
260274"""
261275dense_dims (x) = dense_dims (typeof (x))
262276dense_dims (:: Type ) = nothing
263- dense_dims (:: Type{Array{T,N}} ) where {T,N} = DenseDims {ntuple(_ -> true, Val{N}())} ()
277+ _all_dense (:: Val{N} ) where {N} = DenseDims {ntuple(_ -> true, Val{N}())} ()
278+ dense_dims (:: Type{Array{T,N}} ) where {T,N} = _all_dense (Val {N} ())
264279dense_dims (:: Type{<:Tuple} ) = DenseDims {(true,)} ()
265280function dense_dims (
266281 :: Type{<:Union{Transpose{T,A},Adjoint{T,A}}} ,
@@ -306,6 +321,15 @@ _dense_dims(::Any, ::Any) = nothing
306321 length (dense_tup. args) == N ? Expr (:call , Expr (:curly , :DenseDims , dense_tup)) : nothing
307322end
308323
324+ function dense_dims (:: Type {Base. ReshapedArray{T, N, P, Tuple{Vararg{Base. SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
325+
326+ _reshaped_dense_dims (dense_dims (P), is_column_major (P), Val {N} (), Val {M} ())
327+ end
328+ _reshaped_dense_dims (_, __, ___, ____) = nothing
329+ @generated function _reshaped_dense_dims (:: DenseDims{D} , :: Val{true} , :: Val{N} , :: Val{0} ) where {D,N}
330+ all (D) ? :(_all_dense (Val {$N} ())) : :nothing
331+ end
332+
309333permute (t:: NTuple{N} , I:: NTuple{N,Int} ) where {N} = ntuple (n -> t[I[n]], Val {N} ())
310334@generated function permute (t:: Tuple{Vararg{Any,N}} , :: Val{I} ) where {N,I}
311335 t = Expr (:tuple )
0 commit comments