@@ -28,60 +28,25 @@ argdims(s::ArrayStyle, arg) = argdims(s, typeof(arg))
2828argdims (:: ArrayStyle , :: Type{T} ) where {T} = static (0 )
2929argdims (:: ArrayStyle , :: Type{T} ) where {T<: Colon } = static (1 )
3030argdims (:: ArrayStyle , :: Type{T} ) where {T<: AbstractArray } = static (ndims (T))
31- argdims (:: ArrayStyle , :: Type{T} ) where {N,T<: CartesianIndex {N} } = static (N)
32- argdims (:: ArrayStyle , :: Type{T} ) where {N,T<: AbstractArray{CartesianIndex {N}} } = static (N)
31+ argdims (:: ArrayStyle , :: Type{T} ) where {N,T<: AbstractCartesianIndex {N} } = static (N)
32+ argdims (:: ArrayStyle , :: Type{T} ) where {N,T<: AbstractArray{<:AbstractCartesianIndex {N}} } = static (N)
3333argdims (:: ArrayStyle , :: Type{T} ) where {N,T<: AbstractArray{<:Any,N} } = static (N)
3434argdims (:: ArrayStyle , :: Type{T} ) where {N,T<: LogicalIndex{<:Any,<:AbstractArray{Bool,N}} } = static (N)
3535_argdims (s:: ArrayStyle , :: Type{I} , i:: StaticInt ) where {I} = argdims (s, _get_tuple (I, i))
3636function argdims (s:: ArrayStyle , :: Type{T} ) where {N,T<: Tuple{Vararg{Any,N}} }
3737 return eachop (_argdims, nstatic (Val (N)), s, T)
3838end
3939
40- is_element_index (i) = is_element_index (typeof (i))
41- is_element_index (:: Type{T} ) where {T} = static (false )
42- is_element_index (:: Type{T} ) where {T<: AbstractCartesianIndex } = static (true )
43- is_element_index (:: Type{T} ) where {T<: Integer } = static (true )
44- _is_element_index (:: Type{T} , i:: StaticInt ) where {T} = is_element_index (_get_tuple (T, i))
45- function is_element_index (:: Type{T} ) where {N,T<: Tuple{Vararg{Any,N}} }
46- return static (all (eachop (_is_element_index, nstatic (Val (N)), T)))
47- end
48-
49- """
50- UnsafeIndex(::ArrayStyle, ::Type{I})
51-
52- `UnsafeIndex` controls how indices that have been bounds checked and converted to
53- native axes' indices are used to return the stored values of an array. For example,
54- if the indices at each dimension are single integers then `UnsafeIndex(array, inds)` returns
55- `UnsafeGetElement()`. Conversely, if any of the indices are vectors then `UnsafeGetCollection()`
56- is returned, indicating that a new array needs to be reconstructed. This method permits
57- customizing the terminal behavior of the indexing pipeline based on arguments passed
58- to `ArrayInterface.getindex`. New subtypes of `UnsafeIndex` should define `promote_rule`.
59- """
60- abstract type UnsafeIndex end
61-
62- struct UnsafeGetElement <: UnsafeIndex end
63-
64- struct UnsafeGetCollection <: UnsafeIndex end
65-
66- UnsafeIndex (x, i) = UnsafeIndex (x, typeof (i))
67- UnsafeIndex (x, :: Type{I} ) where {I} = UnsafeIndex (ArrayStyle (x), I)
68- UnsafeIndex (s:: ArrayStyle , i) = UnsafeIndex (s, typeof (i))
69- UnsafeIndex (:: ArrayStyle , :: Type{I} ) where {I} = UnsafeGetElement ()
70- UnsafeIndex (:: ArrayStyle , :: Type{I} ) where {I<: AbstractArray } = UnsafeGetCollection ()
71-
72- Base. promote_rule (:: Type{X} , :: Type{Y} ) where {X<: UnsafeIndex ,Y<: UnsafeGetElement } = X
73-
74- @generated function UnsafeIndex (s:: ArrayStyle , :: Type{T} ) where {N,T<: Tuple{Vararg{Any,N}} }
75- if N === 0
76- return UnsafeGetElement ()
77- else
78- e = Expr (:call , promote_type)
79- for p in T. parameters
80- push! (e. args, :(typeof (ArrayInterface. UnsafeIndex (s, $ p))))
81- end
82- return Expr (:block , Expr (:meta , :inline ), Expr (:call , e))
83- end
40+ _is_element_index (i) = _is_element_index (typeof (i))
41+ _is_element_index (:: Type{T} ) where {T} = static (false )
42+ _is_element_index (:: Type{T} ) where {T<: AbstractCartesianIndex } = static (true )
43+ _is_element_index (:: Type{T} ) where {T<: Integer } = static (true )
44+ __is_element_index (:: Type{T} , i:: StaticInt ) where {T} = _is_element_index (_get_tuple (T, i))
45+ function _is_element_index (:: Type{T} ) where {N,T<: Tuple{Vararg{Any,N}} }
46+ return static (all (eachop (__is_element_index, nstatic (Val (N)), T)))
8447end
48+ # empty tuples refer to the single element of 0-dimensional arrays
49+ _is_element_index (:: Type{Tuple{}} ) = static (true )
8550
8651# are the indexing arguments provided a linear collection into a multidim collection
8752is_linear_indexing (A, args:: Tuple{Arg} ) where {Arg} = argdims (A, Arg) < 2
@@ -181,6 +146,22 @@ to_index(::IndexLinear, axis, arg::CartesianIndices{1}) = axes(arg, 1)
181146@propagate_inbounds function to_index (:: IndexLinear , axis, arg:: AbstractCartesianIndex{1} )
182147 return to_index (axis, first (Tuple (arg)))
183148end
149+ function to_index (:: IndexLinear , x, arg:: AbstractCartesianIndex{N} ) where {N}
150+ inds = Tuple (arg)
151+ o = offsets (x)
152+ s = size (x)
153+ return first (inds) + (offset1 (x) - first (o)) + _subs2int (first (s), tail (s), tail (o), tail (inds))
154+ end
155+ @inline function _subs2int (stride, s:: Tuple{Any,Vararg} , o:: Tuple{Any,Vararg} , inds:: Tuple{Any,Vararg} )
156+ i = ((first (inds) - first (o)) * stride)
157+ return i + _subs2int (stride * first (s), tail (s), tail (o), tail (inds))
158+ end
159+ function _subs2int (stride, s:: Tuple{Any} , o:: Tuple{Any} , inds:: Tuple{Any} )
160+ return (first (inds) - first (o)) * stride
161+ end
162+ # trailing inbounds can only be 1 or 1:1
163+ _subs2int (stride, :: Tuple{} , :: Tuple{} , :: Tuple{Any} ) = static (0 )
164+
184165@propagate_inbounds function to_index (:: IndexLinear , x, arg:: Union{Array{Bool}, BitArray} )
185166 @boundscheck checkbounds (x, arg)
186167 return LogicalIndex {Int} (arg)
194175 return arg
195176end
196177@propagate_inbounds function to_index (:: IndexLinear , x, arg:: Integer )
197- @boundscheck checkindex (Bool, x , arg) || throw (BoundsError (x, arg))
178+ @boundscheck checkindex (Bool, indices (x) , arg) || throw (BoundsError (x, arg))
198179 return _int (arg)
199180end
200181@propagate_inbounds function to_index (:: IndexLinear , axis, arg:: AbstractArray{Bool} )
@@ -209,25 +190,11 @@ end
209190 @boundscheck checkindex (Bool, indices (axis), arg) || throw (BoundsError (axis, arg))
210191 return static_first (arg): static_step (arg): static_last (arg)
211192end
212- to_index (:: IndexLinear , x, inds:: Tuple{Any} ) = first (inds)
213- function to_index (:: IndexLinear , x, inds:: Tuple{Any,Vararg{Any}} )
214- o = offsets (x)
215- s = size (x)
216- return first (inds) + (offset1 (x) - first (o)) + _subs2int (first (s), tail (s), tail (o), tail (inds))
217- end
218- @inline function _subs2int (stride, s:: Tuple{Any,Vararg} , o:: Tuple{Any,Vararg} , inds:: Tuple{Any,Vararg} )
219- i = ((first (inds) - first (o)) * stride)
220- return i + _subs2int (stride * first (s), tail (s), tail (o), tail (inds))
221- end
222- function _subs2int (stride, s:: Tuple{Any} , o:: Tuple{Any} , inds:: Tuple{Any} )
223- return (first (inds) - first (o)) * stride
224- end
225- # trailing inbounds can only be 1 or 1:1
226- _subs2int (stride, :: Tuple{} , :: Tuple{} , :: Tuple{Any} ) = static (0 )
227193
228194# # IndexCartesian ##
229195to_index (:: IndexCartesian , x, arg:: Colon ) = CartesianIndices (x)
230196to_index (:: IndexCartesian , x, arg:: CartesianIndices{0} ) = arg
197+ to_index (:: IndexCartesian , x, arg:: AbstractCartesianIndex ) = arg
231198function to_index (:: IndexCartesian , x, arg)
232199 @boundscheck _multi_check_index (axes (x), arg) || throw (BoundsError (x, arg))
233200 return arg
@@ -253,15 +220,13 @@ end
253220 @boundscheck checkbounds (x, arg)
254221 return LogicalIndex {Int} (arg)
255222end
256- to_index (:: IndexCartesian , x, i:: Integer ) = _int2subs (axes (x), i - offset1 (x))
257- @inline function _int2subs (axs:: Tuple{Any,Vararg{Any}} , i)
258- axis = first (axs)
259- len = static_length (axis)
223+ to_index (:: IndexCartesian , x, i:: Integer ) = NDIndex (_int2subs (offsets (x), size (x), i - offset1 (x)))
224+ @inline function _int2subs (o:: Tuple{Any,Vararg{Any}} , s:: Tuple{Any,Vararg{Any}} , i)
225+ len = first (s)
260226 inext = div (i, len)
261- return (_int (i - len * inext + static_first (axis )), _int2subs (tail (axs ), inext)... )
227+ return (_int (i - len * inext + first (o )), _int2subs (tail (o), tail (s ), inext)... )
262228end
263- _int2subs (axs:: Tuple{Any} , i) = _int (i + static_first (first (axs)))
264-
229+ _int2subs (o:: Tuple{Any} , s:: Tuple{Any} , i) = _int (i + first (o))
265230
266231"""
267232 unsafe_reconstruct(A, data; kwargs...)
353318end
354319to_axis (S:: IndexLinear , axis, inds) = StaticInt (1 ): static_length (inds)
355320
321+ # ###############
322+ # ## getindex ###
323+ # ###############
356324"""
357325 ArrayInterface.getindex(A, args...)
358326
@@ -362,14 +330,19 @@ Changing indexing based on a given argument from `args` should be done through,
362330[`to_index`](@ref), or [`to_axis`](@ref).
363331"""
364332@propagate_inbounds getindex (A, args... ) = unsafe_get_index (A, to_indices (A, args))
365- @propagate_inbounds getindex (A; kwargs... ) = A[order_named_inds (dimnames (A), kwargs. data)... ]
333+ @propagate_inbounds function getindex (A; kwargs... )
334+ return unsafe_get_index (A, to_indices (A, order_named_inds (dimnames (A), kwargs. data)))
335+ end
366336@propagate_inbounds getindex (x:: Tuple , i:: Int ) = getfield (x, i)
367337@propagate_inbounds getindex (x:: Tuple , :: StaticInt{i} ) where {i} = getfield (x, i)
368338
369339# # unsafe_get_index ##
370- unsafe_get_index (A, inds:: Tuple ) = _unsafe_get_index (is_element_index (inds), A, inds)
371- _unsafe_get_index (:: True , A, inds:: Tuple ) = unsafe_get_element (A, inds)
340+ unsafe_get_index (A, inds:: Tuple ) = _unsafe_get_index (_is_element_index (inds), A, inds)
372341_unsafe_get_index (:: False , A, inds:: Tuple ) = unsafe_get_collection (A, inds)
342+ _unsafe_get_index (:: True , A, inds:: Tuple ) = __unsafe_get_index (A, inds)
343+ __unsafe_get_index (A, inds:: Tuple{} ) = unsafe_get_element (A, ())
344+ __unsafe_get_index (A, inds:: Tuple{Any} ) = unsafe_get_element (A, first (inds))
345+ __unsafe_get_index (A, inds:: Tuple{Any,Vararg{Any}} ) = unsafe_get_element (A, NDIndex (inds))
373346
374347"""
375348 unsafe_get_element(A::AbstractArray{T}, inds::Tuple) -> T
@@ -380,22 +353,30 @@ must define `unsafe_get_element(::NewArrayType, inds)`.
380353"""
381354unsafe_get_element (a:: A , inds) where {A} = _unsafe_get_element (has_parent (A), a, inds)
382355_unsafe_get_element (:: True , a, inds) = unsafe_get_element (parent (a), inds)
383- _unsafe_get_element (:: False , a, inds) = @inbounds (parent (a)[inds... ])
384- _unsafe_get_element (:: False , a:: AbstractArray2 , inds) = unsafe_get_element_error (a, inds)
356+ _unsafe_get_element (:: False , a, inds) = @inbounds (parent (a)[inds])
357+ _unsafe_get_element (:: False , a:: AbstractArray2 , i) = unsafe_get_element_error (a, i)
358+
359+ # # Array ##
385360unsafe_get_element (A:: Array , :: Tuple{} ) = Base. arrayref (false , A, 1 )
386- unsafe_get_element (A:: Array , inds) = Base. arrayref (false , A, Int (to_index (A, inds)))
387- unsafe_get_element (A:: LinearIndices , inds) = Int (to_index (A, inds))
388- @inline function unsafe_get_element (A:: CartesianIndices , inds)
389- if length (inds) === 1
390- return CartesianIndex (to_index (A, first (inds)))
391- else
392- return CartesianIndex (Base. _to_subscript_indices (A, inds... ))
393- end
361+ unsafe_get_element (A:: Array , i:: Integer ) = Base. arrayref (false , A, Int (i))
362+ unsafe_get_element (A:: Array , i:: NDIndex ) = unsafe_get_element (A, to_index (A, i))
363+
364+ # # LinearIndices ##
365+ unsafe_get_element (A:: LinearIndices , i:: Integer ) = Int (i)
366+ unsafe_get_element (A:: LinearIndices , i:: NDIndex ) = unsafe_get_element (A, to_index (A, i))
367+
368+ unsafe_get_element (A:: CartesianIndices , i:: NDIndex ) = CartesianIndex (i)
369+ unsafe_get_element (A:: CartesianIndices , i:: Integer ) = unsafe_get_element (A, to_index (A, i))
370+
371+ unsafe_get_element (A:: ReshapedArray , i:: Integer ) = unsafe_get_element (parent (A), i)
372+ function unsafe_get_element (A:: ReshapedArray , i:: NDIndex )
373+ return unsafe_get_element (parent (A), to_index (IndexLinear (), A, i))
394374end
395- unsafe_get_element (A:: ReshapedArray , inds) = @inbounds (A[inds... ])
396- unsafe_get_element (A:: SubArray , inds) = @inbounds (A[inds... ])
397375
398- unsafe_get_element_error (A, inds) = throw (MethodError (unsafe_get_element, (A, inds)))
376+ unsafe_get_element (A:: SubArray , i) = @inbounds (A[i])
377+ function unsafe_get_element_error (@nospecialize (A), @nospecialize (i))
378+ throw (MethodError (unsafe_get_element, (A, i)))
379+ end
399380
400381# This is based on Base._unsafe_getindex from https://github.com/JuliaLang/julia/blob/c5ede45829bf8eb09f2145bfd6f089459d77b2b1/base/multidimensional.jl#L755.
401382"""
@@ -424,7 +405,7 @@ function _generate_unsafe_get_index!_body(N::Int)
424405 # the optimizer is not clever enough to split the union without it
425406 Dy === nothing && return dest
426407 (idx, state) = Dy
427- dest[idx] = unsafe_get_element (src, Base. Cartesian. @ntuple ($ N, j))
408+ dest[idx] = unsafe_get_element (src, NDIndex ( Base. Cartesian. @ntuple ($ N, j) ))
428409 Dy = iterate (D, state)
429410 end
430411 return dest
@@ -453,37 +434,36 @@ end
453434 end
454435end
455436
437+ # ################
438+ # ## setindex! ###
439+ # ################
456440"""
457441 ArrayInterface.setindex!(A, args...)
458442
459443Store the given values at the given key or index within a collection.
460444"""
461445@propagate_inbounds function setindex! (A, val, args... )
462446 if can_setindex (A)
463- return unsafe_setindex ! (A, val, to_indices (A, args))
447+ return unsafe_set_index ! (A, val, to_indices (A, args))
464448 else
465449 error (" Instance of type $(typeof (A)) are not mutable and cannot change elements after construction." )
466450 end
467451end
468452@propagate_inbounds function setindex! (A, val; kwargs... )
469- if has_dimnames (A)
470- return setindex! (A, val, order_named_inds (dimnames (A), kwargs. data)... )
471- else
472- return unsafe_setindex! (A, val, to_indices (A, ()))
473- end
453+ return unsafe_set_index! (A, val, to_indices (A, order_named_inds (dimnames (A), kwargs. data)))
474454end
475455
476- """
477- unsafe_setindex!(A, val, inds::Tuple)
478-
479- Sets indices (`inds`) of `A` to `val`. This method assumes that `inds` have already been
480- bounds-checked. This step of the processing pipeline can be customized by:
481- """
482- unsafe_setindex! (A, val, i:: Tuple ) = unsafe_setindex! (UnsafeIndex (A, i), A, val, i)
483- unsafe_setindex! (:: UnsafeGetElement , A, val, i:: Tuple ) = unsafe_set_element! (A, val, i)
484- unsafe_setindex! (:: UnsafeGetCollection , A, v, i:: Tuple ) = unsafe_set_collection! (A, v, i)
456+ unsafe_set_index! (A, v, inds:: Tuple ) = _unsafe_set_index! (_is_element_index (inds), A, v, inds)
457+ _unsafe_set_index! (:: False , A, v, inds:: Tuple ) = unsafe_set_collection! (A, v, inds)
458+ _unsafe_set_index! (:: True , A, v, inds:: Tuple ) = __unsafe_set_index! (A, v, inds)
459+ __unsafe_set_index! (A, v, inds:: Tuple{} ) = unsafe_set_element! (A, v, ())
460+ function __unsafe_set_index! (A, v, inds:: Tuple{Any} )
461+ return unsafe_set_element! (A, v, to_index (A, first (inds)))
462+ end
463+ function __unsafe_set_index! (A, v, inds:: Tuple{Any,Vararg{Any}} )
464+ return unsafe_set_element! (A, v, to_index (A, NDIndex (inds)))
465+ end
485466
486- unsafe_set_element_error (A, v, i) = throw (MethodError (unsafe_set_element!, (A, v, i)))
487467
488468"""
489469 unsafe_set_element!(A, val, inds::Tuple)
@@ -494,19 +474,18 @@ must define `unsafe_set_element!(::NewArrayType, val, inds)`.
494474"""
495475unsafe_set_element! (a, val, inds) = _unsafe_set_element! (has_parent (a), a, val, inds)
496476_unsafe_set_element! (:: True , a, val, inds) = unsafe_set_element! (parent (a), val, inds)
497- _unsafe_set_element! (:: False , a, val,inds) = @inbounds (parent (a)[inds... ] = val)
477+ _unsafe_set_element! (:: False , a, val, inds) = @inbounds (parent (a)[inds] = val)
478+
498479function _unsafe_set_element! (:: False , a:: AbstractArray2 , val, inds)
499480 unsafe_set_element_error (a, val, inds)
500481end
482+ unsafe_set_element_error (A, v, i) = throw (MethodError (unsafe_set_element!, (A, v, i)))
501483
502- function unsafe_set_element! (A:: Array{T} , val, inds:: Tuple ) where {T}
503- if length (inds) === 0
504- return Base. arrayset (false , A, convert (T, val):: T , 1 )
505- elseif inds isa Tuple{Vararg{Int}}
506- return Base. arrayset (false , A, convert (T, val):: T , inds... )
507- else
508- throw (MethodError (unsafe_set_element!, (A, inds)))
509- end
484+ function unsafe_set_element! (A:: Array{T} , val, :: Tuple{} ) where {T}
485+ Base. arrayset (false , A, convert (T, val):: T , 1 )
486+ end
487+ function unsafe_set_element! (A:: Array{T} , val, i:: Integer ) where {T}
488+ return Base. arrayset (false , A, convert (T, val):: T , Int (i))
510489end
511490
512491# This is based on Base._unsafe_setindex!.
@@ -529,7 +508,7 @@ function _generate_unsafe_setindex!_body(N::Int)
529508 # the optimizer that it does not need to emit error paths
530509 Xy === nothing && break
531510 (val, state) = Xy
532- unsafe_set_element! (A, val, Base. Cartesian. @ntuple ($ N, i))
511+ unsafe_set_element! (A, val, NDIndex ( Base. Cartesian. @ntuple ($ N, i) ))
533512 Xy = iterate (x′, state)
534513 end
535514 A
0 commit comments