Skip to content

Commit d450944

Browse files
committed
add is_lazy_conjugate
1 parent 3eed3f6 commit d450944

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

src/ArrayInterface.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,51 @@ end
841841
return setindex!(A, val; kwargs...)
842842
end
843843

844+
845+
"""
846+
is_lazy_conjugate(::AbstractArray)
847+
848+
Determine if a given array will lazyily take complex conjugates, such as with `Adjoint`. This will work with
849+
nested wrappers, so long as there is no type in the chain of wrappers such that `parent_type(T) == T`
850+
851+
Examples
852+
853+
julia> a = transpose([1 + im, 1-im]')
854+
2×1 transpose(adjoint(::Vector{Complex{Int64}})) with eltype Complex{Int64}:
855+
1 - 1im
856+
1 + 1im
857+
858+
julia> ArrayInterface.is_lazy_conjugate(a)
859+
true
860+
861+
julia> b = a'
862+
1×2 adjoint(transpose(adjoint(::Vector{Complex{Int64}}))) with eltype Complex{Int64}:
863+
1+1im 1-1im
864+
865+
julia> ArrayInterface.is_lazy_conjugate(b)
866+
false
867+
868+
"""
869+
is_lazy_conjugate(::T) where {T <: AbstractArray} = _is_lazy_conjugate(T, false)
870+
871+
function _is_lazy_conjugate(::Type{T}, isconj) where {T <: AbstractArray}
872+
Tp = parent_type(T)
873+
if T !== Tp
874+
_is_lazy_conjugate(Tp, isconj)
875+
else
876+
isconj
877+
end
878+
end
879+
880+
function _is_lazy_conjugate(::Type{T}, isconj) where {T <: Adjoint}
881+
Tp = parent_type(T)
882+
if T !== Tp
883+
_is_lazy_conjugate(Tp, !isconj)
884+
else
885+
!isconj
886+
end
887+
end
888+
844889
include("ranges.jl")
845890
include("dimensions.jl")
846891
include("axes.jl")
@@ -849,6 +894,8 @@ include("indexing.jl")
849894
include("stridelayout.jl")
850895
include("broadcast.jl")
851896

897+
898+
852899
function __init__()
853900

854901
@require SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" begin

test/runtests.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ using Base: setindex
33
using IfElse
44
using ArrayInterface: StaticInt, True, False
55
import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance,
6-
device, contiguous_axis, contiguous_batch_size, stride_rank, dense_dims, static, NDIndex
6+
device, contiguous_axis, contiguous_batch_size, stride_rank, dense_dims, static, NDIndex,
7+
is_lazy_conjugate
78

89

910
if VERSION v"1.6"
@@ -786,4 +787,19 @@ include("dimensions.jl")
786787
include("broadcast.jl")
787788
end
788789

790+
using Test
791+
using ArrayInterface: is_lazy_conjugate
792+
793+
@testset "lazy conj" begin
794+
a = rand(ComplexF64, 2)
795+
@test is_lazy_conjugate(a) == false
796+
b = a'
797+
@test is_lazy_conjugate(b) == true
798+
c = transpose(b)
799+
@test is_lazy_conjugate(c) == true
800+
d = c'
801+
@test is_lazy_conjugate(d) == false
802+
e = permutedims(d)
803+
@test is_lazy_conjugate(e) == false
804+
end
789805

0 commit comments

Comments
 (0)