Skip to content

Commit dbd3b42

Browse files
committed
Safe co-iteration across an axis for 1+ arrays
1 parent 4a959b1 commit dbd3b42

File tree

3 files changed

+58
-0
lines changed

3 files changed

+58
-0
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ Returns the parent array that `x` wraps.
2020
Returns `true` if the size of `T` can change, in which case operations
2121
such as `pop!` and `popfirst!` are available for collections of type `T`.
2222

23+
## indices(x[, d])
24+
25+
Given an array `x`, this returns the indices along dimension `d`. If `x` is a tuple
26+
of arrays then the indices corresponding to dimension `d` of all arrays in `x` are
27+
returned. If any indices are not equal along dimension `d` an error is thrown. A
28+
tuple may be used to specify a different dimension for each array. If `d` is not
29+
specified then indices for visiting each index of `x` is returned.
30+
31+
2332
## ismutable(x)
2433

2534
A trait function for whether `x` is a mutable or immutable array. Used for

src/ArrayInterface.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,46 @@ known_step(x) = known_step(typeof(x))
542542
known_step(::Type{T}) where {T} = nothing
543543
known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)
544544

545+
546+
"""
547+
indices(x[, d]) -> AbstractRange
548+
549+
Given an array `x`, this returns the indices along dimension `d`. If `x` is a tuple
550+
of arrays then the indices corresponding to dimension `d` of all arrays in `x` are
551+
returned. If any indices are not equal along dimension `d` an error is thrown. A
552+
tuple may be used to specify a different dimension for each array. If `d` is not
553+
specified then indices for visiting each index of `x` is returned.
554+
"""
555+
@inline indices(x) = eachindex(x)
556+
557+
indices(x, d) = indices(axes(x, d))
558+
559+
@inline function indices(x::NTuple{N,<:Any}, dim) where {N}
560+
inds = indices(first(x), dim)
561+
@assert _check_indices(inds, Base.tail(x), dim) "The indices along dimension $dim are not equal for all $x"
562+
return inds
563+
end
564+
565+
@inline function indices(x::NTuple{N,<:Any}, dim::NTuple{N,<:Any}) where {N}
566+
ind = indices(first(x), first(dim))
567+
@assert _check_indices(ind, Base.tail(x), Base.tail(dim)) "The indices along dimension $dim are not equal for all $x"
568+
return ind
569+
end
570+
571+
@inline function _check_indices(ind, x::Tuple, dim::Tuple)
572+
for (x_i, d_i) in zip(x, dim)
573+
ind == indices(x_i, d_i) || return false
574+
end
575+
return true
576+
end
577+
578+
@inline function _check_indices(ind, x::Tuple, d)
579+
for x_i in x
580+
ind == indices(x_i, d) || return false
581+
end
582+
return true
583+
end
584+
545585
function __init__()
546586

547587
@require SuiteSparse="4607b0f0-06f3-5cda-b6b1-a6196a1729e9" begin

test/runtests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,12 @@ end
196196
@test !ArrayInterface.can_change_size(Tuple{})
197197
end
198198

199+
@testset "indices" begin
200+
@test @inferred(ArrayInterface.indices(ones(2, 3))) == 1:6
201+
@test @inferred(ArrayInterface.indices(ones(2, 3), 1)) == 1:2
202+
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(3, 2)), (1, 2))) == 1:2
203+
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(2, 3)), 1)) == 1:2
204+
@test_throws AssertionError ArrayInterface.indices((ones(2, 3), ones(3, 3)), 1)
205+
@test_throws AssertionError ArrayInterface.indices((ones(2, 3), ones(3, 3)), (1, 2))
206+
end
207+

0 commit comments

Comments
 (0)