diff --git a/Project.toml b/Project.toml index 4ecad6f..087e67c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" -version = "0.3.3" authors = ["ITensor developers and contributors"] +version = "0.3.4" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -18,12 +18,10 @@ TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d" [extensions] KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"] KroneckerArraysTensorAlgebraExt = "TensorAlgebra" -KroneckerArraysTensorProductsExt = "TensorProducts" [compat] Adapt = "4.3" @@ -36,7 +34,6 @@ GPUArraysCore = "0.2" LinearAlgebra = "1.10" MapBroadcast = "0.1.10" MatrixAlgebraKit = "0.6" -TensorAlgebra = "0.5" -TensorProducts = "0.1.7" +TensorAlgebra = "0.6.2" TypeParameterAccessors = "0.4.2" julia = "1.10" diff --git a/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl index 28f1618..921ab96 100644 --- a/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl +++ b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl @@ -1,8 +1,9 @@ module KroneckerArraysTensorAlgebraExt -using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, ⊗, kroneckerfactors +using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, CartesianProductUnitRange, + ⊗, cartesianrange, kroneckerfactors, kroneckerfactortypes using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTrivialPermutation, - FusionStyle, matricize, unmatricize + FusionStyle, matricize, tensor_product_axis, unmatricize struct KroneckerFusion{A <: FusionStyle, B <: FusionStyle} <: FusionStyle a::A @@ -11,33 +12,49 @@ end KroneckerArrays.kroneckerfactors(style::KroneckerFusion) = (style.a, style.b) KroneckerArrays.kroneckerfactortypes(::Type{KroneckerFusion{A, B}}) where {A, B} = (A, B) -function TensorAlgebra.FusionStyle(a::AbstractKroneckerArray) - return KroneckerFusion(FusionStyle.(kroneckerfactors(a))...) +function TensorAlgebra.FusionStyle(A::Type{<:AbstractKroneckerArray}) + return KroneckerFusion(FusionStyle.(kroneckerfactortypes(A))...) end +function TensorAlgebra.FusionStyle(A::Type{<:CartesianProductUnitRange}) + return KroneckerFusion(FusionStyle.(kroneckerfactortypes(A))...) +end + +function TensorAlgebra.tensor_product_axis( + style::KroneckerFusion, r1::AbstractUnitRange, r2::AbstractUnitRange + ) + style_a, style_b = kroneckerfactors(style) + r1a, r1b = kroneckerfactors(r1) + r2a, r2b = kroneckerfactors(r2) + ra = tensor_product_axis(style_a, r1a, r2a) + rb = tensor_product_axis(style_b, r1b, r2b) + return cartesianrange(ra, rb) +end + function matricize_kronecker( - style::FusionStyle, a::AbstractArray, length1::Val, length2::Val + style::FusionStyle, a::AbstractArray, length_codomain::Val ) - m1 = matricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), length1, length2) - m2 = matricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), length1, length2) + m1 = matricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), length_codomain) + m2 = matricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), length_codomain) return m1 ⊗ m2 end function TensorAlgebra.matricize( - style::KroneckerFusion, a::AbstractArray, length1::Val, length2::Val + style::KroneckerFusion, a::AbstractArray, length_codomain::Val ) - return matricize_kronecker(style, a, length1, length2) + return matricize_kronecker(style, a, length_codomain) end + function unmatricize_kronecker( style::FusionStyle, m::AbstractMatrix, - codomain_axes::Tuple{Vararg{AbstractUnitRange}}, - domain_axes::Tuple{Vararg{AbstractUnitRange}}, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) style1, style2 = kroneckerfactors(style) m1, m2 = kroneckerfactors(m) - codomain1 = kroneckerfactors.(codomain_axes, 1) - codomain2 = kroneckerfactors.(codomain_axes, 2) - domain1 = kroneckerfactors.(domain_axes, 1) - domain2 = kroneckerfactors.(domain_axes, 2) + codomain1 = kroneckerfactors.(axes_codomain, 1) + codomain2 = kroneckerfactors.(axes_codomain, 2) + domain1 = kroneckerfactors.(axes_domain, 1) + domain2 = kroneckerfactors.(axes_domain, 2) a1 = unmatricize(style1, m1, codomain1, domain1) a2 = unmatricize(style2, m2, codomain2, domain2) return a1 ⊗ a2 diff --git a/ext/KroneckerArraysTensorProductsExt/KroneckerArraysTensorProductsExt.jl b/ext/KroneckerArraysTensorProductsExt/KroneckerArraysTensorProductsExt.jl deleted file mode 100644 index f227960..0000000 --- a/ext/KroneckerArraysTensorProductsExt/KroneckerArraysTensorProductsExt.jl +++ /dev/null @@ -1,14 +0,0 @@ -module KroneckerArraysTensorProductsExt - -using TensorProducts: TensorProducts, tensor_product -using KroneckerArrays: CartesianProductOneTo, kroneckerfactors, cartesianrange, unproduct - -function TensorProducts.tensor_product(a1::CartesianProductOneTo, a2::CartesianProductOneTo) - return cartesianrange( - tensor_product(kroneckerfactors(a1, 1), kroneckerfactors(a2, 1)), - tensor_product(kroneckerfactors(a1, 2), kroneckerfactors(a2, 2)), - tensor_product(unproduct(a1), unproduct(a2)) - ) -end - -end diff --git a/test/Project.toml b/test/Project.toml index 32cd42a..6717a93 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -15,7 +15,6 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" @@ -38,7 +37,6 @@ MatrixAlgebraKit = "0.6" SafeTestsets = "0.1" StableRNGs = "1.0" Suppressor = "0.2" -TensorAlgebra = "0.5" -TensorProducts = "0.1.7" +TensorAlgebra = "0.6.2" Test = "1.10" TestExtras = "0.3" diff --git a/test/test_tensoralgebra.jl b/test/test_tensoralgebra.jl index 79763bf..ecbcd8e 100644 --- a/test/test_tensoralgebra.jl +++ b/test/test_tensoralgebra.jl @@ -1,10 +1,21 @@ -using TensorAlgebra: matricize, unmatricize -using KroneckerArrays: ⊗, kroneckerfactors +using TensorAlgebra: matricize, tensor_product_axis, unmatricize +using KroneckerArrays: ⊗, cartesianrange, kroneckerfactors, unproduct using Test: @test, @testset @testset "TensorAlgebraExt" begin - a = randn(2, 2, 2) ⊗ randn(3, 3, 3) - m = matricize(a, (1, 2), (3,)) - @test m == matricize(kroneckerfactors(a, 1), (1, 2), (3,)) ⊗ matricize(kroneckerfactors(a, 2), (1, 2), (3,)) - @test unmatricize(m, (axes(a, 1), axes(a, 2)), (axes(a, 3),)) == a + @testset "tensor_product_axis" begin + r1 = cartesianrange(2, 3) + r2 = cartesianrange(4, 5) + r = tensor_product_axis(r1, r2) + @test r ≡ cartesianrange(8, 15) + @test kroneckerfactors(r, 1) ≡ Base.OneTo(8) + @test kroneckerfactors(r, 2) ≡ Base.OneTo(15) + @test unproduct(r) ≡ Base.OneTo(120) + end + @testset "matricize/unmatricize" begin + a = randn(2, 2, 2) ⊗ randn(3, 3, 3) + m = matricize(a, (1, 2), (3,)) + @test m == matricize(kroneckerfactors(a, 1), (1, 2), (3,)) ⊗ matricize(kroneckerfactors(a, 2), (1, 2), (3,)) + @test unmatricize(m, (axes(a, 1), axes(a, 2)), (axes(a, 3),)) == a + end end diff --git a/test/test_tensorproducts.jl b/test/test_tensorproducts.jl deleted file mode 100644 index d8ed112..0000000 --- a/test/test_tensorproducts.jl +++ /dev/null @@ -1,13 +0,0 @@ -using KroneckerArrays: ×, kroneckerfactors, cartesianrange, unproduct -using TensorProducts: tensor_product -using Test: @test, @testset - -@testset "KroneckerArraysTensorProductsExt" begin - r1 = cartesianrange(2, 3) - r2 = cartesianrange(4, 5) - r = tensor_product(r1, r2) - @test r ≡ cartesianrange(8, 15) - @test kroneckerfactors(r, 1) ≡ Base.OneTo(8) - @test kroneckerfactors(r, 2) ≡ Base.OneTo(15) - @test unproduct(r) ≡ Base.OneTo(120) -end