From fd836c4abeabd3897ffcc91e690069c79aa81a72 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 30 Jul 2025 16:36:53 -0400 Subject: [PATCH 01/28] FusionTreePair --- src/fusiontrees/fusiontrees.jl | 5 +- src/fusiontrees/manipulations.jl | 165 +++++++++++++----------------- src/planar/planaroperations.jl | 2 +- src/tensors/braidingtensor.jl | 4 +- src/tensors/diagonal.jl | 2 +- src/tensors/indexmanipulations.jl | 8 +- src/tensors/tensoroperations.jl | 2 +- src/tensors/treetransformers.jl | 10 +- test/fusiontrees.jl | 35 ++++--- 9 files changed, 106 insertions(+), 127 deletions(-) diff --git a/src/fusiontrees/fusiontrees.jl b/src/fusiontrees/fusiontrees.jl index 3e694e523..71d67953d 100644 --- a/src/fusiontrees/fusiontrees.jl +++ b/src/fusiontrees/fusiontrees.jl @@ -92,6 +92,8 @@ function FusionTree(uncoupled::NTuple{N,I}, coupled::I, end FusionTree(uncoupled::Tuple{I,Vararg{I}}) where {I<:Sector} = FusionTree(uncoupled, one(I)) +const FusionTreePair{I,N₁,N₂} = Tuple{FusionTree{I,N₁},FusionTree{I,N₂}} + # Properties sectortype(::Type{<:FusionTree{I}}) where {I<:Sector} = I FusionStyle(::Type{<:FusionTree{I}}) where {I<:Sector} = FusionStyle(I) @@ -199,8 +201,7 @@ function Base.convert(A::Type{<:AbstractArray}, f::FusionTree{I,N}) where {I,N} end # TODO: is this piracy? -function Base.convert(A::Type{<:AbstractArray}, - (f₁, f₂)::Tuple{FusionTree{I},FusionTree{I}}) where {I} +function Base.convert(A::Type{<:AbstractArray}, (f₁, f₂)::FusionTreePair{I}) where {I} F₁ = convert(A, f₁) F₂ = convert(A, f₂) sz1 = size(F₁) diff --git a/src/fusiontrees/manipulations.jl b/src/fusiontrees/manipulations.jl index f19bb06ed..e9122e592 100644 --- a/src/fusiontrees/manipulations.jl +++ b/src/fusiontrees/manipulations.jl @@ -244,8 +244,7 @@ end # -> A-move (foldleft, foldright) is complicated, needs to be reexpressed in standard form # flip a duality flag of a fusion tree -function flip(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}, i::Int; - inv::Bool=false) where {I<:Sector,N₁,N₂} +function flip((f₁, f₂)::FusionTreePair{I,N₁,N₂}, i::Int; inv::Bool=false) where {I,N₁,N₂} @assert 0 < i ≤ N₁ + N₂ if i ≤ N₁ a = f₁.uncoupled[i] @@ -274,19 +273,18 @@ function flip(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}, i::Int; return SingletonDict((f₁, f₂′) => factor) end end -function flip(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}, ind; - inv::Bool=false) where {I<:Sector,N₁,N₂} +function flip((f₁, f₂)::FusionTreePair{I,N₁,N₂}, ind; inv::Bool=false) where {I,N₁,N₂} f₁′, f₂′ = f₁, f₂ factor = one(sectorscalartype(I)) for i in ind - (f₁′, f₂′), s = only(flip(f₁′, f₂′, i; inv)) + (f₁′, f₂′), s = only(flip((f₁′, f₂′), i; inv)) factor *= s end return SingletonDict((f₁′, f₂′) => factor) end # change to N₁ - 1, N₂ + 1 -function bendright(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}) where {I<:Sector,N₁,N₂} +function bendright((f₁, f₂)::FusionTreePair{I,N₁,N₂}) where {I,N₁,N₂} # map final splitting vertex (a, b)<-c to fusion vertex a<-(c, dual(b)) @assert N₁ > 0 c = f₁.coupled @@ -332,15 +330,14 @@ function bendright(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}) where {I< end end # change to N₁ + 1, N₂ - 1 -function bendleft(f₁::FusionTree{I}, f₂::FusionTree{I}) where {I} +function bendleft((f₁, f₂)::FusionTreePair{I}) where {I} # map final fusion vertex c<-(a, b) to splitting vertex (c, dual(b))<-a return fusiontreedict(I)((f₁′, f₂′) => conj(coeff) - for - ((f₂′, f₁′), coeff) in bendright(f₂, f₁)) + for ((f₂′, f₁′), coeff) in bendright((f₂, f₁))) end # change to N₁ - 1, N₂ + 1 -function foldright(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}) where {I<:Sector,N₁,N₂} +function foldright((f₁, f₂)::FusionTreePair{I,N₁,N₂}) where {I,N₁,N₂} # map first splitting vertex (a, b)<-c to fusion vertex b<-(dual(a), c) @assert N₁ > 0 a = f₁.uncoupled[1] @@ -396,11 +393,10 @@ function foldright(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}) where {I< end end # change to N₁ + 1, N₂ - 1 -function foldleft(f₁::FusionTree{I}, f₂::FusionTree{I}) where {I} +function foldleft((f₁, f₂)::FusionTreePair{I}) where {I} # map first fusion vertex c<-(a, b) to splitting vertex (dual(a), c)<-b return fusiontreedict(I)((f₁′, f₂′) => conj(coeff) - for - ((f₂′, f₁′), coeff) in foldright(f₂, f₁)) + for ((f₂′, f₁′), coeff) in foldright((f₂, f₁))) end # COMPOSITE DUALITY MANIPULATIONS PART 1: Repartition and transpose @@ -423,11 +419,11 @@ function iscyclicpermutation(v1, v2) end # clockwise cyclic permutation while preserving (N₁, N₂): foldright & bendleft -function cycleclockwise(f₁::FusionTree{I}, f₂::FusionTree{I}) where {I<:Sector} +function cycleclockwise((f₁, f₂)::FusionTreePair{I}) where {I} local newtrees if length(f₁) > 0 - for ((f1a, f2a), coeffa) in foldright(f₁, f₂) - for ((f1b, f2b), coeffb) in bendleft(f1a, f2a) + for ((f1a, f2a), coeffa) in foldright((f₁, f₂)) + for ((f1b, f2b), coeffb) in bendleft((f1a, f2a)) coeff = coeffa * coeffb if (@isdefined newtrees) newtrees[(f1b, f2b)] = get(newtrees, (f1b, f2b), zero(coeff)) + coeff @@ -437,8 +433,8 @@ function cycleclockwise(f₁::FusionTree{I}, f₂::FusionTree{I}) where {I<:Sect end end else - for ((f1a, f2a), coeffa) in bendleft(f₁, f₂) - for ((f1b, f2b), coeffb) in foldright(f1a, f2a) + for ((f1a, f2a), coeffa) in bendleft((f₁, f₂)) + for ((f1b, f2b), coeffb) in foldright((f1a, f2a)) coeff = coeffa * coeffb if (@isdefined newtrees) newtrees[(f1b, f2b)] = get(newtrees, (f1b, f2b), zero(coeff)) + coeff @@ -452,11 +448,11 @@ function cycleclockwise(f₁::FusionTree{I}, f₂::FusionTree{I}) where {I<:Sect end # anticlockwise cyclic permutation while preserving (N₁, N₂): foldleft & bendright -function cycleanticlockwise(f₁::FusionTree{I}, f₂::FusionTree{I}) where {I<:Sector} +function cycleanticlockwise((f₁, f₂)::FusionTreePair{I}) where {I} local newtrees if length(f₂) > 0 - for ((f1a, f2a), coeffa) in foldleft(f₁, f₂) - for ((f1b, f2b), coeffb) in bendright(f1a, f2a) + for ((f1a, f2a), coeffa) in foldleft((f₁, f₂)) + for ((f1b, f2b), coeffb) in bendright((f1a, f2a)) coeff = coeffa * coeffb if (@isdefined newtrees) newtrees[(f1b, f2b)] = get(newtrees, (f1b, f2b), zero(coeff)) + coeff @@ -466,8 +462,8 @@ function cycleanticlockwise(f₁::FusionTree{I}, f₂::FusionTree{I}) where {I<: end end else - for ((f1a, f2a), coeffa) in bendright(f₁, f₂) - for ((f1b, f2b), coeffb) in foldleft(f1a, f2a) + for ((f1a, f2a), coeffa) in bendright((f₁, f₂)) + for ((f1b, f2b), coeffb) in foldleft((f1a, f2a)) coeff = coeffa * coeffb if (@isdefined newtrees) newtrees[(f1b, f2b)] = get(newtrees, (f1b, f2b), zero(coeff)) + coeff @@ -482,8 +478,8 @@ end # repartition double fusion tree """ - repartition(f₁::FusionTree{I, N₁}, f₂::FusionTree{I, N₂}, N::Int) where {I, N₁, N₂} - -> <:AbstractDict{Tuple{FusionTree{I, N}, FusionTree{I, N₁+N₂-N}}, <:Number} + repartition((f₁, f₂)::FusionTreePair{I, N₁, N₂}, N::Int) where {I, N₁, N₂} + -> <:AbstractDict{<:FusionTreePair{I, N, N₁+N₂-N}}, <:Number} Input is a double fusion tree that describes the fusion of a set of incoming uncoupled sectors to a set of outgoing uncoupled sectors, represented using the individual trees of @@ -492,35 +488,32 @@ outgoing (`f₁`) and incoming sectors (`f₂`) respectively (with identical cou repartitioning the tree by bending incoming to outgoing sectors (or vice versa) in order to have `N` outgoing sectors. """ -@inline function repartition(f₁::FusionTree{I,N₁}, - f₂::FusionTree{I,N₂}, - N::Int) where {I<:Sector,N₁,N₂} +@inline function repartition((f₁, f₂)::FusionTreePair{I,N₁,N₂}, N::Int) where {I,N₁,N₂} f₁.coupled == f₂.coupled || throw(SectorMismatch()) @assert 0 <= N <= N₁ + N₂ - return _recursive_repartition(f₁, f₂, Val(N)) + return _recursive_repartition((f₁, f₂), Val(N)) end -function _recursive_repartition(f₁::FusionTree{I,N₁}, - f₂::FusionTree{I,N₂}, - ::Val{N}) where {I<:Sector,N₁,N₂,N} +function _recursive_repartition((f₁, f₂)::FusionTreePair{I,N₁,N₂}, + ::Val{N}) where {I,N₁,N₂,N} # recursive definition is only way to get correct number of loops for # GenericFusion, but is too complex for type inference to handle, so we # precompute the parameters of the return type F₁ = fusiontreetype(I, N) F₂ = fusiontreetype(I, N₁ + N₂ - N) + FF = Tuple{F₁,F₂} T = sectorscalartype(I) coeff = one(T) if N == N₁ return fusiontreedict(I){Tuple{F₁,F₂},T}((f₁, f₂) => coeff) else local newtrees::fusiontreedict(I){Tuple{F₁,F₂},T} - for ((f₁′, f₂′), coeff1) in (N < N₁ ? bendright(f₁, f₂) : bendleft(f₁, f₂)) - for ((f₁′′, f₂′′), coeff2) in _recursive_repartition(f₁′, f₂′, Val(N)) + for ((f₁′, f₂′), coeff1) in (N < N₁ ? bendright((f₁, f₂)) : bendleft((f₁, f₂))) + for ((f₁′′, f₂′′), coeff2) in _recursive_repartition((f₁′, f₂′), Val(N)) if (@isdefined newtrees) push!(newtrees, (f₁′′, f₂′′) => coeff1 * coeff2) else - newtrees = fusiontreedict(I){Tuple{F₁,F₂},T}((f₁′′, f₂′′) => coeff1 * - coeff2) + newtrees = fusiontreedict(I){FF,T}((f₁′′, f₂′′) => coeff1 * coeff2) end end end @@ -529,9 +522,8 @@ function _recursive_repartition(f₁::FusionTree{I,N₁}, end """ - transpose(f₁::FusionTree{I}, f₂::FusionTree{I}, - p1::NTuple{N₁, Int}, p2::NTuple{N₂, Int}) where {I, N₁, N₂} - -> <:AbstractDict{Tuple{FusionTree{I, N₁}, FusionTree{I, N₂}}, <:Number} + transpose((f₁, f₂)::FusionTreePair{I}, p::(Index2Tuple{N₁, N₂}) where {I, N₁, N₂} + -> <:AbstractDict{<:FusionTreePair{I, N₁, N₂}}, <:Number} Input is a double fusion tree that describes the fusion of a set of incoming uncoupled sectors to a set of outgoing uncoupled sectors, represented using the individual trees of @@ -540,17 +532,15 @@ outgoing (`t1`) and incoming sectors (`t2`) respectively (with identical coupled repartitioning and permuting the tree such that sectors `p1` become outgoing and sectors `p2` become incoming. """ -function Base.transpose(f₁::FusionTree{I}, f₂::FusionTree{I}, - p1::IndexTuple{N₁}, p2::IndexTuple{N₂}) where {I<:Sector,N₁,N₂} +function Base.transpose((f₁, f₂)::FusionTreePair{I}, p::Index2Tuple{N₁,N₂}) where {I,N₁,N₂} N = N₁ + N₂ @assert length(f₁) + length(f₂) == N - p = linearizepermutation(p1, p2, length(f₁), length(f₂)) - @assert iscyclicpermutation(p) - return fstranspose((f₁, f₂, p1, p2)) + p′ = linearizepermutation(p..., length(f₁), length(f₂)) + @assert iscyclicpermutation(p′) + return fstranspose(((f₁, f₂), p)) end -const FSTransposeKey{I<:Sector,N₁,N₂} = Tuple{<:FusionTree{I},<:FusionTree{I}, - IndexTuple{N₁},IndexTuple{N₂}} +const FSTransposeKey{I,N₁,N₂} = Tuple{<:FusionTreePair{I},Index2Tuple{N₁,N₂}} function _fsdicttype(I, N₁, N₂) F₁ = fusiontreetype(I, N₁) @@ -560,13 +550,11 @@ function _fsdicttype(I, N₁, N₂) end @cached function fstranspose(key::FSTransposeKey{I,N₁,N₂})::_fsdicttype(I, N₁, - N₂) where {I<:Sector, - N₁, - N₂} - f₁, f₂, p1, p2 = key + N₂) where {I,N₁,N₂} + (f₁, f₂), (p1, p2) = key N = N₁ + N₂ p = linearizepermutation(p1, p2, length(f₁), length(f₂)) - newtrees = repartition(f₁, f₂, N₁) + newtrees = repartition((f₁, f₂), N₁) length(p) == 0 && return newtrees i1 = findfirst(==(1), p) @assert i1 !== nothing @@ -575,7 +563,7 @@ end while 1 < i1 <= Nhalf local newtrees′ for ((f1a, f2a), coeffa) in newtrees - for ((f1b, f2b), coeffb) in cycleanticlockwise(f1a, f2a) + for ((f1b, f2b), coeffb) in cycleanticlockwise((f1a, f2a)) coeff = coeffa * coeffb if (@isdefined newtrees′) newtrees′[(f1b, f2b)] = get(newtrees′, (f1b, f2b), zero(coeff)) + coeff @@ -590,7 +578,7 @@ end while Nhalf < i1 local newtrees′ for ((f1a, f2a), coeffa) in newtrees - for ((f1b, f2b), coeffb) in cycleclockwise(f1a, f2a) + for ((f1b, f2b), coeffb) in cycleclockwise((f1a, f2a)) coeff = coeffa * coeffb if (@isdefined newtrees′) newtrees′[(f1b, f2b)] = get(newtrees′, (f1b, f2b), zero(coeff)) + coeff @@ -605,7 +593,7 @@ end return newtrees end -function CacheStyle(::typeof(fstranspose), k::FSTransposeKey{I}) where {I<:Sector} +function CacheStyle(::typeof(fstranspose), k::FSTransposeKey{I}) where {I} if FusionStyle(I) isa UniqueFusion return NoCache() else @@ -618,13 +606,12 @@ end # -> composite manipulations that depend on the duality (rigidity) and pivotal structure # -> planar manipulations that do not require braiding, everything is in Fsymbol (A/Bsymbol) -function planar_trace(f₁::FusionTree{I}, f₂::FusionTree{I}, - p1::IndexTuple{N₁}, p2::IndexTuple{N₂}, - q1::IndexTuple{N₃}, q2::IndexTuple{N₃}) where {I<:Sector,N₁,N₂,N₃} +function planar_trace((f₁, f₂)::FusionTreePair{I}, (p1, p2)::Index2Tuple{N₁,N₂}, + (q1, q2)::Index2Tuple{N₃,N₃}) where {I,N₁,N₂,N₃} N = N₁ + N₂ + 2N₃ @assert length(f₁) + length(f₂) == N if N₃ == 0 - return transpose(f₁, f₂, p1, p2) + return transpose((f₁, f₂), (p1, p2)) end linearindex = (ntuple(identity, Val(length(f₁)))..., @@ -641,9 +628,9 @@ function planar_trace(f₁::FusionTree{I}, f₂::FusionTree{I}, F₁ = fusiontreetype(I, N₁) F₂ = fusiontreetype(I, N₂) newtrees = FusionTreeDict{Tuple{F₁,F₂},T}() - for ((f₁′, f₂′), coeff′) in repartition(f₁, f₂, N) - for (f₁′′, coeff′′) in planar_trace(f₁′, q1′, q2′) - for (f12′′′, coeff′′′) in transpose(f₁′′, f₂′, p1′, p2′) + for ((f₁′, f₂′), coeff′) in repartition((f₁, f₂), N) + for (f₁′′, coeff′′) in planar_trace(f₁′, (q1′, q2′)) + for (f12′′′, coeff′′′) in transpose((f₁′′, f₂′), (p1′, p2′)) coeff = coeff′ * coeff′′ * coeff′′′ if !iszero(coeff) newtrees[f12′′′] = get(newtrees, f12′′′, zero(coeff)) + coeff @@ -655,15 +642,14 @@ function planar_trace(f₁::FusionTree{I}, f₂::FusionTree{I}, end """ - planar_trace(f::FusionTree{I,N}, q1::IndexTuple{N₃}, q2::IndexTuple{N₃}) where {I<:Sector,N,N₃} + planar_trace(f::FusionTree{I,N}, (q1, q2)::Index2Tuple{N₃,N₃}) where {I,N,N₃} -> <:AbstractDict{FusionTree{I,N-2*N₃}, <:Number} Perform a planar trace of the uncoupled indices of the fusion tree `f` at `q1` with those at `q2`, where `q1[i]` is connected to `q2[i]` for all `i`. The result is returned as a dictionary of output trees and corresponding coefficients. """ -function planar_trace(f::FusionTree{I,N}, - q1::IndexTuple{N₃}, q2::IndexTuple{N₃}) where {I<:Sector,N,N₃} +function planar_trace(f::FusionTree{I,N}, (q1, q2)::Index2Tuple{N₃,N₃}) where {I,N,N₃} T = sectorscalartype(I) F = fusiontreetype(I, N - 2 * N₃) newtrees = FusionTreeDict{F,T}() @@ -697,7 +683,7 @@ function planar_trace(f::FusionTree{I,N}, map(l -> (l - (l > i) - (l > j)), TupleTools.deleteat(q2, k)) end for (f′, coeff′) in elementary_trace(f, i) - for (f′′, coeff′′) in planar_trace(f′, q1′, q2′) + for (f′′, coeff′′) in planar_trace(f′, (q1′, q2′)) coeff = coeff′ * coeff′′ if !iszero(coeff) newtrees[f′′] = get(newtrees, f′′, zero(coeff)) + coeff @@ -709,13 +695,13 @@ end # trace two neighbouring indices of a single fusion tree """ - elementary_trace(f::FusionTree{I,N}, i) where {I<:Sector,N} -> <:AbstractDict{FusionTree{I,N-2}, <:Number} + elementary_trace(f::FusionTree{I,N}, i) where {I,N} -> <:AbstractDict{FusionTree{I,N-2}, <:Number} Perform an elementary trace of neighbouring uncoupled indices `i` and `i+1` on a fusion tree `f`, and returns the result as a dictionary of output trees and corresponding coefficients. """ -function elementary_trace(f::FusionTree{I,N}, i) where {I<:Sector,N} +function elementary_trace(f::FusionTree{I,N}, i) where {I,N} (N > 1 && 1 <= i <= N) || throw(ArgumentError("Cannot trace outputs i=$i and i+1 out of only $N outputs")) i < N || isone(f.coupled) || @@ -820,7 +806,7 @@ applying `artin_braid(f′, i; inv = true)` to all the outputs `f′` of tree with non-zero coefficient, namely `f` with coefficient `1`. This keyword has no effect if `BraidingStyle(sectortype(f)) isa SymmetricBraiding`. """ -function artin_braid(f::FusionTree{I,N}, i; inv::Bool=false) where {I<:Sector,N} +function artin_braid(f::FusionTree{I,N}, i; inv::Bool=false) where {I,N} 1 <= i < N || throw(ArgumentError("Cannot swap outputs i=$i and i+1 out of only $N outputs")) uncoupled = f.uncoupled @@ -963,9 +949,7 @@ that if `i` and `j` cross, ``τ_{i,j}`` is applied if `levels[i] < levels[j]` an ``τ_{j,i}^{-1}`` if `levels[i] > levels[j]`. This does not allow to encode the most general braid, but a general braid can be obtained by combining such operations. """ -function braid(f::FusionTree{I,N}, - levels::NTuple{N,Int}, - p::NTuple{N,Int}) where {I<:Sector,N} +function braid(f::FusionTree{I,N}, levels::NTuple{N,Int}, p::NTuple{N,Int}) where {I,N} TupleTools.isperm(p) || throw(ArgumentError("not a valid permutation: $p")) if FusionStyle(I) isa UniqueFusion && BraidingStyle(I) isa SymmetricBraiding coeff = one(sectorscalartype(I)) @@ -1012,17 +996,16 @@ Perform a permutation of the uncoupled indices of the fusion tree `f` and return as a `<:AbstractDict` of output trees and corresponding coefficients; this requires that `BraidingStyle(sectortype(f)) isa SymmetricBraiding`. """ -function permute(f::FusionTree{I,N}, p::NTuple{N,Int}) where {I<:Sector,N} +function permute(f::FusionTree{I,N}, p::NTuple{N,Int}) where {I,N} @assert BraidingStyle(I) isa SymmetricBraiding return braid(f, ntuple(identity, Val(N)), p) end # braid double fusion tree """ - braid(f₁::FusionTree{I}, f₂::FusionTree{I}, - levels1::IndexTuple, levels2::IndexTuple, - p1::IndexTuple{N₁}, p2::IndexTuple{N₂}) where {I<:Sector, N₁, N₂} - -> <:AbstractDict{Tuple{FusionTree{I, N₁}, FusionTree{I, N₂}}, <:Number} + braid((f₁, f₂)::FusionTreePair{I}, (levels1, levels2)::Index2Tuple, + (p1, p2)::Index2Tuple{N₁, N₂}) where {I, N₁, N₂} + -> <:AbstractDict{<:FusionTreePair{I, N₁, N₂}}, <:Number} Input is a fusion-splitting tree pair that describes the fusion of a set of incoming uncoupled sectors to a set of outgoing uncoupled sectors, represented using the splitting @@ -1036,27 +1019,23 @@ respectively, which determines how indices braid. In particular, if `i` and `j` levels[j]`. This does not allow to encode the most general braid, but a general braid can be obtained by combining such operations. """ -function braid(f₁::FusionTree{I}, f₂::FusionTree{I}, - levels1::IndexTuple, levels2::IndexTuple, - p1::IndexTuple{N₁}, p2::IndexTuple{N₂}) where {I<:Sector,N₁,N₂} +function braid((f₁, f₂)::FusionTreePair{I}, (levels1, levels2)::Index2Tuple, + (p1, p2)::Index2Tuple{N₁,N₂}) where {I,N₁,N₂} @assert length(f₁) + length(f₂) == N₁ + N₂ @assert length(f₁) == length(levels1) && length(f₂) == length(levels2) @assert TupleTools.isperm((p1..., p2...)) - return fsbraid((f₁, f₂, levels1, levels2, p1, p2)) + return fsbraid(((f₁, f₂), (levels1, levels2), (p1, p2))) end -const FSBraidKey{I<:Sector,N₁,N₂} = Tuple{<:FusionTree{I},<:FusionTree{I}, - IndexTuple,IndexTuple, - IndexTuple{N₁},IndexTuple{N₂}} +const FSBraidKey{I,N₁,N₂} = Tuple{<:FusionTreePair{I},Index2Tuple,Index2Tuple{N₁,N₂}} -@cached function fsbraid(key::FSBraidKey{I,N₁,N₂})::_fsdicttype(I, N₁, - N₂) where {I<:Sector,N₁,N₂} - (f₁, f₂, l1, l2, p1, p2) = key +@cached function fsbraid(key::FSBraidKey{I,N₁,N₂})::_fsdicttype(I, N₁, N₂) where {I,N₁,N₂} + ((f₁, f₂), (l1, l2), (p1, p2)) = key p = linearizepermutation(p1, p2, length(f₁), length(f₂)) levels = (l1..., reverse(l2)...) local newtrees - for ((f, f0), coeff1) in repartition(f₁, f₂, N₁ + N₂) + for ((f, f0), coeff1) in repartition((f₁, f₂), N₁ + N₂) for (f′, coeff2) in braid(f, levels, p) - for ((f₁′, f₂′), coeff3) in repartition(f′, f0, N₁) + for ((f₁′, f₂′), coeff3) in repartition((f′, f0), N₁) if @isdefined newtrees newtrees[(f₁′, f₂′)] = get(newtrees, (f₁′, f₂′), zero(coeff3)) + coeff1 * coeff2 * coeff3 @@ -1078,9 +1057,8 @@ function CacheStyle(::typeof(fsbraid), k::FSBraidKey{I}) where {I<:Sector} end """ - permute(f₁::FusionTree{I}, f₂::FusionTree{I}, - p1::NTuple{N₁, Int}, p2::NTuple{N₂, Int}) where {I, N₁, N₂} - -> <:AbstractDict{Tuple{FusionTree{I, N₁}, FusionTree{I, N₂}}, <:Number} + permute((f₁, f₂)::FusionTreePair{I}, (p1, p2)::Index2Tuple{N₁, N₂}) where {I, N₁, N₂} + -> <:AbstractDict{<:FusionTreePair{I, N₁, N₂}}, <:Number} Input is a double fusion tree that describes the fusion of a set of incoming uncoupled sectors to a set of outgoing uncoupled sectors, represented using the individual trees of @@ -1089,10 +1067,9 @@ outgoing (`t1`) and incoming sectors (`t2`) respectively (with identical coupled repartitioning and permuting the tree such that sectors `p1` become outgoing and sectors `p2` become incoming. """ -function permute(f₁::FusionTree{I}, f₂::FusionTree{I}, - p1::IndexTuple{N₁}, p2::IndexTuple{N₂}) where {I<:Sector,N₁,N₂} +function permute((f₁, f₂)::FusionTreePair{I}, (p1, p2)::Index2Tuple{N₁, N₂}) where {I, N₁, N₂} @assert BraidingStyle(I) isa SymmetricBraiding levels1 = ntuple(identity, length(f₁)) levels2 = length(f₁) .+ ntuple(identity, length(f₂)) - return braid(f₁, f₂, levels1, levels2, p1, p2) + return braid((f₁, f₂), (levels1, levels2), (p1, p2)) end diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index 15fce65d0..8d689834f 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -95,7 +95,7 @@ function planartrace!(C::AbstractTensorMap, end β′ = One() for (f₁, f₂) in fusiontrees(A) - for ((f₁′, f₂′), coeff) in planar_trace(f₁, f₂, p₁, p₂, q₁, q₂) + for ((f₁′, f₂′), coeff) in planar_trace((f₁, f₂), (p₁, p₂), (q₁, q₂)) TO.tensortrace!(C[f₁′, f₂′], A[f₁, f₂], (p₁, p₂), (q₁, q₂), false, α * coeff, β′, backend, allocator) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index a4a184e5f..3889f7551 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -230,7 +230,7 @@ function planarcontract!(C::AbstractTensorMap, inv_braid = τ_levels[cindA[1]] > τ_levels[cindA[2]] for (f₁, f₂) in fusiontrees(B) local newtrees - for ((f₁′, f₂′), coeff′) in transpose(f₁, f₂, cindB, oindB) + for ((f₁′, f₂′), coeff′) in transpose((f₁, f₂), (cindB, oindB)) for (f₁′′, coeff′′) in artin_braid(f₁′, 1; inv=inv_braid) f12 = (f₁′′, f₂′) coeff = coeff′ * coeff′′ @@ -281,7 +281,7 @@ function planarcontract!(C::AbstractTensorMap, for (f₁, f₂) in fusiontrees(A) local newtrees - for ((f₁′, f₂′), coeff′) in transpose(f₁, f₂, oindA, cindA) + for ((f₁′, f₂′), coeff′) in transpose((f₁, f₂), (oindA, cindA)) for (f₂′′, coeff′′) in artin_braid(f₂′, 1; inv=inv_braid) f12 = (f₁′, f₂′′) coeff = coeff′ * conj(coeff′′) diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index 88d0c3b25..7b0ea9cf2 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -203,7 +203,7 @@ function permute(d::DiagonalTensorMap, (p₁, p₂)::Index2Tuple{1,1}; d′ = typeof(d)(undef, dual(d.domain)) for (c, b) in blocks(d) f = only(fusiontrees(codomain(d), c)) - ((f′, _), coeff) = only(permute(f, f, p₁, p₂)) + ((f′, _), coeff) = only(permute((f, f), (p₁, p₂))) c′ = f′.coupled scale!(block(d′, c′), b, coeff) end diff --git a/src/tensors/indexmanipulations.jl b/src/tensors/indexmanipulations.jl index c4e2b9228..6cda0c584 100644 --- a/src/tensors/indexmanipulations.jl +++ b/src/tensors/indexmanipulations.jl @@ -17,7 +17,7 @@ function flip(t::AbstractTensorMap, I; inv::Bool=false) P = flip(space(t), I) t′ = similar(t, P) for (f₁, f₂) in fusiontrees(t) - (f₁′, f₂′), factor = only(flip(f₁, f₂, I; inv)) + (f₁′, f₂′), factor = only(flip((f₁, f₂), I; inv)) scale!(t′[f₁′, f₂′], t[f₁, f₂], factor) end return t′ @@ -558,7 +558,7 @@ function _add_abelian_kernel_threaded!(tdst, tsrc, p, transformer, α, β, backe end function _add_abelian_block!(tdst, tsrc, p, transformer, f₁, f₂, α, β, backend...) - (f₁′, f₂′), coeff = first(transformer(f₁, f₂)) + (f₁′, f₂′), coeff = first(transformer((f₁, f₂))) @inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, β, backend...) return nothing @@ -618,7 +618,7 @@ function _add_general_kernel_nonthreaded!(tdst, tsrc, p, transformer, α, β, ba tdst = scale!(tdst, β) end for (f₁, f₂) in fusiontrees(tsrc) - for ((f₁′, f₂′), coeff) in transformer(f₁, f₂) + for ((f₁′, f₂′), coeff) in transformer((f₁, f₂)) @inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, One(), backend...) end @@ -683,7 +683,7 @@ end function _add_nonabelian_sector!(tdst, tsrc, p, fusiontreetransform, s₁, s₂, α, backend...) for (f₁, f₂) in fusiontrees(tsrc) (f₁.uncoupled == s₁ && f₂.uncoupled == s₂) || continue - for ((f₁′, f₂′), coeff) in fusiontreetransform(f₁, f₂) + for ((f₁′, f₂′), coeff) in fusiontreetransform((f₁, f₂)) @inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, One(), backend...) end diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index 6110093b6..5910c3322 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -199,7 +199,7 @@ function trace_permute!(tdst::AbstractTensorMap, r₁ = (p₁..., q₁...) r₂ = (p₂..., q₂...) for (f₁, f₂) in fusiontrees(tsrc) - for ((f₁′, f₂′), coeff) in permute(f₁, f₂, r₁, r₂) + for ((f₁′, f₂′), coeff) in permute((f₁, f₂), (r₁, r₂)) f₁′′, g₁ = split(f₁′, N₁) f₂′′, g₂ = split(f₂′, N₂) g₁ == g₂ || continue diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index 046af5edc..aa9a3da12 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -26,7 +26,7 @@ function AbelianTreeTransformer(transform, p, Vdst, Vsrc) for i in 1:L f₁, f₂ = structure_src.fusiontreelist[i] - (f₃, f₄), coeff = only(transform(f₁, f₂)) + (f₃, f₄), coeff = only(transform((f₁, f₂))) j = structure_dst.fusiontreeindices[(f₃, f₄)] stridestructure_dst = structure_dst.fusiontreestructure[j] stridestructure_src = structure_src.fusiontreestructure[i] @@ -166,14 +166,14 @@ end # braid is special because it has levels function treebraider(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple, levels) - return fusiontreetransform(f1, f2) = braid(f1, f2, levels..., p...) + return fusiontreetransform((f1, f2)) = braid((f1, f2), levels, p) end function treebraider(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple, levels) return treebraider(space(tdst), space(tsrc), p, levels) end @cached function treebraider(Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple, levels)::treetransformertype(Vdst, Vsrc) - fusiontreebraider(f1, f2) = braid(f1, f2, levels..., p...) + fusiontreebraider((f1, f2)) = braid((f1, f2), levels, p) return TreeTransformer(fusiontreebraider, p, Vdst, Vsrc) end @@ -181,14 +181,14 @@ for (transform, treetransformer) in ((:permute, :treepermuter), (:transpose, :treetransposer)) @eval begin function $treetransformer(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple) - return fusiontreetransform(f1, f2) = $transform(f1, f2, p...) + return fusiontreetransform(f1, f2) = $transform((f1, f2), p) end function $treetransformer(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple) return $treetransformer(space(tdst), space(tsrc), p) end @cached function $treetransformer(Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple)::treetransformertype(Vdst, Vsrc) - fusiontreetransform(f1, f2) = $transform(f1, f2, p...) + fusiontreetransform((f1, f2)) = $transform((f1, f2), p) return TreeTransformer(fusiontreetransform, p, Vdst, Vsrc) end end diff --git a/test/fusiontrees.jl b/test/fusiontrees.jl index d758f9a83..707002966 100644 --- a/test/fusiontrees.jl +++ b/test/fusiontrees.jl @@ -142,7 +142,7 @@ ti = time() @test bf ≈ bf′ atol = 1e-12 end - d2 = @constinferred TK.planar_trace(f, (1, 3), (2, 4)) + d2 = @constinferred TK.planar_trace(f, ((1, 3), (2, 4))) oind2 = (5, 6, 7) bf2 = tensortrace(af, (:a, :a, :b, :b, :c, :d, :e)) bf2′ = zero(bf2) @@ -151,7 +151,7 @@ ti = time() end @test bf2 ≈ bf2′ atol = 1e-12 - d2 = @constinferred TK.planar_trace(f, (5, 6), (2, 1)) + d2 = @constinferred TK.planar_trace(f, ((5, 6), (2, 1))) oind2 = (3, 4, 7) bf2 = tensortrace(af, (:a, :b, :c, :d, :b, :a, :e)) bf2′ = zero(bf2) @@ -160,7 +160,7 @@ ti = time() end @test bf2 ≈ bf2′ atol = 1e-12 - d2 = @constinferred TK.planar_trace(f, (1, 4), (6, 3)) + d2 = @constinferred TK.planar_trace(f, ((1, 4), (6, 3))) bf2 = tensortrace(af, (:a, :b, :c, :c, :d, :a, :e)) bf2′ = zero(bf2) for (f2′, coeff) in d2 @@ -170,7 +170,7 @@ ti = time() q1 = (1, 3, 5) q2 = (2, 4, 6) - d3 = @constinferred TK.planar_trace(f, q1, q2) + d3 = @constinferred TK.planar_trace(f, (q1, q2)) bf3 = tensortrace(af, (:a, :a, :b, :b, :c, :c, :d)) bf3′ = zero(bf3) for (f3′, coeff) in d3 @@ -180,7 +180,7 @@ ti = time() q1 = (1, 3, 5) q2 = (6, 2, 4) - d3 = @constinferred TK.planar_trace(f, q1, q2) + d3 = @constinferred TK.planar_trace(f, (q1, q2)) bf3 = tensortrace(af, (:a, :b, :b, :c, :c, :a, :d)) bf3′ = zero(bf3) for (f3′, coeff) in d3 @@ -190,7 +190,7 @@ ti = time() q1 = (1, 2, 3) q2 = (6, 5, 4) - d3 = @constinferred TK.planar_trace(f, q1, q2) + d3 = @constinferred TK.planar_trace(f, (q1, q2)) bf3 = tensortrace(af, (:a, :b, :c, :c, :b, :a, :d)) bf3′ = zero(bf3) for (f3′, coeff) in d3 @@ -200,7 +200,7 @@ ti = time() q1 = (1, 2, 4) q2 = (6, 3, 5) - d3 = @constinferred TK.planar_trace(f, q1, q2) + d3 = @constinferred TK.planar_trace(f, (q1, q2)) bf3 = tensortrace(af, (:a, :b, :b, :c, :c, :a, :d)) bf3′ = zero(bf3) for (f3′, coeff) in d3 @@ -381,12 +381,12 @@ ti = time() @testset "Double fusion tree $Istr: repartioning" begin for n in 0:(2 * N) - d = @constinferred TK.repartition(f1, f2, $n) + d = @constinferred TK.repartition((f1, f2), $n) @test dim(incoming) ≈ sum(abs2(coef) * dim(f1.coupled) for ((f1, f2), coef) in d) d2 = Dict{typeof((f1, f2)),valtype(d)}() for ((f1′, f2′), coeff) in d - for ((f1′′, f2′′), coeff2) in TK.repartition(f1′, f2′, N) + for ((f1′′, f2′′), coeff2) in TK.repartition((f1′, f2′), N) d2[(f1′′, f2′′)] = get(d2, (f1′′, f2′′), zero(coeff)) + coeff2 * coeff end end @@ -432,12 +432,12 @@ ti = time() ip = invperm(p) ip1, ip2 = ip[1:N], ip[(N + 1):(2N)] - d = @constinferred TensorKit.permute(f1, f2, p1, p2) + d = @constinferred TensorKit.permute((f1, f2), (p1, p2)) @test dim(incoming) ≈ sum(abs2(coef) * dim(f1.coupled) for ((f1, f2), coef) in d) d2 = Dict{typeof((f1, f2)),valtype(d)}() for ((f1′, f2′), coeff) in d - d′ = TensorKit.permute(f1′, f2′, ip1, ip2) + d′ = TensorKit.permute((f1′, f2′), (ip1, ip2)) for ((f1′′, f2′′), coeff2) in d′ d2[(f1′′, f2′′)] = get(d2, (f1′′, f2′′), zero(coeff)) + coeff2 * coeff @@ -490,12 +490,12 @@ ti = time() ip′ = tuple(getindex.(Ref(vcat(1:n, (2N):-1:(n + 1))), ip)...) ip1, ip2 = ip′[1:N], ip′[(2N):-1:(N + 1)] - d = @constinferred transpose(f1, f2, p1, p2) + d = @constinferred transpose((f1, f2), (p1, p2)) @test dim(incoming) ≈ sum(abs2(coef) * dim(f1.coupled) for ((f1, f2), coef) in d) d2 = Dict{typeof((f1, f2)),valtype(d)}() for ((f1′, f2′), coeff) in d - d′ = transpose(f1′, f2′, ip1, ip2) + d′ = transpose((f1′, f2′), (ip1, ip2)) for ((f1′′, f2′′), coeff2) in d′ d2[(f1′′, f2′′)] = get(d2, (f1′′, f2′′), zero(coeff)) + coeff2 * coeff end @@ -509,7 +509,7 @@ ti = time() end if BraidingStyle(I) isa Bosonic - d3 = permute(f1, f2, p1, p2) + d3 = permute((f1, f2), (p1, p2)) for (f1′, f2′) in union(keys(d), keys(d3)) coeff1 = get(d, (f1′, f2′), zero(valtype(d))) coeff3 = get(d3, (f1′, f2′), zero(valtype(d3))) @@ -546,14 +546,15 @@ ti = time() end end @testset "Double fusion tree $Istr: planar trace" begin - d1 = transpose(f1, f1, (N + 1, 1:N..., ((2N):-1:(N + 3))...), (N + 2,)) + d1 = transpose((f1, f1), ((N + 1, 1:N..., ((2N):-1:(N + 3))...), (N + 2,))) f1front, = TK.split(f1, N - 1) T = typeof(Fsymbol(one(I), one(I), one(I), one(I), one(I), one(I))[1, 1, 1, 1]) d2 = Dict{typeof((f1front, f1front)),T}() for ((f1′, f2′), coeff′) in d1 for ((f1′′, f2′′), coeff′′) in - TK.planar_trace(f1′, f2′, (2:N...,), (1, ((2N):-1:(N + 3))...), (N + 1,), - (N + 2,)) + TK.planar_trace((f1′, f2′), ((2:N...,), (1, ((2N):-1:(N + 3))...)), + ((N + 1,), + (N + 2,))) coeff = coeff′ * coeff′′ d2[(f1′′, f2′′)] = get(d2, (f1′′, f2′′), zero(coeff)) + coeff end From 03c3c3b88eb794d6a89661dfcfda4e1160c7f44c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 30 Jul 2025 20:49:21 -0400 Subject: [PATCH 02/28] implement "vectorized" fusiontree manipulations --- src/fusiontrees/fusiontrees.jl | 7 +- src/fusiontrees/uncouplediterator.jl | 353 +++++++++++++++++++++++++++ 2 files changed, 357 insertions(+), 3 deletions(-) create mode 100644 src/fusiontrees/uncouplediterator.jl diff --git a/src/fusiontrees/fusiontrees.jl b/src/fusiontrees/fusiontrees.jl index 71d67953d..b44f73b34 100644 --- a/src/fusiontrees/fusiontrees.jl +++ b/src/fusiontrees/fusiontrees.jl @@ -225,11 +225,12 @@ function Base.show(io::IO, t::FusionTree{I}) where {I<:Sector} end end -# Manipulate fusion trees -include("manipulations.jl") - # Fusion tree iterators include("iterator.jl") +include("uncouplediterator.jl") + +# Manipulate fusion trees +include("manipulations.jl") # auxiliary routines # _abelianinner: generate the inner indices for given outer indices in the abelian case diff --git a/src/fusiontrees/uncouplediterator.jl b/src/fusiontrees/uncouplediterator.jl new file mode 100644 index 000000000..e0e923167 --- /dev/null +++ b/src/fusiontrees/uncouplediterator.jl @@ -0,0 +1,353 @@ +struct OuterTreeIterator{I<:Sector,N₁,N₂} + uncoupled::Tuple{NTuple{N₁,I},NTuple{N₂,I}} + isdual::Tuple{NTuple{N₁,Bool},NTuple{N₂,Bool}} +end + +sectortype(::Type{<:OuterTreeIterator{I}}) where {I} = I +numout(fs::OuterTreeIterator) = numout(typeof(fs)) +numout(::Type{<:OuterTreeIterator{I,N₁}}) where {I,N₁} = N₁ +numin(fs::OuterTreeIterator) = numin(typeof(fs)) +numin(::Type{<:OuterTreeIterator{I,N₁,N₂}}) where {I,N₁,N₂} = N₂ +numind(fs::OuterTreeIterator) = numind(typeof(fs)) +numind(::Type{T}) where {T<:OuterTreeIterator} = numin(T) + numout(T) + +# TODO: should we make this an actual iterator? +function fusiontrees(iter::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} + F₁ = fusiontreetype(I, N₁) + F₂ = fusiontreetype(I, N₂) + + trees = Vector{Tuple{F₁,F₂}}(undef, 0) + for c in blocksectors(iter), f₁ in fusiontrees(iter.uncoupled[1], c, iter.isdual[1]), + f₂ in fusiontrees(iter.uncoupled[2], c, iter.isdual[2]) + + push!(trees, (f₁, f₂)) + end + return trees +end + +# TODO: better implementation +Base.length(iter::OuterTreeIterator) = length(fusiontrees(iter)) + +function blocksectors(iter::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} + I == Trivial && return (Trivial(),) + + bs_codomain = Vector{I}() + if N₁ == 0 + push!(bs_codomain, one(I)) + elseif N₁ == 1 + push!(bs_codomain, only(iter.uncoupled[1])) + else + for c in ⊗(iter.uncoupled[1]...) + if !(c in bs_codomain) + push!(bs_codomain, c) + end + end + end + + bs_domain = Vector{I}() + if N₂ == 0 + push!(bs_domain, one(I)) + elseif N₂ == 1 + push!(bs_domain, only(iter.uncoupled[2])) + else + for c in ⊗(iter.uncoupled[2]...) + if !(c in bs_domain) + push!(bs_domain, c) + end + end + end + + return sort!(collect(intersect(bs_codomain, bs_domain))) +end + +# Manipulations +# ------------- + +function bendright(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} + uncoupled_dst = (TupleTools.front(fs_src.uncoupled[1]), + (fs_src.uncoupled[2]..., dual(fs_src.uncoupled[1][end]))) + isdual_dst = (TupleTools.front(fs_src.isdual[1]), + (fs_src.isdual[2]..., !(fs_src.isdual[1][end]))) + fs_dst = OuterTreeIterator(uncoupled_dst, isdual_dst) + + trees_src = fusiontrees(fs_src) + trees_dst = fusiontrees(fs_dst) + indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) + U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) + + for (col, f) in enumerate(trees_src) + for (f′, c) in bendright(f) + row = indexmap[f′] + U[row, col] = c + end + end + + return fs_dst, U +end + +# TODO: verify if this can be computed through an adjoint +function bendleft(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} + uncoupled_dst = ((fs_src.uncoupled[1]..., dual(fs_src.uncoupled[2][end])), + TupleTools.front(fs_src.uncoupled[2])) + isdual_dst = ((fs_src.isdual[1]..., !(fs_src.isdual[2][end])), + TupleTools.front(fs_src.isdual[2])) + fs_dst = OuterTreeIterator(uncoupled_dst, isdual_dst) + + trees_src = fusiontrees(fs_src) + trees_dst = fusiontrees(fs_dst) + indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) + U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) + + for (col, f) in enumerate(trees_src) + for (f′, c) in bendleft(f) + row = indexmap[f′] + U[row, col] = c + end + end + + return fs_dst, U +end + +function foldright(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} + uncoupled_dst = (Base.tail(fs_src.uncoupled[1]), + (dual(first(fs_src.uncoupled[1])), fs_src.uncoupled[2]...)) + isdual_dst = (Base.tail(fs_src.isdual[1]), + (!first(fs_src.isdual[1]), fs_src.isdual[2]...)) + fs_dst = OuterTreeIterator(uncoupled_dst, isdual_dst) + + trees_src = fusiontrees(fs_src) + trees_dst = fusiontrees(fs_dst) + indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) + U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) + + for (col, f) in enumerate(trees_src) + for (f′, c) in foldright(f) + row = indexmap[f′] + U[row, col] = c + end + end + + return fs_dst, U +end + +# TODO: verify if this can be computed through an adjoint +function foldleft(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} + uncoupled_dst = ((dual(first(fs_src.uncoupled[2])), fs_src.uncoupled[1]...), + Base.tail(fs_src.uncoupled[2])) + isdual_dst = ((!first(fs_src.isdual[2]), fs_src.isdual[1]...), + Base.tail(fs_src.isdual[2])) + fs_dst = OuterTreeIterator(uncoupled_dst, isdual_dst) + + trees_src = fusiontrees(fs_src) + trees_dst = fusiontrees(fs_dst) + indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) + U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) + + for (col, f) in enumerate(trees_src) + for (f′, c) in foldleft(f) + row = indexmap[f′] + U[row, col] = c + end + end + + return fs_dst, U +end + +function cycleclockwise(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} + if N₁ > 0 + fs_tmp, U₁ = foldright(fs_src) + fs_dst, U₂ = bendleft(fs_tmp) + else + fs_tmp, U₁ = bendleft(fs_src) + fs_dst, U₂ = foldright(fs_tmp) + end + return fs_dst, U₂ * U₁ +end + +function cycleanticlockwise(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} + if N₂ > 0 + fs_tmp, U₁ = foldleft(fs_src) + fs_dst, U₂ = bendright(fs_tmp) + else + fs_tmp, U₁ = bendright(fs_src) + fs_dst, U₂ = foldleft(fs_tmp) + end + return fs_dst, U₂ * U₁ +end + +@inline function repartition(fs_src::OuterTreeIterator{I,N₁,N₂}, N::Int) where {I,N₁,N₂} + @assert 0 <= N <= N₁ + N₂ + return _recursive_repartition(fs_src, Val(N)) +end + +function _repartition_type(I, N, N₁, N₂) + return Tuple{OuterTreeIterator{I,N,N₁ + N₂ - N},Matrix{sectorscalartype(I)}} +end +function _recursive_repartition(fs_src::OuterTreeIterator{I,N₁,N₂}, + ::Val{N})::_repartition_type(I, N, N₁, N₂) where {I,N₁,N₂,N} + if N == N₁ + fs_dst = fs_src + U = zeros(sectorscalartype(I), length(fs_dst), length(fs_src)) + copyto!(U, LinearAlgebra.I) + return fs_dst, U + end + + N == N₁ - 1 && return bendright(fs_src) + N == N₁ + 1 && return bendleft(fs_src) + + fs_tmp, U₁ = N < N₁ ? bendright(fs_src) : bendleft(fs_src) + fs_dst, U₂ = _recursive_repartition(fs_tmp, Val(N)) + return fs_dst, U₂ * U₁ +end + +function Base.transpose(fs_src::OuterTreeIterator{I}, p::Index2Tuple{N₁,N₂}) where {I,N₁,N₂} + N = N₁ + N₂ + @assert numind(fs_src) == N + p′ = linearizepermutation(p..., numout(fs_src), numin(fs_src)) + @assert iscyclicpermutation(p′) + return _fstranspose((fs_src, p)) +end + +const _FSTransposeKey{I,N₁,N₂} = Tuple{<:OuterTreeIterator{I},Index2Tuple{N₁,N₂}} + +@cached function _fstranspose(key::_FSTransposeKey{I,N₁,N₂})::Tuple{OuterTreeIterator{I,N₁, + N₂}, + Matrix{sectorscalartype(I)}} where {I, + N₁, + N₂} + fs_src, (p1, p2) = key + + N = N₁ + N₂ + p = linearizepermutation(p1, p2, numout(fs_src), numin(fs_src)) + + fs_dst, U = repartition(fs_src, N₁) + length(p) == 0 && return fs_dst, U + i1 = findfirst(==(1), p)::Int + i1 == 1 && return fs_dst, U + + Nhalf = N >> 1 + while 1 < i1 ≤ Nhalf + fs_dst, U_tmp = cycleanticlockwise(fs_dst) + U = U_tmp * U + i1 -= 1 + end + while Nhalf < i1 + fs_dst, U_tmp = cycleclockwise(fs_dst) + U = U_tmp * U + i1 = mod1(i1 + 1, N) + end + + return fs_dst, U +end + +function CacheStyle(::typeof(_fstranspose), k::_FSTransposeKey{I}) where {I} + if FusionStyle(I) == UniqueFusion() + return NoCache() + else + return GlobalLRUCache() + end +end + +function artin_braid(fs_src::OuterTreeIterator{I,N,0}, i; inv::Bool=false) where {I,N} + 1 <= i < N || + throw(ArgumentError("Cannot swap outputs i=$i and i+1 out of only $N outputs")) + + uncoupled = fs_src.uncoupled[1] + uncoupled′ = TupleTools.setindex(uncoupled, uncoupled[i + 1], i) + uncoupled′ = TupleTools.setindex(uncoupled′, uncoupled[i], i + 1) + + isdual = fs_src.isdual[1] + isdual′ = TupleTools.setindex(isdual, isdual[i], i + 1) + isdual′ = TupleTools.setindex(isdual′, isdual[i + 1], i) + + fs_dst = OuterTreeIterator((uncoupled′, ()), (isdual′, ())) + + trees_src = fusiontrees(fs_src) + trees_dst = fusiontrees(fs_dst) + indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) + U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) + + for (col, (f₁, f₂)) in enumerate(trees_src) + for (f₁′, c) in artin_braid(f₁, i; inv) + row = indexmap[(f₁′, f₂)] + U[row, col] = c + end + end + + return fs_dst, U +end + +function braid(fs_src::OuterTreeIterator{I,N,0}, levels::NTuple{N,Int}, + p::NTuple{N,Int}) where {I,N} + TupleTools.isperm(p) || throw(ArgumentError("not a valid permutation: $p")) + + if FusionStyle(I) isa UniqueFusion && BraidingStyle(I) isa SymmetricBraiding + uncoupled′ = TupleTools._permute(fs_src.uncoupled[1], p) + isdual′ = TupleTools._permute(fs_src.isdual[1], p) + fs_dst = OuterTreeIterator(uncoupled′, isdual′) + + trees_src = fusiontrees(fs_src) + trees_dst = fusiontrees(fs_dst) + indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) + U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) + + for (col, (f₁, f₂)) in enumerate(trees_src) + for (f₁′, c) in braid(f₁, levels, p) + row = indexmap[(f₁′, f₂)] + U[row, col] = c + end + end + + return fs_dst, U + end + + fs_dst, U = repartition(fs_src, N) # TODO: can we avoid this? + for s in permutation2swaps(p) + inv = levels[s] > levels[s + 1] + fs_dst, U_tmp = artin_braid(fs_dst, s; inv) + U = U_tmp * U + end + return fs_dst, U +end + +function braid(fs_src::OuterTreeIterator{I}, levels::Index2Tuple, + p::Index2Tuple{N₁,N₂}) where {I,N₁,N₂} + @assert numind(fs_src) == N₁ + N₂ + @assert numout(fs_src) == length(levels[1]) && numin(fs_src) == length(levels[2]) + @assert TupleTools.isperm((p[1]..., p[2]...)) + return _fsbraid((fs_src, levels, p)) +end + +const _FSBraidKey{I,N₁,N₂} = Tuple{<:OuterTreeIterator{I},Index2Tuple,Index2Tuple{N₁,N₂}} + +@cached function _fsbraid(key::_FSBraidKey{I,N₁,N₂})::Tuple{OuterTreeIterator{I,N₁,N₂}, + Matrix{sectorscalartype(I)}} where {I, + N₁, + N₂} + fs_src, (l1, l2), (p1, p2) = key + + p = linearizepermutation(p1, p2, numout(fs_src), numin(fs_src)) + levels = (l1..., reverse(l2)...) + + fs_dst, U = repartition(fs_src, numind(fs_src)) + fs_dst, U_tmp = braid(fs_dst, levels, p) + U = U_tmp * U + fs_dst, U_tmp = repartition(fs_dst, N₁) + U = U_tmp * U + return fs_dst, U +end + +function CacheStyle(::typeof(_fsbraid), k::_FSBraidKey{I}) where {I} + if FusionStyle(I) isa UniqueFusion + return NoCache() + else + return GlobalLRUCache() + end +end + +function permute(fs_src::OuterTreeIterator{I}, p::Index2Tuple) where {I} + @assert BraidingStyle(I) isa SymmetricBraiding + levels1 = ntuple(identity, numout(fs_src)) + levels2 = numout(fs_src) .+ ntuple(identity, numin(fs_src)) + return braid(fs_src, (levels1, levels2), p) +end From 9c34aab92a5649ea63da8ed3df58d6fab7475052 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 30 Jul 2025 21:22:19 -0400 Subject: [PATCH 03/28] Refactor treetransformer to make use of vectorized implementation --- src/tensors/treetransformers.jl | 49 ++++++++++++--------------------- 1 file changed, 18 insertions(+), 31 deletions(-) diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index aa9a3da12..af2c58ad9 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -62,39 +62,26 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) fusionstructure_dst = structure_dst.fusiontreestructure structure_src = fusionblockstructure(Vsrc) fusionstructure_src = structure_src.fusiontreestructure - I = sectortype(Vsrc) - - uncoupleds_src = map(structure_src.fusiontreelist) do (f₁, f₂) - return TupleTools.vcat(f₁.uncoupled, dual.(f₂.uncoupled)) - end - uncoupleds_src_unique = unique(uncoupleds_src) - - uncoupleds_dst = map(structure_dst.fusiontreelist) do (f₁, f₂) - return TupleTools.vcat(f₁.uncoupled, dual.(f₂.uncoupled)) - end + I = sectortype(Vsrc) T = sectorscalartype(I) N = numind(Vdst) - L = length(uncoupleds_src_unique) - data = Vector{_GenericTransformerData{T,N}}(undef, L) + data = Vector{_GenericTransformerData{T,N}}() - # TODO: this can be multithreaded - for (i, uncoupled) in enumerate(uncoupleds_src_unique) - inds_src = findall(==(uncoupled), uncoupleds_src) - fusiontrees_outer_src = structure_src.fusiontreelist[inds_src] + isdual_src = (map(isdual, codomain(Vsrc).spaces), map(isdual, domain(Vsrc).spaces)) + for cod_uncoupled_src in sectors(codomain(Vsrc)), + dom_uncoupled_src in sectors(domain(Vsrc)) - uncoupled_dst = TupleTools.getindices(uncoupled, (p[1]..., p[2]...)) - inds_dst = findall(==(uncoupled_dst), uncoupleds_dst) + fs_src = OuterTreeIterator((cod_uncoupled_src, dom_uncoupled_src), isdual_src) + trees_src = fusiontrees(fs_src) + isempty(trees_src) && continue - fusiontrees_outer_dst = structure_dst.fusiontreelist[inds_dst] + fs_dst, U = transform(fs_src) + matrix = copy(transpose(U)) # TODO: should we avoid this - matrix = zeros(sectorscalartype(I), length(inds_dst), length(inds_src)) - for (row, (f₁, f₂)) in enumerate(fusiontrees_outer_src) - for ((f₃, f₄), coeff) in transform(f₁, f₂) - col = findfirst(==((f₃, f₄)), fusiontrees_outer_dst)::Int - matrix[row, col] = coeff - end - end + inds_src = map(Base.Fix1(getindex, structure_src.fusiontreeindices), trees_src) + trees_dst = fusiontrees(fs_dst) + inds_dst = map(Base.Fix1(getindex, structure_dst.fusiontreeindices), trees_dst) # size is shared between blocks, so repack: # from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...]) @@ -104,7 +91,7 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) @debug("Created recoupling block for uncoupled: $uncoupled", sz = size(matrix), sparsity = count(!iszero, matrix) / length(matrix)) - data[i] = (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src)) + push!(data, (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src))) end transformer = GenericTreeTransformer{T,N}(data) @@ -166,14 +153,14 @@ end # braid is special because it has levels function treebraider(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple, levels) - return fusiontreetransform((f1, f2)) = braid((f1, f2), levels, p) + return fusiontreetransform(f) = braid(f, levels, p) end function treebraider(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple, levels) return treebraider(space(tdst), space(tsrc), p, levels) end @cached function treebraider(Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple, levels)::treetransformertype(Vdst, Vsrc) - fusiontreebraider((f1, f2)) = braid((f1, f2), levels, p) + fusiontreebraider(f) = braid(f, levels, p) return TreeTransformer(fusiontreebraider, p, Vdst, Vsrc) end @@ -181,14 +168,14 @@ for (transform, treetransformer) in ((:permute, :treepermuter), (:transpose, :treetransposer)) @eval begin function $treetransformer(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple) - return fusiontreetransform(f1, f2) = $transform((f1, f2), p) + return fusiontreetransform(f) = $transform(f, p) end function $treetransformer(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple) return $treetransformer(space(tdst), space(tsrc), p) end @cached function $treetransformer(Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple)::treetransformertype(Vdst, Vsrc) - fusiontreetransform((f1, f2)) = $transform((f1, f2), p) + fusiontreetransform(f) = $transform(f, p) return TreeTransformer(fusiontreetransform, p, Vdst, Vsrc) end end From 1f1f734628f2cad730e6bb23b01bf9ce6a54f0eb Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 30 Jul 2025 22:17:00 -0400 Subject: [PATCH 04/28] fix arg order `braid` --- src/fusiontrees/manipulations.jl | 26 ++++++++++----------- src/fusiontrees/uncouplediterator.jl | 35 ++++++++++++++-------------- src/tensors/treetransformers.jl | 4 ++-- test/fusiontrees.jl | 10 ++++---- 4 files changed, 37 insertions(+), 38 deletions(-) diff --git a/src/fusiontrees/manipulations.jl b/src/fusiontrees/manipulations.jl index e9122e592..9a8402870 100644 --- a/src/fusiontrees/manipulations.jl +++ b/src/fusiontrees/manipulations.jl @@ -936,7 +936,7 @@ end # braid fusion tree """ - braid(f::FusionTree{<:Sector, N}, levels::NTuple{N, Int}, p::NTuple{N, Int}) + braid(f::FusionTree{<:Sector, N}, p::NTuple{N, Int}, levels::NTuple{N, Int}) -> <:AbstractDict{typeof(t), <:Number} Perform a braiding of the uncoupled indices of the fusion tree `f` and return the result as @@ -949,7 +949,7 @@ that if `i` and `j` cross, ``τ_{i,j}`` is applied if `levels[i] < levels[j]` an ``τ_{j,i}^{-1}`` if `levels[i] > levels[j]`. This does not allow to encode the most general braid, but a general braid can be obtained by combining such operations. """ -function braid(f::FusionTree{I,N}, levels::NTuple{N,Int}, p::NTuple{N,Int}) where {I,N} +function braid(f::FusionTree{I,N}, p::NTuple{N,Int}, levels::NTuple{N,Int}) where {I,N} TupleTools.isperm(p) || throw(ArgumentError("not a valid permutation: $p")) if FusionStyle(I) isa UniqueFusion && BraidingStyle(I) isa SymmetricBraiding coeff = one(sectorscalartype(I)) @@ -998,13 +998,13 @@ as a `<:AbstractDict` of output trees and corresponding coefficients; this requi """ function permute(f::FusionTree{I,N}, p::NTuple{N,Int}) where {I,N} @assert BraidingStyle(I) isa SymmetricBraiding - return braid(f, ntuple(identity, Val(N)), p) + return braid(f, p, ntuple(identity, Val(N))) end # braid double fusion tree """ - braid((f₁, f₂)::FusionTreePair{I}, (levels1, levels2)::Index2Tuple, - (p1, p2)::Index2Tuple{N₁, N₂}) where {I, N₁, N₂} + braid((f₁, f₂)::FusionTreePair{I}, (p1, p2)::Index2Tuple{N₁,N₂}, + (levels1, levels2)::Index2Tuple) where {I,N₁,N₂} -> <:AbstractDict{<:FusionTreePair{I, N₁, N₂}}, <:Number} Input is a fusion-splitting tree pair that describes the fusion of a set of incoming @@ -1019,22 +1019,22 @@ respectively, which determines how indices braid. In particular, if `i` and `j` levels[j]`. This does not allow to encode the most general braid, but a general braid can be obtained by combining such operations. """ -function braid((f₁, f₂)::FusionTreePair{I}, (levels1, levels2)::Index2Tuple, - (p1, p2)::Index2Tuple{N₁,N₂}) where {I,N₁,N₂} +function braid((f₁, f₂)::FusionTreePair{I}, (p1, p2)::Index2Tuple{N₁,N₂}, + (levels1, levels2)::Index2Tuple) where {I,N₁,N₂} @assert length(f₁) + length(f₂) == N₁ + N₂ @assert length(f₁) == length(levels1) && length(f₂) == length(levels2) @assert TupleTools.isperm((p1..., p2...)) - return fsbraid(((f₁, f₂), (levels1, levels2), (p1, p2))) + return fsbraid(((f₁, f₂), (p1, p2), (levels1, levels2))) end -const FSBraidKey{I,N₁,N₂} = Tuple{<:FusionTreePair{I},Index2Tuple,Index2Tuple{N₁,N₂}} +const FSBraidKey{I,N₁,N₂} = Tuple{<:FusionTreePair{I},Index2Tuple{N₁,N₂},Index2Tuple} @cached function fsbraid(key::FSBraidKey{I,N₁,N₂})::_fsdicttype(I, N₁, N₂) where {I,N₁,N₂} - ((f₁, f₂), (l1, l2), (p1, p2)) = key + ((f₁, f₂), (p1, p2), (l1, l2)) = key p = linearizepermutation(p1, p2, length(f₁), length(f₂)) levels = (l1..., reverse(l2)...) local newtrees for ((f, f0), coeff1) in repartition((f₁, f₂), N₁ + N₂) - for (f′, coeff2) in braid(f, levels, p) + for (f′, coeff2) in braid(f, p, levels) for ((f₁′, f₂′), coeff3) in repartition((f′, f0), N₁) if @isdefined newtrees newtrees[(f₁′, f₂′)] = get(newtrees, (f₁′, f₂′), zero(coeff3)) + @@ -1067,9 +1067,9 @@ outgoing (`t1`) and incoming sectors (`t2`) respectively (with identical coupled repartitioning and permuting the tree such that sectors `p1` become outgoing and sectors `p2` become incoming. """ -function permute((f₁, f₂)::FusionTreePair{I}, (p1, p2)::Index2Tuple{N₁, N₂}) where {I, N₁, N₂} +function permute((f₁, f₂)::FusionTreePair{I}, (p1, p2)::Index2Tuple{N₁,N₂}) where {I,N₁,N₂} @assert BraidingStyle(I) isa SymmetricBraiding levels1 = ntuple(identity, length(f₁)) levels2 = length(f₁) .+ ntuple(identity, length(f₂)) - return braid((f₁, f₂), (levels1, levels2), (p1, p2)) + return braid((f₁, f₂), (p1, p2), (levels1, levels2)) end diff --git a/src/fusiontrees/uncouplediterator.jl b/src/fusiontrees/uncouplediterator.jl index e0e923167..d5906b417 100644 --- a/src/fusiontrees/uncouplediterator.jl +++ b/src/fusiontrees/uncouplediterator.jl @@ -19,7 +19,6 @@ function fusiontrees(iter::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} trees = Vector{Tuple{F₁,F₂}}(undef, 0) for c in blocksectors(iter), f₁ in fusiontrees(iter.uncoupled[1], c, iter.isdual[1]), f₂ in fusiontrees(iter.uncoupled[2], c, iter.isdual[2]) - push!(trees, (f₁, f₂)) end return trees @@ -211,10 +210,10 @@ end const _FSTransposeKey{I,N₁,N₂} = Tuple{<:OuterTreeIterator{I},Index2Tuple{N₁,N₂}} @cached function _fstranspose(key::_FSTransposeKey{I,N₁,N₂})::Tuple{OuterTreeIterator{I,N₁, - N₂}, - Matrix{sectorscalartype(I)}} where {I, - N₁, - N₂} + N₂}, + Matrix{sectorscalartype(I)}} where {I, + N₁, + N₂} fs_src, (p1, p2) = key N = N₁ + N₂ @@ -277,8 +276,8 @@ function artin_braid(fs_src::OuterTreeIterator{I,N,0}, i; inv::Bool=false) where return fs_dst, U end -function braid(fs_src::OuterTreeIterator{I,N,0}, levels::NTuple{N,Int}, - p::NTuple{N,Int}) where {I,N} +function braid(fs_src::OuterTreeIterator{I,N,0}, p::NTuple{N,Int}, + levels::NTuple{N,Int}) where {I,N} TupleTools.isperm(p) || throw(ArgumentError("not a valid permutation: $p")) if FusionStyle(I) isa UniqueFusion && BraidingStyle(I) isa SymmetricBraiding @@ -292,7 +291,7 @@ function braid(fs_src::OuterTreeIterator{I,N,0}, levels::NTuple{N,Int}, U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) for (col, (f₁, f₂)) in enumerate(trees_src) - for (f₁′, c) in braid(f₁, levels, p) + for (f₁′, c) in braid(f₁, p, levels) row = indexmap[(f₁′, f₂)] U[row, col] = c end @@ -310,27 +309,27 @@ function braid(fs_src::OuterTreeIterator{I,N,0}, levels::NTuple{N,Int}, return fs_dst, U end -function braid(fs_src::OuterTreeIterator{I}, levels::Index2Tuple, - p::Index2Tuple{N₁,N₂}) where {I,N₁,N₂} +function braid(fs_src::OuterTreeIterator{I}, p::Index2Tuple{N₁,N₂}, + levels::Index2Tuple) where {I,N₁,N₂} @assert numind(fs_src) == N₁ + N₂ @assert numout(fs_src) == length(levels[1]) && numin(fs_src) == length(levels[2]) @assert TupleTools.isperm((p[1]..., p[2]...)) - return _fsbraid((fs_src, levels, p)) + return _fsbraid((fs_src, p, levels)) end -const _FSBraidKey{I,N₁,N₂} = Tuple{<:OuterTreeIterator{I},Index2Tuple,Index2Tuple{N₁,N₂}} +const _FSBraidKey{I,N₁,N₂} = Tuple{<:OuterTreeIterator{I},Index2Tuple{N₁,N₂},Index2Tuple} @cached function _fsbraid(key::_FSBraidKey{I,N₁,N₂})::Tuple{OuterTreeIterator{I,N₁,N₂}, - Matrix{sectorscalartype(I)}} where {I, - N₁, - N₂} - fs_src, (l1, l2), (p1, p2) = key + Matrix{sectorscalartype(I)}} where {I, + N₁, + N₂} + fs_src, (p1, p2), (l1, l2) = key p = linearizepermutation(p1, p2, numout(fs_src), numin(fs_src)) levels = (l1..., reverse(l2)...) fs_dst, U = repartition(fs_src, numind(fs_src)) - fs_dst, U_tmp = braid(fs_dst, levels, p) + fs_dst, U_tmp = braid(fs_dst, p, levels) U = U_tmp * U fs_dst, U_tmp = repartition(fs_dst, N₁) U = U_tmp * U @@ -349,5 +348,5 @@ function permute(fs_src::OuterTreeIterator{I}, p::Index2Tuple) where {I} @assert BraidingStyle(I) isa SymmetricBraiding levels1 = ntuple(identity, numout(fs_src)) levels2 = numout(fs_src) .+ ntuple(identity, numin(fs_src)) - return braid(fs_src, (levels1, levels2), p) + return braid(fs_src, p, (levels1, levels2)) end diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index af2c58ad9..8cec1fefd 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -153,14 +153,14 @@ end # braid is special because it has levels function treebraider(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple, levels) - return fusiontreetransform(f) = braid(f, levels, p) + return fusiontreetransform(f) = braid(f, p, levels) end function treebraider(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple, levels) return treebraider(space(tdst), space(tsrc), p, levels) end @cached function treebraider(Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple, levels)::treetransformertype(Vdst, Vsrc) - fusiontreebraider(f) = braid(f, levels, p) + fusiontreebraider(f) = braid(f, p, levels) return TreeTransformer(fusiontreebraider, p, Vdst, Vsrc) end diff --git a/test/fusiontrees.jl b/test/fusiontrees.jl index 707002966..f94e52c6e 100644 --- a/test/fusiontrees.jl +++ b/test/fusiontrees.jl @@ -89,13 +89,13 @@ ti = time() @test c′ == one(c′) return t′ end - braid_i_to_1 = braid(f1, levels, (i, (1:(i - 1))..., ((i + 1):N)...)) + braid_i_to_1 = braid(f1, (i, (1:(i - 1))..., ((i + 1):N)...), levels) trees2 = Dict(_reinsert_partial_tree(t, f2) => c for (t, c) in braid_i_to_1) trees3 = empty(trees2) p = (((N + 1):(N + i - 1))..., (1:N)..., ((N + i):(2N - 1))...) levels = ((i:(N + i - 1))..., (1:(i - 1))..., ((i + N):(2N - 1))...) for (t, coeff) in trees2 - for (t′, coeff′) in braid(t, levels, p) + for (t′, coeff′) in braid(t, p, levels) trees3[t′] = get(trees3, t′, zero(coeff′)) + coeff * coeff′ end end @@ -273,11 +273,11 @@ ti = time() ip = invperm(p) levels = ntuple(identity, N) - d = @constinferred braid(f, levels, p) + d = @constinferred braid(f, p, levels) d2 = Dict{typeof(f),valtype(d)}() levels2 = p for (f2, coeff) in d - for (f1, coeff2) in braid(f2, levels2, ip) + for (f1, coeff2) in braid(f2, ip, levels2) d2[f1] = get(d2, f1, zero(coeff)) + coeff2 * coeff end end @@ -334,7 +334,7 @@ ti = time() perm = ((N .+ (1:N))..., (1:N)...) levels = ntuple(identity, 2 * N) for (t, coeff) in trees1 - for (t′, coeff′) in braid(t, levels, perm) + for (t′, coeff′) in braid(t, perm, levels) trees3[t′] = get(trees3, t′, zero(valtype(trees3))) + coeff * coeff′ end end From 7de3afb84558d9443765d7a2e3da47e29aa36b51 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 31 Jul 2025 12:04:58 -0400 Subject: [PATCH 05/28] refactor in terms of FusionTreeBlock --- src/fusiontrees/fusiontreeblocks.jl | 307 +++++++++++++++++++++++ src/fusiontrees/fusiontrees.jl | 2 +- src/fusiontrees/uncouplediterator.jl | 352 --------------------------- src/tensors/treetransformers.jl | 2 +- 4 files changed, 309 insertions(+), 354 deletions(-) create mode 100644 src/fusiontrees/fusiontreeblocks.jl delete mode 100644 src/fusiontrees/uncouplediterator.jl diff --git a/src/fusiontrees/fusiontreeblocks.jl b/src/fusiontrees/fusiontreeblocks.jl new file mode 100644 index 000000000..f26834142 --- /dev/null +++ b/src/fusiontrees/fusiontreeblocks.jl @@ -0,0 +1,307 @@ +struct FusionTreeBlock{I,N₁,N₂,F<:FusionTreePair{I,N₁,N₂}} + trees::Vector{F} +end + +function FusionTreeBlock(uncoupled::Tuple{NTuple{N₁,I},NTuple{N₂,I}}, + isdual::Tuple{NTuple{N₁,Bool},NTuple{N₂,Bool}}) where {I<:Sector,N₁,N₂} + F₁ = fusiontreetype(I, N₁) + F₂ = fusiontreetype(I, N₂) + trees = Vector{Tuple{F₁,F₂}}(undef, 0) + + cleft = N₁ == 0 ? (one(I),) : ⊗(uncoupled[1]...) + cright = N₂ == 0 ? (one(I),) : ⊗(uncoupled[2]...) + cs = sort!(collect(intersect(cleft, cright))) + for c in cs + for f₁ in fusiontrees(uncoupled[1], c, isdual[1]), + f₂ in fusiontrees(uncoupled[2], c, isdual[2]) + + push!(trees, (f₁, f₂)) + end + end + return FusionTreeBlock(trees) +end + +Base.@constprop :aggressive function Base.getproperty(block::FusionTreeBlock, prop::Symbol) + if prop === :uncoupled + f₁, f₂ = first(block.trees) + return f₁.uncoupled, f₂.uncoupled + elseif prop === :isdual + f₁, f₂ = first(block.trees) + return f₁.isdual, f₂.isdual + else + return getfield(block, prop) + end +end + +Base.propertynames(::FusionTreeBlock, private::Bool=false) = (:trees, :uncoupled, :isdual) + +sectortype(::Type{<:FusionTreeBlock{I}}) where {I} = I +numout(fs::FusionTreeBlock) = numout(typeof(fs)) +numout(::Type{<:FusionTreeBlock{I,N₁}}) where {I,N₁} = N₁ +numin(fs::FusionTreeBlock) = numin(typeof(fs)) +numin(::Type{<:FusionTreeBlock{I,N₁,N₂}}) where {I,N₁,N₂} = N₂ +numind(fs::FusionTreeBlock) = numind(typeof(fs)) +numind(::Type{T}) where {T<:FusionTreeBlock} = numin(T) + numout(T) + +fusiontrees(block::FusionTreeBlock) = block.trees +Base.length(block::FusionTreeBlock) = length(fusiontrees(block)) + +# Manipulations +# ------------- +function transformation_matrix(transform, dst::FusionTreeBlock{I}, + src::FusionTreeBlock{I}) where {I} + U = zeros(sectorscalartype(I), length(dst), length(src)) + indexmap = Dict(f => ind for (ind, f) in enumerate(fusiontrees(dst))) + for (col, f) in enumerate(fusiontrees(src)) + for (f′, c) in transform(f) + row = indexmap[f′] + U[row, col] = c + end + end + return U +end + +function bendright(src::FusionTreeBlock) + uncoupled_dst = (TupleTools.front(src.uncoupled[1]), + (src.uncoupled[2]..., dual(src.uncoupled[1][end]))) + isdual_dst = (TupleTools.front(src.isdual[1]), + (src.isdual[2]..., !(src.isdual[1][end]))) + dst = FusionTreeBlock(uncoupled_dst, isdual_dst) + + U = transformation_matrix(bendright, dst, src) + return dst, U +end + +# TODO: verify if this can be computed through an adjoint +function bendleft(src::FusionTreeBlock) + uncoupled_dst = ((src.uncoupled[1]..., dual(src.uncoupled[2][end])), + TupleTools.front(src.uncoupled[2])) + isdual_dst = ((src.isdual[1]..., !(src.isdual[2][end])), + TupleTools.front(src.isdual[2])) + dst = FusionTreeBlock(uncoupled_dst, isdual_dst) + + U = transformation_matrix(bendleft, dst, src) + return dst, U +end + +function foldright(src::FusionTreeBlock) + uncoupled_dst = (Base.tail(src.uncoupled[1]), + (dual(first(src.uncoupled[1])), src.uncoupled[2]...)) + isdual_dst = (Base.tail(src.isdual[1]), + (!first(src.isdual[1]), src.isdual[2]...)) + dst = FusionTreeBlock(uncoupled_dst, isdual_dst) + + U = transformation_matrix(foldright, dst, src) + return dst, U +end + +# TODO: verify if this can be computed through an adjoint +function foldleft(src::FusionTreeBlock) + uncoupled_dst = ((dual(first(src.uncoupled[2])), src.uncoupled[1]...), + Base.tail(src.uncoupled[2])) + isdual_dst = ((!first(src.isdual[2]), src.isdual[1]...), + Base.tail(src.isdual[2])) + dst = FusionTreeBlock(uncoupled_dst, isdual_dst) + + U = transformation_matrix(foldleft, dst, src) + return dst, U +end + +function cycleclockwise(src::FusionTreeBlock) + if numout(src) > 0 + tmp, U₁ = foldright(src) + dst, U₂ = bendleft(tmp) + else + tmp, U₁ = bendleft(src) + dst, U₂ = foldright(tmp) + end + return dst, U₂ * U₁ +end + +function cycleanticlockwise(src::FusionTreeBlock) + if numin(src) > 0 + tmp, U₁ = foldleft(src) + dst, U₂ = bendright(tmp) + else + tmp, U₁ = bendright(src) + dst, U₂ = foldleft(tmp) + end + return dst, U₂ * U₁ +end + +@inline function repartition(src::FusionTreeBlock, N::Int) + @assert 0 <= N <= numind(src) + return _recursive_repartition(src, Val(N)) +end + +function _repartition_type(I, N, N₁, N₂) + return Tuple{FusionTreeBlock{I,N,N₁ + N₂ - N},Matrix{sectorscalartype(I)}} +end +function _recursive_repartition(src::FusionTreeBlock{I,N₁,N₂}, + ::Val{N})::_repartition_type(I, N, N₁, N₂) where {I,N₁,N₂,N} + if N == N₁ + dst = src + U = zeros(sectorscalartype(I), length(dst), length(src)) + copyto!(U, LinearAlgebra.I) + return dst, U + end + + N == N₁ - 1 && return bendright(src) + N == N₁ + 1 && return bendleft(src) + + tmp, U₁ = N < N₁ ? bendright(src) : bendleft(src) + dst, U₂ = _recursive_repartition(tmp, Val(N)) + return dst, U₂ * U₁ +end + +function Base.transpose(src::FusionTreeBlock, p::Index2Tuple{N₁,N₂}) where {N₁,N₂} + N = N₁ + N₂ + @assert numind(src) == N + p′ = linearizepermutation(p..., numout(src), numin(src)) + @assert iscyclicpermutation(p′) + return _fstranspose((src, p)) +end + +const _FSTransposeKey{I,N₁,N₂} = Tuple{<:FusionTreeBlock{I},Index2Tuple{N₁,N₂}} + +@cached function _fstranspose(key::_FSTransposeKey{I,N₁,N₂})::Tuple{FusionTreeBlock{I,N₁, + N₂}, + Matrix{sectorscalartype(I)}} where {I, + N₁, + N₂} + src, (p1, p2) = key + + N = N₁ + N₂ + p = linearizepermutation(p1, p2, numout(src), numin(src)) + + dst, U = repartition(src, N₁) + length(p) == 0 && return dst, U + i1 = findfirst(==(1), p)::Int + i1 == 1 && return dst, U + + Nhalf = N >> 1 + while 1 < i1 ≤ Nhalf + dst, U_tmp = cycleanticlockwise(dst) + U = U_tmp * U + i1 -= 1 + end + while Nhalf < i1 + dst, U_tmp = cycleclockwise(dst) + U = U_tmp * U + i1 = mod1(i1 + 1, N) + end + + return dst, U +end + +function CacheStyle(::typeof(_fstranspose), k::_FSTransposeKey{I}) where {I} + if FusionStyle(I) == UniqueFusion() + return NoCache() + else + return GlobalLRUCache() + end +end + +function artin_braid(src::FusionTreeBlock{I,N,0}, i; inv::Bool=false) where {I,N} + 1 <= i < N || + throw(ArgumentError("Cannot swap outputs i=$i and i+1 out of only $N outputs")) + + uncoupled = src.uncoupled[1] + uncoupled′ = TupleTools.setindex(uncoupled, uncoupled[i + 1], i) + uncoupled′ = TupleTools.setindex(uncoupled′, uncoupled[i], i + 1) + isdual = src.isdual[1] + isdual′ = TupleTools.setindex(isdual, isdual[i], i + 1) + isdual′ = TupleTools.setindex(isdual′, isdual[i + 1], i) + dst = FusionTreeBlock((uncoupled′, ()), (isdual′, ())) + + # TODO: do we want to rewrite `artin_braid` to take double trees instead? + U = transformation_matrix(dst, src) do (f₁, f₂) + return ((f₁′, f₂) => c for (f₁′, c) in artin_braid(f₁, i; inv)) + end + return dst, U +end + +function braid(src::FusionTreeBlock{I,N,0}, p::NTuple{N,Int}, + levels::NTuple{N,Int}) where {I,N} + TupleTools.isperm(p) || throw(ArgumentError("not a valid permutation: $p")) + + if FusionStyle(I) isa UniqueFusion && BraidingStyle(I) isa SymmetricBraiding + uncoupled′ = TupleTools._permute(src.uncoupled[1], p) + isdual′ = TupleTools._permute(src.isdual[1], p) + dst = FusionTreeBlock(uncoupled′, isdual′) + U = transformation_matrix(dst, src) do (f₁, f₂) + return ((f₁′, f₂) => c for (f₁, c) in braid(f₁, p, levels)) + end + else + dst, U = repartition(src, N) # TODO: can we avoid this? + for s in permutation2swaps(p) + inv = levels[s] > levels[s + 1] + dst, U_tmp = artin_braid(dst, s; inv) + U = U_tmp * U + end + end + return dst, U +end + +function braid(src::FusionTreeBlock{I}, p::Index2Tuple{N₁,N₂}, + levels::Index2Tuple) where {I,N₁,N₂} + @assert numind(src) == N₁ + N₂ + @assert numout(src) == length(levels[1]) && numin(src) == length(levels[2]) + @assert TupleTools.isperm((p[1]..., p[2]...)) + return _fsbraid((src, p, levels)) +end + +const _FSBraidKey{I,N₁,N₂} = Tuple{<:FusionTreeBlock{I},Index2Tuple{N₁,N₂},Index2Tuple} + +@cached function _fsbraid(key::_FSBraidKey{I,N₁,N₂})::Tuple{FusionTreeBlock{I,N₁,N₂}, + Matrix{sectorscalartype(I)}} where {I, + N₁, + N₂} + src, (p1, p2), (l1, l2) = key + + p = linearizepermutation(p1, p2, numout(src), numin(src)) + levels = (l1..., reverse(l2)...) + + dst, U = repartition(src, numind(src)) + + if FusionStyle(I) isa UniqueFusion && BraidingStyle(I) isa SymmetricBraiding + uncoupled′ = TupleTools._permute(dst.uncoupled[1], p) + isdual′ = TupleTools._permute(dst.isdual[1], p) + + dst′ = FusionTreeBlock(uncoupled′, isdual′) + U_tmp = transformation_matrix(dst′, dst) do (f₁, f₂) + return ((f₁′, f₂) => c for (f₁, c) in braid(f₁, p, levels)) + end + dst = dst′ + U = U_tmp * U + else + for s in permutation2swaps(p) + inv = levels[s] > levels[s + 1] + dst, U_tmp = artin_braid(dst, s; inv) + U = U_tmp * U + end + end + + if N₂ == 0 + return dst, U + else + dst, U_tmp = repartition(dst, N₁) + U = U_tmp * U + return dst, U + end +end + +function CacheStyle(::typeof(_fsbraid), k::_FSBraidKey{I}) where {I} + if FusionStyle(I) isa UniqueFusion + return NoCache() + else + return GlobalLRUCache() + end +end + +function permute(src::FusionTreeBlock{I}, p::Index2Tuple) where {I} + @assert BraidingStyle(I) isa SymmetricBraiding + levels1 = ntuple(identity, numout(src)) + levels2 = numout(src) .+ ntuple(identity, numin(src)) + return braid(src, p, (levels1, levels2)) +end diff --git a/src/fusiontrees/fusiontrees.jl b/src/fusiontrees/fusiontrees.jl index b44f73b34..1da733ece 100644 --- a/src/fusiontrees/fusiontrees.jl +++ b/src/fusiontrees/fusiontrees.jl @@ -227,7 +227,7 @@ end # Fusion tree iterators include("iterator.jl") -include("uncouplediterator.jl") +include("fusiontreeblocks.jl") # Manipulate fusion trees include("manipulations.jl") diff --git a/src/fusiontrees/uncouplediterator.jl b/src/fusiontrees/uncouplediterator.jl deleted file mode 100644 index d5906b417..000000000 --- a/src/fusiontrees/uncouplediterator.jl +++ /dev/null @@ -1,352 +0,0 @@ -struct OuterTreeIterator{I<:Sector,N₁,N₂} - uncoupled::Tuple{NTuple{N₁,I},NTuple{N₂,I}} - isdual::Tuple{NTuple{N₁,Bool},NTuple{N₂,Bool}} -end - -sectortype(::Type{<:OuterTreeIterator{I}}) where {I} = I -numout(fs::OuterTreeIterator) = numout(typeof(fs)) -numout(::Type{<:OuterTreeIterator{I,N₁}}) where {I,N₁} = N₁ -numin(fs::OuterTreeIterator) = numin(typeof(fs)) -numin(::Type{<:OuterTreeIterator{I,N₁,N₂}}) where {I,N₁,N₂} = N₂ -numind(fs::OuterTreeIterator) = numind(typeof(fs)) -numind(::Type{T}) where {T<:OuterTreeIterator} = numin(T) + numout(T) - -# TODO: should we make this an actual iterator? -function fusiontrees(iter::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} - F₁ = fusiontreetype(I, N₁) - F₂ = fusiontreetype(I, N₂) - - trees = Vector{Tuple{F₁,F₂}}(undef, 0) - for c in blocksectors(iter), f₁ in fusiontrees(iter.uncoupled[1], c, iter.isdual[1]), - f₂ in fusiontrees(iter.uncoupled[2], c, iter.isdual[2]) - push!(trees, (f₁, f₂)) - end - return trees -end - -# TODO: better implementation -Base.length(iter::OuterTreeIterator) = length(fusiontrees(iter)) - -function blocksectors(iter::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} - I == Trivial && return (Trivial(),) - - bs_codomain = Vector{I}() - if N₁ == 0 - push!(bs_codomain, one(I)) - elseif N₁ == 1 - push!(bs_codomain, only(iter.uncoupled[1])) - else - for c in ⊗(iter.uncoupled[1]...) - if !(c in bs_codomain) - push!(bs_codomain, c) - end - end - end - - bs_domain = Vector{I}() - if N₂ == 0 - push!(bs_domain, one(I)) - elseif N₂ == 1 - push!(bs_domain, only(iter.uncoupled[2])) - else - for c in ⊗(iter.uncoupled[2]...) - if !(c in bs_domain) - push!(bs_domain, c) - end - end - end - - return sort!(collect(intersect(bs_codomain, bs_domain))) -end - -# Manipulations -# ------------- - -function bendright(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} - uncoupled_dst = (TupleTools.front(fs_src.uncoupled[1]), - (fs_src.uncoupled[2]..., dual(fs_src.uncoupled[1][end]))) - isdual_dst = (TupleTools.front(fs_src.isdual[1]), - (fs_src.isdual[2]..., !(fs_src.isdual[1][end]))) - fs_dst = OuterTreeIterator(uncoupled_dst, isdual_dst) - - trees_src = fusiontrees(fs_src) - trees_dst = fusiontrees(fs_dst) - indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) - U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) - - for (col, f) in enumerate(trees_src) - for (f′, c) in bendright(f) - row = indexmap[f′] - U[row, col] = c - end - end - - return fs_dst, U -end - -# TODO: verify if this can be computed through an adjoint -function bendleft(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} - uncoupled_dst = ((fs_src.uncoupled[1]..., dual(fs_src.uncoupled[2][end])), - TupleTools.front(fs_src.uncoupled[2])) - isdual_dst = ((fs_src.isdual[1]..., !(fs_src.isdual[2][end])), - TupleTools.front(fs_src.isdual[2])) - fs_dst = OuterTreeIterator(uncoupled_dst, isdual_dst) - - trees_src = fusiontrees(fs_src) - trees_dst = fusiontrees(fs_dst) - indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) - U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) - - for (col, f) in enumerate(trees_src) - for (f′, c) in bendleft(f) - row = indexmap[f′] - U[row, col] = c - end - end - - return fs_dst, U -end - -function foldright(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} - uncoupled_dst = (Base.tail(fs_src.uncoupled[1]), - (dual(first(fs_src.uncoupled[1])), fs_src.uncoupled[2]...)) - isdual_dst = (Base.tail(fs_src.isdual[1]), - (!first(fs_src.isdual[1]), fs_src.isdual[2]...)) - fs_dst = OuterTreeIterator(uncoupled_dst, isdual_dst) - - trees_src = fusiontrees(fs_src) - trees_dst = fusiontrees(fs_dst) - indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) - U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) - - for (col, f) in enumerate(trees_src) - for (f′, c) in foldright(f) - row = indexmap[f′] - U[row, col] = c - end - end - - return fs_dst, U -end - -# TODO: verify if this can be computed through an adjoint -function foldleft(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} - uncoupled_dst = ((dual(first(fs_src.uncoupled[2])), fs_src.uncoupled[1]...), - Base.tail(fs_src.uncoupled[2])) - isdual_dst = ((!first(fs_src.isdual[2]), fs_src.isdual[1]...), - Base.tail(fs_src.isdual[2])) - fs_dst = OuterTreeIterator(uncoupled_dst, isdual_dst) - - trees_src = fusiontrees(fs_src) - trees_dst = fusiontrees(fs_dst) - indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) - U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) - - for (col, f) in enumerate(trees_src) - for (f′, c) in foldleft(f) - row = indexmap[f′] - U[row, col] = c - end - end - - return fs_dst, U -end - -function cycleclockwise(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} - if N₁ > 0 - fs_tmp, U₁ = foldright(fs_src) - fs_dst, U₂ = bendleft(fs_tmp) - else - fs_tmp, U₁ = bendleft(fs_src) - fs_dst, U₂ = foldright(fs_tmp) - end - return fs_dst, U₂ * U₁ -end - -function cycleanticlockwise(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} - if N₂ > 0 - fs_tmp, U₁ = foldleft(fs_src) - fs_dst, U₂ = bendright(fs_tmp) - else - fs_tmp, U₁ = bendright(fs_src) - fs_dst, U₂ = foldleft(fs_tmp) - end - return fs_dst, U₂ * U₁ -end - -@inline function repartition(fs_src::OuterTreeIterator{I,N₁,N₂}, N::Int) where {I,N₁,N₂} - @assert 0 <= N <= N₁ + N₂ - return _recursive_repartition(fs_src, Val(N)) -end - -function _repartition_type(I, N, N₁, N₂) - return Tuple{OuterTreeIterator{I,N,N₁ + N₂ - N},Matrix{sectorscalartype(I)}} -end -function _recursive_repartition(fs_src::OuterTreeIterator{I,N₁,N₂}, - ::Val{N})::_repartition_type(I, N, N₁, N₂) where {I,N₁,N₂,N} - if N == N₁ - fs_dst = fs_src - U = zeros(sectorscalartype(I), length(fs_dst), length(fs_src)) - copyto!(U, LinearAlgebra.I) - return fs_dst, U - end - - N == N₁ - 1 && return bendright(fs_src) - N == N₁ + 1 && return bendleft(fs_src) - - fs_tmp, U₁ = N < N₁ ? bendright(fs_src) : bendleft(fs_src) - fs_dst, U₂ = _recursive_repartition(fs_tmp, Val(N)) - return fs_dst, U₂ * U₁ -end - -function Base.transpose(fs_src::OuterTreeIterator{I}, p::Index2Tuple{N₁,N₂}) where {I,N₁,N₂} - N = N₁ + N₂ - @assert numind(fs_src) == N - p′ = linearizepermutation(p..., numout(fs_src), numin(fs_src)) - @assert iscyclicpermutation(p′) - return _fstranspose((fs_src, p)) -end - -const _FSTransposeKey{I,N₁,N₂} = Tuple{<:OuterTreeIterator{I},Index2Tuple{N₁,N₂}} - -@cached function _fstranspose(key::_FSTransposeKey{I,N₁,N₂})::Tuple{OuterTreeIterator{I,N₁, - N₂}, - Matrix{sectorscalartype(I)}} where {I, - N₁, - N₂} - fs_src, (p1, p2) = key - - N = N₁ + N₂ - p = linearizepermutation(p1, p2, numout(fs_src), numin(fs_src)) - - fs_dst, U = repartition(fs_src, N₁) - length(p) == 0 && return fs_dst, U - i1 = findfirst(==(1), p)::Int - i1 == 1 && return fs_dst, U - - Nhalf = N >> 1 - while 1 < i1 ≤ Nhalf - fs_dst, U_tmp = cycleanticlockwise(fs_dst) - U = U_tmp * U - i1 -= 1 - end - while Nhalf < i1 - fs_dst, U_tmp = cycleclockwise(fs_dst) - U = U_tmp * U - i1 = mod1(i1 + 1, N) - end - - return fs_dst, U -end - -function CacheStyle(::typeof(_fstranspose), k::_FSTransposeKey{I}) where {I} - if FusionStyle(I) == UniqueFusion() - return NoCache() - else - return GlobalLRUCache() - end -end - -function artin_braid(fs_src::OuterTreeIterator{I,N,0}, i; inv::Bool=false) where {I,N} - 1 <= i < N || - throw(ArgumentError("Cannot swap outputs i=$i and i+1 out of only $N outputs")) - - uncoupled = fs_src.uncoupled[1] - uncoupled′ = TupleTools.setindex(uncoupled, uncoupled[i + 1], i) - uncoupled′ = TupleTools.setindex(uncoupled′, uncoupled[i], i + 1) - - isdual = fs_src.isdual[1] - isdual′ = TupleTools.setindex(isdual, isdual[i], i + 1) - isdual′ = TupleTools.setindex(isdual′, isdual[i + 1], i) - - fs_dst = OuterTreeIterator((uncoupled′, ()), (isdual′, ())) - - trees_src = fusiontrees(fs_src) - trees_dst = fusiontrees(fs_dst) - indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) - U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) - - for (col, (f₁, f₂)) in enumerate(trees_src) - for (f₁′, c) in artin_braid(f₁, i; inv) - row = indexmap[(f₁′, f₂)] - U[row, col] = c - end - end - - return fs_dst, U -end - -function braid(fs_src::OuterTreeIterator{I,N,0}, p::NTuple{N,Int}, - levels::NTuple{N,Int}) where {I,N} - TupleTools.isperm(p) || throw(ArgumentError("not a valid permutation: $p")) - - if FusionStyle(I) isa UniqueFusion && BraidingStyle(I) isa SymmetricBraiding - uncoupled′ = TupleTools._permute(fs_src.uncoupled[1], p) - isdual′ = TupleTools._permute(fs_src.isdual[1], p) - fs_dst = OuterTreeIterator(uncoupled′, isdual′) - - trees_src = fusiontrees(fs_src) - trees_dst = fusiontrees(fs_dst) - indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) - U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) - - for (col, (f₁, f₂)) in enumerate(trees_src) - for (f₁′, c) in braid(f₁, p, levels) - row = indexmap[(f₁′, f₂)] - U[row, col] = c - end - end - - return fs_dst, U - end - - fs_dst, U = repartition(fs_src, N) # TODO: can we avoid this? - for s in permutation2swaps(p) - inv = levels[s] > levels[s + 1] - fs_dst, U_tmp = artin_braid(fs_dst, s; inv) - U = U_tmp * U - end - return fs_dst, U -end - -function braid(fs_src::OuterTreeIterator{I}, p::Index2Tuple{N₁,N₂}, - levels::Index2Tuple) where {I,N₁,N₂} - @assert numind(fs_src) == N₁ + N₂ - @assert numout(fs_src) == length(levels[1]) && numin(fs_src) == length(levels[2]) - @assert TupleTools.isperm((p[1]..., p[2]...)) - return _fsbraid((fs_src, p, levels)) -end - -const _FSBraidKey{I,N₁,N₂} = Tuple{<:OuterTreeIterator{I},Index2Tuple{N₁,N₂},Index2Tuple} - -@cached function _fsbraid(key::_FSBraidKey{I,N₁,N₂})::Tuple{OuterTreeIterator{I,N₁,N₂}, - Matrix{sectorscalartype(I)}} where {I, - N₁, - N₂} - fs_src, (p1, p2), (l1, l2) = key - - p = linearizepermutation(p1, p2, numout(fs_src), numin(fs_src)) - levels = (l1..., reverse(l2)...) - - fs_dst, U = repartition(fs_src, numind(fs_src)) - fs_dst, U_tmp = braid(fs_dst, p, levels) - U = U_tmp * U - fs_dst, U_tmp = repartition(fs_dst, N₁) - U = U_tmp * U - return fs_dst, U -end - -function CacheStyle(::typeof(_fsbraid), k::_FSBraidKey{I}) where {I} - if FusionStyle(I) isa UniqueFusion - return NoCache() - else - return GlobalLRUCache() - end -end - -function permute(fs_src::OuterTreeIterator{I}, p::Index2Tuple) where {I} - @assert BraidingStyle(I) isa SymmetricBraiding - levels1 = ntuple(identity, numout(fs_src)) - levels2 = numout(fs_src) .+ ntuple(identity, numin(fs_src)) - return braid(fs_src, p, (levels1, levels2)) -end diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index 8cec1fefd..7a009ad15 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -72,7 +72,7 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) for cod_uncoupled_src in sectors(codomain(Vsrc)), dom_uncoupled_src in sectors(domain(Vsrc)) - fs_src = OuterTreeIterator((cod_uncoupled_src, dom_uncoupled_src), isdual_src) + fs_src = FusionTreeBlock((cod_uncoupled_src, dom_uncoupled_src), isdual_src) trees_src = fusiontrees(fs_src) isempty(trees_src) && continue From baf12553b7297a9ebd961a4dede77a59e21cbe11 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 31 Jul 2025 19:21:13 -0400 Subject: [PATCH 06/28] Fix unbound type parameter --- src/fusiontrees/fusiontreeblocks.jl | 19 ++++++++++--------- src/tensors/treetransformers.jl | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/fusiontrees/fusiontreeblocks.jl b/src/fusiontrees/fusiontreeblocks.jl index f26834142..634c28bc4 100644 --- a/src/fusiontrees/fusiontreeblocks.jl +++ b/src/fusiontrees/fusiontreeblocks.jl @@ -2,8 +2,9 @@ struct FusionTreeBlock{I,N₁,N₂,F<:FusionTreePair{I,N₁,N₂}} trees::Vector{F} end -function FusionTreeBlock(uncoupled::Tuple{NTuple{N₁,I},NTuple{N₂,I}}, - isdual::Tuple{NTuple{N₁,Bool},NTuple{N₂,Bool}}) where {I<:Sector,N₁,N₂} +function FusionTreeBlock{I}(uncoupled::Tuple{NTuple{N₁,I},NTuple{N₂,I}}, + isdual::Tuple{NTuple{N₁,Bool},NTuple{N₂,Bool}}) where {I<:Sector, + N₁,N₂} F₁ = fusiontreetype(I, N₁) F₂ = fusiontreetype(I, N₂) trees = Vector{Tuple{F₁,F₂}}(undef, 0) @@ -66,7 +67,7 @@ function bendright(src::FusionTreeBlock) (src.uncoupled[2]..., dual(src.uncoupled[1][end]))) isdual_dst = (TupleTools.front(src.isdual[1]), (src.isdual[2]..., !(src.isdual[1][end]))) - dst = FusionTreeBlock(uncoupled_dst, isdual_dst) + dst = FusionTreeBlock{sectortype(src)}(uncoupled_dst, isdual_dst) U = transformation_matrix(bendright, dst, src) return dst, U @@ -78,7 +79,7 @@ function bendleft(src::FusionTreeBlock) TupleTools.front(src.uncoupled[2])) isdual_dst = ((src.isdual[1]..., !(src.isdual[2][end])), TupleTools.front(src.isdual[2])) - dst = FusionTreeBlock(uncoupled_dst, isdual_dst) + dst = FusionTreeBlock{sectortype(src)}(uncoupled_dst, isdual_dst) U = transformation_matrix(bendleft, dst, src) return dst, U @@ -89,7 +90,7 @@ function foldright(src::FusionTreeBlock) (dual(first(src.uncoupled[1])), src.uncoupled[2]...)) isdual_dst = (Base.tail(src.isdual[1]), (!first(src.isdual[1]), src.isdual[2]...)) - dst = FusionTreeBlock(uncoupled_dst, isdual_dst) + dst = FusionTreeBlock{sectortype(src)}(uncoupled_dst, isdual_dst) U = transformation_matrix(foldright, dst, src) return dst, U @@ -101,7 +102,7 @@ function foldleft(src::FusionTreeBlock) Base.tail(src.uncoupled[2])) isdual_dst = ((!first(src.isdual[2]), src.isdual[1]...), Base.tail(src.isdual[2])) - dst = FusionTreeBlock(uncoupled_dst, isdual_dst) + dst = FusionTreeBlock{sectortype(src)}(uncoupled_dst, isdual_dst) U = transformation_matrix(foldleft, dst, src) return dst, U @@ -212,7 +213,7 @@ function artin_braid(src::FusionTreeBlock{I,N,0}, i; inv::Bool=false) where {I,N isdual = src.isdual[1] isdual′ = TupleTools.setindex(isdual, isdual[i], i + 1) isdual′ = TupleTools.setindex(isdual′, isdual[i + 1], i) - dst = FusionTreeBlock((uncoupled′, ()), (isdual′, ())) + dst = FusionTreeBlock{I}((uncoupled′, ()), (isdual′, ())) # TODO: do we want to rewrite `artin_braid` to take double trees instead? U = transformation_matrix(dst, src) do (f₁, f₂) @@ -228,7 +229,7 @@ function braid(src::FusionTreeBlock{I,N,0}, p::NTuple{N,Int}, if FusionStyle(I) isa UniqueFusion && BraidingStyle(I) isa SymmetricBraiding uncoupled′ = TupleTools._permute(src.uncoupled[1], p) isdual′ = TupleTools._permute(src.isdual[1], p) - dst = FusionTreeBlock(uncoupled′, isdual′) + dst = FusionTreeBlock{I}(uncoupled′, isdual′) U = transformation_matrix(dst, src) do (f₁, f₂) return ((f₁′, f₂) => c for (f₁, c) in braid(f₁, p, levels)) end @@ -268,7 +269,7 @@ const _FSBraidKey{I,N₁,N₂} = Tuple{<:FusionTreeBlock{I},Index2Tuple{N₁,N uncoupled′ = TupleTools._permute(dst.uncoupled[1], p) isdual′ = TupleTools._permute(dst.isdual[1], p) - dst′ = FusionTreeBlock(uncoupled′, isdual′) + dst′ = FusionTreeBlock{I}(uncoupled′, isdual′) U_tmp = transformation_matrix(dst′, dst) do (f₁, f₂) return ((f₁′, f₂) => c for (f₁, c) in braid(f₁, p, levels)) end diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index 7a009ad15..d31706dc4 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -72,7 +72,7 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) for cod_uncoupled_src in sectors(codomain(Vsrc)), dom_uncoupled_src in sectors(domain(Vsrc)) - fs_src = FusionTreeBlock((cod_uncoupled_src, dom_uncoupled_src), isdual_src) + fs_src = FusionTreeBlock{I}((cod_uncoupled_src, dom_uncoupled_src), isdual_src) trees_src = fusiontrees(fs_src) isempty(trees_src) && continue From 4e7c9e3ee8b5663cfd1a8a13103268e13d9633fe Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 31 Jul 2025 19:22:27 -0400 Subject: [PATCH 07/28] refactor repartition to unroll loop --- src/fusiontrees/fusiontreeblocks.jl | 52 +++++++++++++++++++---------- 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/src/fusiontrees/fusiontreeblocks.jl b/src/fusiontrees/fusiontreeblocks.jl index 634c28bc4..03c73c7f1 100644 --- a/src/fusiontrees/fusiontreeblocks.jl +++ b/src/fusiontrees/fusiontreeblocks.jl @@ -132,27 +132,45 @@ end @inline function repartition(src::FusionTreeBlock, N::Int) @assert 0 <= N <= numind(src) - return _recursive_repartition(src, Val(N)) + return repartition(src, Val(N)) end -function _repartition_type(I, N, N₁, N₂) - return Tuple{FusionTreeBlock{I,N,N₁ + N₂ - N},Matrix{sectorscalartype(I)}} -end -function _recursive_repartition(src::FusionTreeBlock{I,N₁,N₂}, - ::Val{N})::_repartition_type(I, N, N₁, N₂) where {I,N₁,N₂,N} - if N == N₁ - dst = src - U = zeros(sectorscalartype(I), length(dst), length(src)) - copyto!(U, LinearAlgebra.I) - return dst, U - end +#= +Using a generated function here to ensure type stability by unrolling the loops: +```julia +dst, U = bendleft/right(src) - N == N₁ - 1 && return bendright(src) - N == N₁ + 1 && return bendleft(src) +# repeat the following 2 lines N - 1 times +dst, Utmp = bendleft/right(dst) +U = Utmp * U - tmp, U₁ = N < N₁ ? bendright(src) : bendleft(src) - dst, U₂ = _recursive_repartition(tmp, Val(N)) - return dst, U₂ * U₁ +return dst, U +``` +=# +@generated function repartition(src::FusionTreeBlock, ::Val{N}) where {N} + return _repartition_body(numout(src) - N) +end +function _repartition_body(N) + if N == 0 + ex = quote + T = sectorscalartype(sectortype(src)) + U = copyto!(zeros(T, length(src), length(src)), LinearAlgebra.I) + return src, U + end + else + f = N < 0 ? bendleft : bendright + ex_rep = Expr(:block) + for _ in 1:(abs(N) - 1) + push!(ex_rep.args, :((dst, Utmp) = $f(dst))) + push!(ex_rep.args, :(U = Utmp * U)) + end + ex = quote + dst, U = $f(src) + $ex_rep + return dst, U + end + end + return ex end function Base.transpose(src::FusionTreeBlock, p::Index2Tuple{N₁,N₂}) where {N₁,N₂} From c808d4b45c5907f8f3739778aa48aa6895bc81c9 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 13 Aug 2025 14:15:52 +0200 Subject: [PATCH 08/28] dont depend on intricate scoping rules --- src/fusiontrees/fusiontreeblocks.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fusiontrees/fusiontreeblocks.jl b/src/fusiontrees/fusiontreeblocks.jl index 03c73c7f1..2121088cc 100644 --- a/src/fusiontrees/fusiontreeblocks.jl +++ b/src/fusiontrees/fusiontreeblocks.jl @@ -249,7 +249,7 @@ function braid(src::FusionTreeBlock{I,N,0}, p::NTuple{N,Int}, isdual′ = TupleTools._permute(src.isdual[1], p) dst = FusionTreeBlock{I}(uncoupled′, isdual′) U = transformation_matrix(dst, src) do (f₁, f₂) - return ((f₁′, f₂) => c for (f₁, c) in braid(f₁, p, levels)) + return ((f₁′, f₂) => c for (f₁′, c) in braid(f₁, p, levels)) end else dst, U = repartition(src, N) # TODO: can we avoid this? From 57047339cdf82371dfe9b63e5b41e3205846ccf1 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 13 Aug 2025 14:16:16 +0200 Subject: [PATCH 09/28] Refactor bendright to avoid extra dictionary --- src/fusiontrees/fusiontreeblocks.jl | 50 +++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/src/fusiontrees/fusiontreeblocks.jl b/src/fusiontrees/fusiontreeblocks.jl index 2121088cc..fb1934ce5 100644 --- a/src/fusiontrees/fusiontreeblocks.jl +++ b/src/fusiontrees/fusiontreeblocks.jl @@ -67,9 +67,55 @@ function bendright(src::FusionTreeBlock) (src.uncoupled[2]..., dual(src.uncoupled[1][end]))) isdual_dst = (TupleTools.front(src.isdual[1]), (src.isdual[2]..., !(src.isdual[1][end]))) - dst = FusionTreeBlock{sectortype(src)}(uncoupled_dst, isdual_dst) + I = sectortype(src) + N₁ = numout(src) + N₂ = numin(src) + @assert N₁ > 0 + + dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst) + indexmap = fusiontreedict(I)(f => ind for (ind, f) in enumerate(fusiontrees(dst))) + U = zeros(sectorscalartype(I), length(dst), length(src)) + + for (col, (f₁, f₂)) in enumerate(fusiontrees(src)) + c = f₁.coupled + a = N₁ == 1 ? leftone(f₁.uncoupled[1]) : + (N₁ == 2 ? f₁.uncoupled[1] : f₁.innerlines[end]) + b = f₁.uncoupled[N₁] + + uncoupled1 = TupleTools.front(f₁.uncoupled) + isdual1 = TupleTools.front(f₁.isdual) + inner1 = N₁ > 2 ? TupleTools.front(f₁.innerlines) : () + vertices1 = N₁ > 1 ? TupleTools.front(f₁.vertices) : () + f₁′ = FusionTree(uncoupled1, a, isdual1, inner1, vertices1) + + uncoupled2 = (f₂.uncoupled..., dual(b)) + isdual2 = (f₂.isdual..., !(f₁.isdual[N₁])) + inner2 = N₂ > 1 ? (f₂.innerlines..., c) : () + + coeff₀ = sqrtdim(c) * invsqrtdim(a) + if f₁.isdual[N₁] + coeff₀ *= conj(frobeniusschur(dual(b))) + end + if FusionStyle(I) isa MultiplicityFreeFusion + coeff = coeff₀ * Bsymbol(a, b, c) + vertices2 = N₂ > 0 ? (f₂.vertices..., 1) : () + f₂′ = FusionTree(uncoupled2, a, isdual2, inner2, vertices2) + row = indexmap[(f₁′, f₂′)] + @inbounds U[row, col] = coeff + else + Bmat = Bsymbol(a, b, c) + μ = N₁ > 1 ? f₁.vertices[end] : 1 + for ν in axes(Bmat, 2) + coeff = coeff₀ * Bmat[μ, ν] + iszero(coeff) && continue + vertices2 = N₂ > 0 ? (f₂.vertices..., ν) : () + f₂′ = FusionTree(uncoupled2, a, isdual2, inner2, vertices2) + row = indexmap[(f₁′, f₂′)] + @inbounds U[row, col] = coeff + end + end + end - U = transformation_matrix(bendright, dst, src) return dst, U end From 3dc1a47d7dadff09ede629635a134f52dfc11f7d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 13 Aug 2025 14:24:30 +0200 Subject: [PATCH 10/28] Refactor bendleft to avoid extra dictionaries --- src/fusiontrees/fusiontreeblocks.jl | 53 +++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/src/fusiontrees/fusiontreeblocks.jl b/src/fusiontrees/fusiontreeblocks.jl index fb1934ce5..faec92cf4 100644 --- a/src/fusiontrees/fusiontreeblocks.jl +++ b/src/fusiontrees/fusiontreeblocks.jl @@ -119,15 +119,62 @@ function bendright(src::FusionTreeBlock) return dst, U end -# TODO: verify if this can be computed through an adjoint +# !! note that this is more or less a copy of bendright through +# (f1, f2) => conj(coeff) for ((f2, f1), coeff) in bendleft(src) function bendleft(src::FusionTreeBlock) uncoupled_dst = ((src.uncoupled[1]..., dual(src.uncoupled[2][end])), TupleTools.front(src.uncoupled[2])) isdual_dst = ((src.isdual[1]..., !(src.isdual[2][end])), TupleTools.front(src.isdual[2])) - dst = FusionTreeBlock{sectortype(src)}(uncoupled_dst, isdual_dst) + I = sectortype(src) + N₁ = numin(src) + N₂ = numout(src) + @assert N₁ > 0 + + dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst) + indexmap = fusiontreedict(I)(f => ind for (ind, f) in enumerate(fusiontrees(dst))) + U = zeros(sectorscalartype(I), length(dst), length(src)) + + for (col, (f₂, f₁)) in enumerate(fusiontrees(src)) + c = f₁.coupled + a = N₁ == 1 ? leftone(f₁.uncoupled[1]) : + (N₁ == 2 ? f₁.uncoupled[1] : f₁.innerlines[end]) + b = f₁.uncoupled[N₁] + + uncoupled1 = TupleTools.front(f₁.uncoupled) + isdual1 = TupleTools.front(f₁.isdual) + inner1 = N₁ > 2 ? TupleTools.front(f₁.innerlines) : () + vertices1 = N₁ > 1 ? TupleTools.front(f₁.vertices) : () + f₁′ = FusionTree(uncoupled1, a, isdual1, inner1, vertices1) + + uncoupled2 = (f₂.uncoupled..., dual(b)) + isdual2 = (f₂.isdual..., !(f₁.isdual[N₁])) + inner2 = N₂ > 1 ? (f₂.innerlines..., c) : () + + coeff₀ = sqrtdim(c) * invsqrtdim(a) + if f₁.isdual[N₁] + coeff₀ *= conj(frobeniusschur(dual(b))) + end + if FusionStyle(I) isa MultiplicityFreeFusion + coeff = coeff₀ * Bsymbol(a, b, c) + vertices2 = N₂ > 0 ? (f₂.vertices..., 1) : () + f₂′ = FusionTree(uncoupled2, a, isdual2, inner2, vertices2) + row = indexmap[(f₂′, f₁′)] + @inbounds U[row, col] = conj(coeff) + else + Bmat = Bsymbol(a, b, c) + μ = N₁ > 1 ? f₁.vertices[end] : 1 + for ν in axes(Bmat, 2) + coeff = coeff₀ * Bmat[μ, ν] + iszero(coeff) && continue + vertices2 = N₂ > 0 ? (f₂.vertices..., ν) : () + f₂′ = FusionTree(uncoupled2, a, isdual2, inner2, vertices2) + row = indexmap[(f₂′, f₁′)] + @inbounds U[row, col] = conj(coeff) + end + end + end - U = transformation_matrix(bendleft, dst, src) return dst, U end From f8eb2074f24092de5ea5c688b31b9af4ca4fcba7 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 13 Aug 2025 14:32:35 +0200 Subject: [PATCH 11/28] Refactor foldright to avoid extra dictionaries --- src/fusiontrees/fusiontreeblocks.jl | 61 +++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 3 deletions(-) diff --git a/src/fusiontrees/fusiontreeblocks.jl b/src/fusiontrees/fusiontreeblocks.jl index faec92cf4..e34c8c03f 100644 --- a/src/fusiontrees/fusiontreeblocks.jl +++ b/src/fusiontrees/fusiontreeblocks.jl @@ -181,11 +181,66 @@ end function foldright(src::FusionTreeBlock) uncoupled_dst = (Base.tail(src.uncoupled[1]), (dual(first(src.uncoupled[1])), src.uncoupled[2]...)) - isdual_dst = (Base.tail(src.isdual[1]), - (!first(src.isdual[1]), src.isdual[2]...)) + isdual_dst = (Base.tail(src.isdual[1]), (!first(src.isdual[1]), src.isdual[2]...)) + I = sectortype(src) + N₁ = numout(src) + N₂ = numin(src) + @assert N₁ > 0 dst = FusionTreeBlock{sectortype(src)}(uncoupled_dst, isdual_dst) - U = transformation_matrix(foldright, dst, src) + dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst) + indexmap = fusiontreedict(I)(f => ind for (ind, f) in enumerate(fusiontrees(dst))) + U = zeros(sectorscalartype(I), length(dst), length(src)) + + for (col, (f₁, f₂)) in enumerate(fusiontrees(src)) + # map first splitting vertex (a, b)<-c to fusion vertex b<-(dual(a), c) + a = f₁.uncoupled[1] + isduala = f₁.isdual[1] + factor = sqrtdim(a) + if !isduala + factor *= conj(frobeniusschur(a)) + end + c1 = dual(a) + c2 = f₁.coupled + uncoupled = Base.tail(f₁.uncoupled) + isdual = Base.tail(f₁.isdual) + if FusionStyle(I) isa UniqueFusion + c = first(c1 ⊗ c2) + fl = FusionTree{I}(Base.tail(f₁.uncoupled), c, Base.tail(f₁.isdual)) + fr = FusionTree{I}((c1, f₂.uncoupled...), c, (!isduala, f₂.isdual...)) + row = indexmap[(fl, fr)] + @inbounds U[row, col] = factor + else + if N₁ == 1 + cset = (leftone(c1),) # or rightone(a) + elseif N₁ == 2 + cset = (f₁.uncoupled[2],) + else + cset = ⊗(Base.tail(f₁.uncoupled)...) + end + for c in c1 ⊗ c2 + c ∈ cset || continue + for μ in 1:Nsymbol(c1, c2, c) + fc = FusionTree((c1, c2), c, (!isduala, false), (), (μ,)) + for (fl′, coeff1) in insertat(fc, 2, f₁) + N₁ > 1 && !isone(fl′.innerlines[1]) && continue + coupled = fl′.coupled + uncoupled = Base.tail(Base.tail(fl′.uncoupled)) + isdual = Base.tail(Base.tail(fl′.isdual)) + inner = N₁ <= 3 ? () : Base.tail(Base.tail(fl′.innerlines)) + vertices = N₁ <= 2 ? () : Base.tail(Base.tail(fl′.vertices)) + fl = FusionTree{I}(uncoupled, coupled, isdual, inner, vertices) + for (fr, coeff2) in insertat(fc, 2, f₂) + coeff = factor * coeff1 * conj(coeff2) + row = indexmap[(fl, fr)] + @inbounds U[row, col] = coeff + end + end + end + end + end + end + return dst, U end From 8c129a4cfc2119567ff43ed666e61d6ebccda5c9 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 13 Aug 2025 14:54:54 +0200 Subject: [PATCH 12/28] Refactor foldleft to avoid extra dictionaries --- src/fusiontrees/fusiontreeblocks.jl | 61 +++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 3 deletions(-) diff --git a/src/fusiontrees/fusiontreeblocks.jl b/src/fusiontrees/fusiontreeblocks.jl index e34c8c03f..a0b9f8432 100644 --- a/src/fusiontrees/fusiontreeblocks.jl +++ b/src/fusiontrees/fusiontreeblocks.jl @@ -244,15 +244,70 @@ function foldright(src::FusionTreeBlock) return dst, U end -# TODO: verify if this can be computed through an adjoint +# !! note that this is more or less a copy of foldright through +# (f1, f2) => conj(coeff) for ((f2, f1), coeff) in foldright(src) function foldleft(src::FusionTreeBlock) uncoupled_dst = ((dual(first(src.uncoupled[2])), src.uncoupled[1]...), Base.tail(src.uncoupled[2])) isdual_dst = ((!first(src.isdual[2]), src.isdual[1]...), Base.tail(src.isdual[2])) - dst = FusionTreeBlock{sectortype(src)}(uncoupled_dst, isdual_dst) + I = sectortype(src) + N₁ = numin(src) + N₂ = numout(src) + @assert N₁ > 0 + + dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst) + indexmap = fusiontreedict(I)(f => ind for (ind, f) in enumerate(fusiontrees(dst))) + U = zeros(sectorscalartype(I), length(dst), length(src)) - U = transformation_matrix(foldleft, dst, src) + for (col, (f₂, f₁)) in enumerate(fusiontrees(src)) + # map first splitting vertex (a, b)<-c to fusion vertex b<-(dual(a), c) + a = f₁.uncoupled[1] + isduala = f₁.isdual[1] + factor = sqrtdim(a) + if !isduala + factor *= conj(frobeniusschur(a)) + end + c1 = dual(a) + c2 = f₁.coupled + uncoupled = Base.tail(f₁.uncoupled) + isdual = Base.tail(f₁.isdual) + if FusionStyle(I) isa UniqueFusion + c = first(c1 ⊗ c2) + fl = FusionTree{I}(Base.tail(f₁.uncoupled), c, Base.tail(f₁.isdual)) + fr = FusionTree{I}((c1, f₂.uncoupled...), c, (!isduala, f₂.isdual...)) + row = indexmap[(fr, fl)] + @inbounds U[row, col] = conj(factor) + else + if N₁ == 1 + cset = (leftone(c1),) # or rightone(a) + elseif N₁ == 2 + cset = (f₁.uncoupled[2],) + else + cset = ⊗(Base.tail(f₁.uncoupled)...) + end + for c in c1 ⊗ c2 + c ∈ cset || continue + for μ in 1:Nsymbol(c1, c2, c) + fc = FusionTree((c1, c2), c, (!isduala, false), (), (μ,)) + for (fl′, coeff1) in insertat(fc, 2, f₁) + N₁ > 1 && !isone(fl′.innerlines[1]) && continue + coupled = fl′.coupled + uncoupled = Base.tail(Base.tail(fl′.uncoupled)) + isdual = Base.tail(Base.tail(fl′.isdual)) + inner = N₁ <= 3 ? () : Base.tail(Base.tail(fl′.innerlines)) + vertices = N₁ <= 2 ? () : Base.tail(Base.tail(fl′.vertices)) + fl = FusionTree{I}(uncoupled, coupled, isdual, inner, vertices) + for (fr, coeff2) in insertat(fc, 2, f₂) + coeff = factor * coeff1 * conj(coeff2) + row = indexmap[(fr, fl)] + @inbounds U[row, col] = conj(coeff) + end + end + end + end + end + end return dst, U end From 8b2a15dbdc920bd71d0567528b562ef0186575dd Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 13 Aug 2025 14:56:22 +0200 Subject: [PATCH 13/28] remove unused variable --- src/fusiontrees/manipulations.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fusiontrees/manipulations.jl b/src/fusiontrees/manipulations.jl index 9a8402870..09aa345ba 100644 --- a/src/fusiontrees/manipulations.jl +++ b/src/fusiontrees/manipulations.jl @@ -356,7 +356,6 @@ function foldright((f₁, f₂)::FusionTreePair{I,N₁,N₂}) where {I,N₁,N₂ fr = FusionTree{I}((c1, f₂.uncoupled...), c, (!isduala, f₂.isdual...)) return fusiontreedict(I)((fl, fr) => factor) else - hasmultiplicities = FusionStyle(a) isa GenericFusion local newtrees if N₁ == 1 cset = (leftone(c1),) # or rightone(a) From 9311fb03beb4843cb6981b38a0c40aa551dc98e3 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 14 Aug 2025 09:48:56 +0200 Subject: [PATCH 14/28] some docs fixes --- docs/Project.toml | 1 + docs/make.jl | 1 + docs/src/lib/sectors.md | 10 +++++----- docs/src/man/sectors.md | 6 +++--- src/fusiontrees/manipulations.jl | 18 ++++++++---------- src/tensors/linalg.jl | 2 -- 6 files changed, 18 insertions(+), 20 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 695415bc4..bb471e442 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,7 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +TensorKit = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec" TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f" [compat] diff --git a/docs/make.jl b/docs/make.jl index 0c2e0a35a..398fde0f7 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,7 @@ using Documenter using Random using TensorKit, TensorKitSectors +using TensorKit: FusionTreePair, Index2Tuple pages = ["Home" => "index.md", "Manual" => ["man/intro.md", "man/tutorial.md", "man/categories.md", diff --git a/docs/src/lib/sectors.md b/docs/src/lib/sectors.md index 4029b68cf..09e725b0e 100644 --- a/docs/src/lib/sectors.md +++ b/docs/src/lib/sectors.md @@ -90,7 +90,7 @@ insertat split merge elementary_trace -planar_trace(f::FusionTree{I,N}, q1::IndexTuple{N₃}, q2::IndexTuple{N₃}) where {I<:Sector,N,N₃} +planar_trace(f::FusionTree{I,N}, q::Index2Tuple{N₃,N₃}) where {I<:Sector,N,N₃} artin_braid braid(f::FusionTree{I,N}, levels::NTuple{N,Int}, p::NTuple{N,Int}) where {I<:Sector,N} permute(f::FusionTree{I,N}, p::NTuple{N,Int}) where {I<:Sector,N} @@ -113,8 +113,8 @@ Finally, these are used to define large manipulations of fusion-splitting tree p are then used in the index manipulation of `AbstractTensorMap` objects. The following methods defined on fusion splitting tree pairs have an associated definition for tensors. ```@docs -repartition(::FusionTree{I,N₁}, ::FusionTree{I,N₂}, ::Int) where {I<:Sector,N₁,N₂} -transpose(::FusionTree{I}, ::FusionTree{I}, ::IndexTuple{N₁}, ::IndexTuple{N₂}) where {I<:Sector,N₁,N₂} -braid(::FusionTree{I}, ::FusionTree{I}, ::IndexTuple, ::IndexTuple, ::IndexTuple{N₁}, ::IndexTuple{N₂}) where {I<:Sector,N₁,N₂} -permute(::FusionTree{I}, ::FusionTree{I}, ::IndexTuple{N₁}, ::IndexTuple{N₂}) where {I<:Sector,N₁,N₂} +repartition(::FusionTreePair, ::Int) +transpose(::FusionTreePair{I}, ::Index2Tuple{N₁,N₂}) where {I,N₁,N₂} +braid(::FusionTreePair, ::Index2Tuple, ::Index2Tuple) +permute(::FusionTreePair, ::Index2Tuple) ``` diff --git a/docs/src/man/sectors.md b/docs/src/man/sectors.md index 30590ef00..768e3f608 100644 --- a/docs/src/man/sectors.md +++ b/docs/src/man/sectors.md @@ -1155,7 +1155,7 @@ the splitting tree. The `FusionTree` interface to duality and line bending is given by -[`repartition(f1::FusionTree{I,N₁}, f2::FusionTree{I,N₂}, N::Int)`](@ref repartition) +[`repartition(f1::FusionTreePair{I,N₁,N₂}, N::Int)`](@ref repartition) which takes a splitting tree `f1` with `N₁` outgoing sectors, a fusion tree `f2` with `N₂` incoming sectors, and applies line bending such that the resulting splitting and fusion @@ -1180,7 +1180,7 @@ With this basic function, we can now perform arbitrary combinations of braids or permutations with line bendings, to completely reshuffle where sectors appear. The interface provided for this is given by -[`braid(f1::FusionTree{I,N₁}, f2::FusionTree{I,N₂}, levels1::NTuple{N₁,Int}, levels2::NTuple{N₂,Int}, p1::NTuple{N₁′,Int}, p2::NTuple{N₂′,Int})`](@ref braid(::FusionTree{I}, ::FusionTree{I}, ::IndexTuple, ::IndexTuple, ::IndexTuple{N₁}, ::IndexTuple{N₂}) where {I<:Sector,N₁,N₂}) +[`braid((f₁, f₂)::FusionTreePair, (p1, p2)::Index2Tuple, (levels1, levels2)::Index2Tuple)`](@ref braid(::TensorKit.FusionTreePair, ::Index2Tuple, ::Index2Tuple)) where we now have splitting tree `f1` with `N₁` outgoing sectors, a fusion tree `f2` with `N₂` incoming sectors, `levels1` and `levels2` assign a level or depth to the corresponding @@ -1206,7 +1206,7 @@ As before, there is a simplified interface for the case where `BraidingStyle(I) isa SymmetricBraiding` and the levels are not needed. This is simply given by -[`permute(f1::FusionTree{I,N₁}, f2::FusionTree{I,N₂}, p1::NTuple{N₁′,Int}, p2::NTuple{N₂′,Int})`](@ref permute(::FusionTree{I}, ::FusionTree{I}, ::IndexTuple{N₁}, ::IndexTuple{N₂}) where {I<:Sector,N₁,N₂}) +[`permute((f₁, f₂)::FusionTreePair, (p1, p2)::Index2Tuple)`](@ref permute(::FusionTreePair, ::Index2Tuple)) The `braid` and `permute` routines for double fusion trees will be the main access point for corresponding manipulations on tensors. As a consequence, results from this routine are diff --git a/src/fusiontrees/manipulations.jl b/src/fusiontrees/manipulations.jl index 09aa345ba..3030e89d7 100644 --- a/src/fusiontrees/manipulations.jl +++ b/src/fusiontrees/manipulations.jl @@ -487,9 +487,9 @@ outgoing (`f₁`) and incoming sectors (`f₂`) respectively (with identical cou repartitioning the tree by bending incoming to outgoing sectors (or vice versa) in order to have `N` outgoing sectors. """ -@inline function repartition((f₁, f₂)::FusionTreePair{I,N₁,N₂}, N::Int) where {I,N₁,N₂} +@inline function repartition((f₁, f₂)::FusionTreePair, N::Int) f₁.coupled == f₂.coupled || throw(SectorMismatch()) - @assert 0 <= N <= N₁ + N₂ + @assert 0 <= N <= length(f₁) + length(f₂) return _recursive_repartition((f₁, f₂), Val(N)) end @@ -1002,8 +1002,7 @@ end # braid double fusion tree """ - braid((f₁, f₂)::FusionTreePair{I}, (p1, p2)::Index2Tuple{N₁,N₂}, - (levels1, levels2)::Index2Tuple) where {I,N₁,N₂} + braid((f₁, f₂)::FusionTreePair, (p1, p2)::Index2Tuple, (levels1, levels2)::Index2Tuple) -> <:AbstractDict{<:FusionTreePair{I, N₁, N₂}}, <:Number} Input is a fusion-splitting tree pair that describes the fusion of a set of incoming @@ -1018,9 +1017,8 @@ respectively, which determines how indices braid. In particular, if `i` and `j` levels[j]`. This does not allow to encode the most general braid, but a general braid can be obtained by combining such operations. """ -function braid((f₁, f₂)::FusionTreePair{I}, (p1, p2)::Index2Tuple{N₁,N₂}, - (levels1, levels2)::Index2Tuple) where {I,N₁,N₂} - @assert length(f₁) + length(f₂) == N₁ + N₂ +function braid((f₁, f₂)::FusionTreePair, (p1, p2)::Index2Tuple, (levels1, levels2)::Index2Tuple) + @assert length(f₁) + length(f₂) == length(p1) + length(p2) @assert length(f₁) == length(levels1) && length(f₂) == length(levels2) @assert TupleTools.isperm((p1..., p2...)) return fsbraid(((f₁, f₂), (p1, p2), (levels1, levels2))) @@ -1056,7 +1054,7 @@ function CacheStyle(::typeof(fsbraid), k::FSBraidKey{I}) where {I<:Sector} end """ - permute((f₁, f₂)::FusionTreePair{I}, (p1, p2)::Index2Tuple{N₁, N₂}) where {I, N₁, N₂} + permute((f₁, f₂)::FusionTreePair, (p1, p2)::Index2Tuple) -> <:AbstractDict{<:FusionTreePair{I, N₁, N₂}}, <:Number} Input is a double fusion tree that describes the fusion of a set of incoming uncoupled @@ -1066,8 +1064,8 @@ outgoing (`t1`) and incoming sectors (`t2`) respectively (with identical coupled repartitioning and permuting the tree such that sectors `p1` become outgoing and sectors `p2` become incoming. """ -function permute((f₁, f₂)::FusionTreePair{I}, (p1, p2)::Index2Tuple{N₁,N₂}) where {I,N₁,N₂} - @assert BraidingStyle(I) isa SymmetricBraiding +function permute((f₁, f₂)::FusionTreePair, (p1, p2)::Index2Tuple) + @assert BraidingStyle(sectortype(f₁)) isa SymmetricBraiding levels1 = ntuple(identity, length(f₁)) levels2 = length(f₁) .+ ntuple(identity, length(f₂)) return braid((f₁, f₂), (p1, p2), (levels1, levels2)) diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl index f29bdf809..8a3c3d1e9 100644 --- a/src/tensors/linalg.jl +++ b/src/tensors/linalg.jl @@ -71,8 +71,6 @@ end Construct the identity endomorphism on space `V`, i.e. return a `t::TensorMap` with `domain(t) == codomain(t) == V`, where either `scalartype(t) = T` if `T` is a `Number` type or `storagetype(t) = T` if `T` is a `DenseVector` type. - -See also [`one!`](@ref). """ id, id! id(V::TensorSpace) = id(Float64, V) From ecbdaa62915a3e4dc7aa83e968dcf4272fdb3aac Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 14 Aug 2025 10:02:04 +0200 Subject: [PATCH 15/28] Avoid using `one(I)` --- src/fusiontrees/fusiontreeblocks.jl | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/fusiontrees/fusiontreeblocks.jl b/src/fusiontrees/fusiontreeblocks.jl index a0b9f8432..0022d4652 100644 --- a/src/fusiontrees/fusiontreeblocks.jl +++ b/src/fusiontrees/fusiontreeblocks.jl @@ -9,9 +9,16 @@ function FusionTreeBlock{I}(uncoupled::Tuple{NTuple{N₁,I},NTuple{N₂,I}}, F₂ = fusiontreetype(I, N₂) trees = Vector{Tuple{F₁,F₂}}(undef, 0) - cleft = N₁ == 0 ? (one(I),) : ⊗(uncoupled[1]...) - cright = N₂ == 0 ? (one(I),) : ⊗(uncoupled[2]...) - cs = sort!(collect(intersect(cleft, cright))) + if N₁ == N₂ == 0 + return FusionTreeBlock(trees) + elseif N₁ == 0 + cs = sort!(collect(filter(isone, ⊗(uncoupled[2]...)))) + elseif N₂ == 0 + cs = sort!(collect(filter(isone, ⊗(uncoupled[1]...)))) + else + cs = sort!(collect(intersect(⊗(uncoupled[1]...), ⊗(uncoupled[2]...)))) + end + for c in cs for f₁ in fusiontrees(uncoupled[1], c, isdual[1]), f₂ in fusiontrees(uncoupled[2], c, isdual[2]) From 0c05153d6a12ad0c19d6137d2b38617960bdde79 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 14 Aug 2025 10:26:19 +0200 Subject: [PATCH 16/28] format --- src/fusiontrees/manipulations.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fusiontrees/manipulations.jl b/src/fusiontrees/manipulations.jl index 3030e89d7..92742efc1 100644 --- a/src/fusiontrees/manipulations.jl +++ b/src/fusiontrees/manipulations.jl @@ -1017,7 +1017,8 @@ respectively, which determines how indices braid. In particular, if `i` and `j` levels[j]`. This does not allow to encode the most general braid, but a general braid can be obtained by combining such operations. """ -function braid((f₁, f₂)::FusionTreePair, (p1, p2)::Index2Tuple, (levels1, levels2)::Index2Tuple) +function braid((f₁, f₂)::FusionTreePair, (p1, p2)::Index2Tuple, + (levels1, levels2)::Index2Tuple) @assert length(f₁) + length(f₂) == length(p1) + length(p2) @assert length(f₁) == length(levels1) && length(f₂) == length(levels2) @assert TupleTools.isperm((p1..., p2...)) From 4470e2fd64fa7076316f9ecf2d362d1d6d86db3f Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 15 Aug 2025 11:59:56 +0200 Subject: [PATCH 17/28] Move independent computations out of loop --- src/fusiontrees/fusiontreeblocks.jl | 6 ++++-- src/fusiontrees/manipulations.jl | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/fusiontrees/fusiontreeblocks.jl b/src/fusiontrees/fusiontreeblocks.jl index 0022d4652..4df3069e7 100644 --- a/src/fusiontrees/fusiontreeblocks.jl +++ b/src/fusiontrees/fusiontreeblocks.jl @@ -229,6 +229,7 @@ function foldright(src::FusionTreeBlock) c ∈ cset || continue for μ in 1:Nsymbol(c1, c2, c) fc = FusionTree((c1, c2), c, (!isduala, false), (), (μ,)) + frs_coeffs = insertat(fc, 2, f₂) for (fl′, coeff1) in insertat(fc, 2, f₁) N₁ > 1 && !isone(fl′.innerlines[1]) && continue coupled = fl′.coupled @@ -237,7 +238,7 @@ function foldright(src::FusionTreeBlock) inner = N₁ <= 3 ? () : Base.tail(Base.tail(fl′.innerlines)) vertices = N₁ <= 2 ? () : Base.tail(Base.tail(fl′.vertices)) fl = FusionTree{I}(uncoupled, coupled, isdual, inner, vertices) - for (fr, coeff2) in insertat(fc, 2, f₂) + for (fr, coeff2) in frs_coeffs coeff = factor * coeff1 * conj(coeff2) row = indexmap[(fl, fr)] @inbounds U[row, col] = coeff @@ -297,6 +298,7 @@ function foldleft(src::FusionTreeBlock) c ∈ cset || continue for μ in 1:Nsymbol(c1, c2, c) fc = FusionTree((c1, c2), c, (!isduala, false), (), (μ,)) + fr_coeffs = insertat(fc, 2, f₂) for (fl′, coeff1) in insertat(fc, 2, f₁) N₁ > 1 && !isone(fl′.innerlines[1]) && continue coupled = fl′.coupled @@ -305,7 +307,7 @@ function foldleft(src::FusionTreeBlock) inner = N₁ <= 3 ? () : Base.tail(Base.tail(fl′.innerlines)) vertices = N₁ <= 2 ? () : Base.tail(Base.tail(fl′.vertices)) fl = FusionTree{I}(uncoupled, coupled, isdual, inner, vertices) - for (fr, coeff2) in insertat(fc, 2, f₂) + for (fr, coeff2) in fr_coeffs coeff = factor * coeff1 * conj(coeff2) row = indexmap[(fr, fl)] @inbounds U[row, col] = conj(coeff) diff --git a/src/fusiontrees/manipulations.jl b/src/fusiontrees/manipulations.jl index 92742efc1..c01b5641c 100644 --- a/src/fusiontrees/manipulations.jl +++ b/src/fusiontrees/manipulations.jl @@ -368,6 +368,7 @@ function foldright((f₁, f₂)::FusionTreePair{I,N₁,N₂}) where {I,N₁,N₂ c ∈ cset || continue for μ in 1:Nsymbol(c1, c2, c) fc = FusionTree((c1, c2), c, (!isduala, false), (), (μ,)) + fr_coeffs = insertat(fc, 2, f₂) for (fl′, coeff1) in insertat(fc, 2, f₁) N₁ > 1 && !isone(fl′.innerlines[1]) && continue coupled = fl′.coupled @@ -376,7 +377,7 @@ function foldright((f₁, f₂)::FusionTreePair{I,N₁,N₂}) where {I,N₁,N₂ inner = N₁ <= 3 ? () : Base.tail(Base.tail(fl′.innerlines)) vertices = N₁ <= 2 ? () : Base.tail(Base.tail(fl′.vertices)) fl = FusionTree{I}(uncoupled, coupled, isdual, inner, vertices) - for (fr, coeff2) in insertat(fc, 2, f₂) + for (fr, coeff2) in fr_coeffs coeff = factor * coeff1 * conj(coeff2) if (@isdefined newtrees) newtrees[(fl, fr)] = get(newtrees, (fl, fr), zero(coeff)) + From 731a832d30cbdc058893e39a904dd67e02aecf72 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 15 Aug 2025 12:12:19 +0200 Subject: [PATCH 18/28] add utility fusiontreetype --- src/fusiontrees/fusiontrees.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/fusiontrees/fusiontrees.jl b/src/fusiontrees/fusiontrees.jl index 1da733ece..48a0c9802 100644 --- a/src/fusiontrees/fusiontrees.jl +++ b/src/fusiontrees/fusiontrees.jl @@ -145,6 +145,9 @@ function fusiontreetype(::Type{I}, N::Int) where {I<:Sector} FusionTree{I,N,N - 2,N - 1} end end +function fusiontreetype(::Type{I}, N₁::Int, N₂::Int) where {I<:Sector} + return Tuple{fusiontreetype(I, N₁),fusiontreetype(I, N₂)} +end # converting to actual array function Base.convert(A::Type{<:AbstractArray}, f::FusionTree{I,0}) where {I} From fff2d5a64d62884197ce69b0ddb628024c49e576 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 15 Aug 2025 12:12:32 +0200 Subject: [PATCH 19/28] add multithreaded treetransformer implementation --- src/tensors/treetransformers.jl | 89 ++++++++++++++++++++++++++------- 1 file changed, 71 insertions(+), 18 deletions(-) diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index d31706dc4..b1d8d84ef 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -66,32 +66,85 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) I = sectortype(Vsrc) T = sectorscalartype(I) N = numind(Vdst) - data = Vector{_GenericTransformerData{T,N}}() isdual_src = (map(isdual, codomain(Vsrc).spaces), map(isdual, domain(Vsrc).spaces)) - for cod_uncoupled_src in sectors(codomain(Vsrc)), - dom_uncoupled_src in sectors(domain(Vsrc)) - fs_src = FusionTreeBlock{I}((cod_uncoupled_src, dom_uncoupled_src), isdual_src) - trees_src = fusiontrees(fs_src) - isempty(trees_src) && continue + nthreads = get_num_transformer_threads() + if nthreads > 1 + fusiontreeblocks = Vector{FusionTreeBlock{I,N₁,N₂,fusiontreetype(I, N₁, N₂)}}() + for cod_uncoupled_src in sectors(codomain(Vsrc)), + dom_uncoupled_src in sectors(domain(Vsrc)) + + fs_src = FusionTreeBlock{I}((cod_uncoupled_src, dom_uncoupled_src), isdual_src) + trees_src = fusiontrees(fs_src) + if !isempty(trees_src) + push!(fusiontreeblocks, fs_src) + end + end + + data = Vector{_GenericTransformerData{T,N}}(undef, length(fusiontreeblocks)) + counter = Threads.Atomic{Int}(1) + Threads.@sync for _ in 1:min(nthreads, length(fusiontreeblocks)) + Threads.@spawn begin + while true + local_counter = Threads.atomic_add!(counter, 1) + local_counter > nblocks && break + fs_src = fusiontreeblocks[local_counter] + fs_dst, U = transform(fs_src) + matrix = copy(transpose(U)) # TODO: should we avoid this + + inds_src = map(Base.Fix1(getindex, structure_src.fusiontreeindices), + trees_src) + trees_dst = fusiontrees(fs_dst) + inds_dst = map(Base.Fix1(getindex, structure_dst.fusiontreeindices), + trees_dst) + + # size is shared between blocks, so repack: + # from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...]) + sz_src, newstructs_src = repack_transformer_structure(fusionstructure_src, + inds_src) + sz_dst, newstructs_dst = repack_transformer_structure(fusionstructure_dst, + inds_dst) + + @debug("Created recoupling block for uncoupled: $uncoupled", + sz = size(matrix), + sparsity = count(!iszero, matrix) / length(matrix)) + + data[local_counter] = (matrix, (sz_dst, newstructs_dst), + (sz_src, newstructs_src)) + end + end + end + else + data = Vector{_GenericTransformerData{T,N}}() - fs_dst, U = transform(fs_src) - matrix = copy(transpose(U)) # TODO: should we avoid this + isdual_src = (map(isdual, codomain(Vsrc).spaces), map(isdual, domain(Vsrc).spaces)) + for cod_uncoupled_src in sectors(codomain(Vsrc)), + dom_uncoupled_src in sectors(domain(Vsrc)) - inds_src = map(Base.Fix1(getindex, structure_src.fusiontreeindices), trees_src) - trees_dst = fusiontrees(fs_dst) - inds_dst = map(Base.Fix1(getindex, structure_dst.fusiontreeindices), trees_dst) + fs_src = FusionTreeBlock{I}((cod_uncoupled_src, dom_uncoupled_src), isdual_src) + trees_src = fusiontrees(fs_src) + isempty(trees_src) && continue - # size is shared between blocks, so repack: - # from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...]) - sz_src, newstructs_src = repack_transformer_structure(fusionstructure_src, inds_src) - sz_dst, newstructs_dst = repack_transformer_structure(fusionstructure_dst, inds_dst) + fs_dst, U = transform(fs_src) + matrix = copy(transpose(U)) # TODO: should we avoid this - @debug("Created recoupling block for uncoupled: $uncoupled", - sz = size(matrix), sparsity = count(!iszero, matrix) / length(matrix)) + inds_src = map(Base.Fix1(getindex, structure_src.fusiontreeindices), trees_src) + trees_dst = fusiontrees(fs_dst) + inds_dst = map(Base.Fix1(getindex, structure_dst.fusiontreeindices), trees_dst) - push!(data, (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src))) + # size is shared between blocks, so repack: + # from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...]) + sz_src, newstructs_src = repack_transformer_structure(fusionstructure_src, + inds_src) + sz_dst, newstructs_dst = repack_transformer_structure(fusionstructure_dst, + inds_dst) + + @debug("Created recoupling block for uncoupled: $uncoupled", + sz = size(matrix), sparsity = count(!iszero, matrix) / length(matrix)) + + push!(data, (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src))) + end end transformer = GenericTreeTransformer{T,N}(data) From 54b7abca31473cc0114eb7d4a5de1d4869ee00b6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 15 Aug 2025 13:00:24 +0200 Subject: [PATCH 20/28] refactor treeindex_map --- src/fusiontrees/fusiontreeblocks.jl | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/fusiontrees/fusiontreeblocks.jl b/src/fusiontrees/fusiontreeblocks.jl index 4df3069e7..10eb3d7f1 100644 --- a/src/fusiontrees/fusiontreeblocks.jl +++ b/src/fusiontrees/fusiontreeblocks.jl @@ -54,12 +54,18 @@ numind(::Type{T}) where {T<:FusionTreeBlock} = numin(T) + numout(T) fusiontrees(block::FusionTreeBlock) = block.trees Base.length(block::FusionTreeBlock) = length(fusiontrees(block)) +function treeindex_map(fs::FusionTreeBlock) + I = sectortype(fs) + return fusiontreedict(I)(f => ind for (ind, f) in enumerate(fusiontrees(fs))) +end + + # Manipulations # ------------- function transformation_matrix(transform, dst::FusionTreeBlock{I}, src::FusionTreeBlock{I}) where {I} U = zeros(sectorscalartype(I), length(dst), length(src)) - indexmap = Dict(f => ind for (ind, f) in enumerate(fusiontrees(dst))) + indexmap = treeindex_map(dst) for (col, f) in enumerate(fusiontrees(src)) for (f′, c) in transform(f) row = indexmap[f′] @@ -80,7 +86,7 @@ function bendright(src::FusionTreeBlock) @assert N₁ > 0 dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst) - indexmap = fusiontreedict(I)(f => ind for (ind, f) in enumerate(fusiontrees(dst))) + indexmap = treeindex_map(dst) U = zeros(sectorscalartype(I), length(dst), length(src)) for (col, (f₁, f₂)) in enumerate(fusiontrees(src)) @@ -139,7 +145,7 @@ function bendleft(src::FusionTreeBlock) @assert N₁ > 0 dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst) - indexmap = fusiontreedict(I)(f => ind for (ind, f) in enumerate(fusiontrees(dst))) + indexmap = treeindex_map(dst) U = zeros(sectorscalartype(I), length(dst), length(src)) for (col, (f₂, f₁)) in enumerate(fusiontrees(src)) @@ -196,7 +202,7 @@ function foldright(src::FusionTreeBlock) dst = FusionTreeBlock{sectortype(src)}(uncoupled_dst, isdual_dst) dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst) - indexmap = fusiontreedict(I)(f => ind for (ind, f) in enumerate(fusiontrees(dst))) + indexmap = treeindex_map(dst) U = zeros(sectorscalartype(I), length(dst), length(src)) for (col, (f₁, f₂)) in enumerate(fusiontrees(src)) @@ -265,7 +271,7 @@ function foldleft(src::FusionTreeBlock) @assert N₁ > 0 dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst) - indexmap = fusiontreedict(I)(f => ind for (ind, f) in enumerate(fusiontrees(dst))) + indexmap = treeindex_map(dst) U = zeros(sectorscalartype(I), length(dst), length(src)) for (col, (f₂, f₁)) in enumerate(fusiontrees(src)) From fe3cfd4e09d4589b81946dbf0d60fa1dd7114958 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 16 Aug 2025 18:06:26 +0200 Subject: [PATCH 21/28] Refactor artin_braid to avoid extra dicts --- src/fusiontrees/fusiontreeblocks.jl | 125 +++++++++++++++++++++++++++- 1 file changed, 121 insertions(+), 4 deletions(-) diff --git a/src/fusiontrees/fusiontreeblocks.jl b/src/fusiontrees/fusiontreeblocks.jl index 10eb3d7f1..607f32717 100644 --- a/src/fusiontrees/fusiontreeblocks.jl +++ b/src/fusiontrees/fusiontreeblocks.jl @@ -59,7 +59,6 @@ function treeindex_map(fs::FusionTreeBlock) return fusiontreedict(I)(f => ind for (ind, f) in enumerate(fusiontrees(fs))) end - # Manipulations # ------------- function transformation_matrix(transform, dst::FusionTreeBlock{I}, @@ -451,10 +450,128 @@ function artin_braid(src::FusionTreeBlock{I,N,0}, i; inv::Bool=false) where {I,N isdual′ = TupleTools.setindex(isdual′, isdual[i + 1], i) dst = FusionTreeBlock{I}((uncoupled′, ()), (isdual′, ())) - # TODO: do we want to rewrite `artin_braid` to take double trees instead? - U = transformation_matrix(dst, src) do (f₁, f₂) - return ((f₁′, f₂) => c for (f₁′, c) in artin_braid(f₁, i; inv)) + indexmap = treeindex_map(dst) + U = zeros(sectorscalartype(I), length(dst), length(src)) + + for (col, (f, f₂)) in enumerate(fusiontrees(src)) + a, b = uncoupled[i], uncoupled[i + 1] + uncoupled′ = TupleTools.setindex(uncoupled, b, i) + uncoupled′ = TupleTools.setindex(uncoupled′, a, i + 1) + coupled′ = f.coupled + isdual′ = TupleTools.setindex(f.isdual, f.isdual[i], i + 1) + isdual′ = TupleTools.setindex(isdual′, f.isdual[i + 1], i) + inner = f.innerlines + inner_extended = (uncoupled[1], inner..., coupled′) + vertices = f.vertices + oneT = one(sectorscalartype(I)) + + if isone(uncoupled[i]) || isone(uncoupled[i + 1]) + # braiding with trivial sector: simple and always possible + inner′ = inner + vertices′ = vertices + if i > 1 # we also need to alter innerlines and vertices + inner′ = TupleTools.setindex(inner, + inner_extended[isone(a) ? (i + 1) : (i - 1)], + i - 1) + vertices′ = TupleTools.setindex(vertices′, vertices[i], i - 1) + vertices′ = TupleTools.setindex(vertices′, vertices[i - 1], i) + end + f′ = FusionTree{I}(uncoupled′, coupled′, isdual′, inner′, vertices′) + row = indexmap[(f′, f₂)] + @inbounds U[row, col] = oneT + continue + end + + BraidingStyle(I) isa NoBraiding && + throw(SectorMismatch("Cannot braid sectors $(uncoupled[i]) and $(uncoupled[i + 1])")) + + if i == 1 + c = N > 2 ? inner[1] : coupled′ + if FusionStyle(I) isa MultiplicityFreeFusion + R = oftype(oneT, (inv ? conj(Rsymbol(b, a, c)) : Rsymbol(a, b, c))) + f′ = FusionTree{I}(uncoupled′, coupled′, isdual′, inner, vertices) + row = indexmap[(f′, f₂)] + @inbounds U[row, col] = R + else # GenericFusion + μ = vertices[1] + Rmat = inv ? Rsymbol(b, a, c)' : Rsymbol(a, b, c) + local newtrees + for ν in axes(Rmat, 2) + R = oftype(oneT, Rmat[μ, ν]) + iszero(R) && continue + vertices′ = TupleTools.setindex(vertices, ν, 1) + f′ = FusionTree{I}(uncoupled′, coupled′, isdual′, inner, vertices′) + row = indexmap[(f′, f₂)] + @inbounds U[row, col] = R + end + end + continue + end + # case i > 1: other naming convention + b = uncoupled[i] + d = uncoupled[i + 1] + a = inner_extended[i - 1] + c = inner_extended[i] + e = inner_extended[i + 1] + if FusionStyle(I) isa UniqueFusion + c′ = first(a ⊗ d) + coeff = oftype(oneT, + if inv + conj(Rsymbol(d, c, e) * Fsymbol(d, a, b, e, c′, c)) * + Rsymbol(d, a, c′) + else + Rsymbol(c, d, e) * + conj(Fsymbol(d, a, b, e, c′, c) * Rsymbol(a, d, c′)) + end) + inner′ = TupleTools.setindex(inner, c′, i - 1) + f′ = FusionTree{I}(uncoupled′, coupled′, isdual′, inner′) + row = indexmap[(f′, f₂)] + @inbounds U[row, col] = coeff + elseif FusionStyle(I) isa SimpleFusion + cs = collect(I, intersect(a ⊗ d, e ⊗ conj(b))) + for c′ in cs + coeff = oftype(oneT, + if inv + conj(Rsymbol(d, c, e) * Fsymbol(d, a, b, e, c′, c)) * + Rsymbol(d, a, c′) + else + Rsymbol(c, d, e) * + conj(Fsymbol(d, a, b, e, c′, c) * Rsymbol(a, d, c′)) + end) + iszero(coeff) && continue + inner′ = TupleTools.setindex(inner, c′, i - 1) + f′ = FusionTree{I}(uncoupled′, coupled′, isdual′, inner′) + row = indexmap[(f′, f₂)] + @inbounds U[row, col] = coeff + end + else # GenericFusion + cs = collect(I, intersect(a ⊗ d, e ⊗ conj(b))) + for c′ in cs + Rmat1 = inv ? Rsymbol(d, c, e)' : Rsymbol(c, d, e) + Rmat2 = inv ? Rsymbol(d, a, c′)' : Rsymbol(a, d, c′) + Fmat = Fsymbol(d, a, b, e, c′, c) + μ = vertices[i - 1] + ν = vertices[i] + for σ in 1:Nsymbol(a, d, c′) + for λ in 1:Nsymbol(c′, b, e) + coeff = zero(oneT) + for ρ in 1:Nsymbol(d, c, e), κ in 1:Nsymbol(d, a, c′) + coeff += Rmat1[ν, ρ] * conj(Fmat[κ, λ, μ, ρ]) * + conj(Rmat2[σ, κ]) + end + iszero(coeff) && continue + vertices′ = TupleTools.setindex(vertices, σ, i - 1) + vertices′ = TupleTools.setindex(vertices′, λ, i) + inner′ = TupleTools.setindex(inner, c′, i - 1) + f′ = FusionTree{I}(uncoupled′, coupled′, isdual′, inner′, vertices′) + row = indexmap[(f′, f₂)] + @inbounds U[row, col] = coeff + end + end + end + end end + return dst, U end From 9247e381bdc0935493c3214f797bfb8a1aebf084 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 16 Aug 2025 18:07:38 +0200 Subject: [PATCH 22/28] type stability improvements --- src/fusiontrees/fusiontreeblocks.jl | 5 ++++- src/fusiontrees/fusiontrees.jl | 4 ++-- src/fusiontrees/manipulations.jl | 6 ++++-- src/tensors/treetransformers.jl | 29 ++++++++++++----------------- 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/fusiontrees/fusiontreeblocks.jl b/src/fusiontrees/fusiontreeblocks.jl index 607f32717..f2a619428 100644 --- a/src/fusiontrees/fusiontreeblocks.jl +++ b/src/fusiontrees/fusiontreeblocks.jl @@ -607,7 +607,10 @@ end const _FSBraidKey{I,N₁,N₂} = Tuple{<:FusionTreeBlock{I},Index2Tuple{N₁,N₂},Index2Tuple} -@cached function _fsbraid(key::_FSBraidKey{I,N₁,N₂})::Tuple{FusionTreeBlock{I,N₁,N₂}, +@cached function _fsbraid(key::_FSBraidKey{I,N₁,N₂})::Tuple{FusionTreeBlock{I,N₁,N₂, + fusiontreetype(I, + N₁, + N₂)}, Matrix{sectorscalartype(I)}} where {I, N₁, N₂} diff --git a/src/fusiontrees/fusiontrees.jl b/src/fusiontrees/fusiontrees.jl index 48a0c9802..7fb46d6a7 100644 --- a/src/fusiontrees/fusiontrees.jl +++ b/src/fusiontrees/fusiontrees.jl @@ -136,7 +136,7 @@ end Base.:(==)(f₁::FusionTree, f₂::FusionTree) = false # Facilitate getting correct fusion tree types -function fusiontreetype(::Type{I}, N::Int) where {I<:Sector} +Base.@assume_effects :foldable function fusiontreetype(::Type{I}, N::Int) where {I<:Sector} if N === 0 FusionTree{I,0,0,0} elseif N === 1 @@ -145,7 +145,7 @@ function fusiontreetype(::Type{I}, N::Int) where {I<:Sector} FusionTree{I,N,N - 2,N - 1} end end -function fusiontreetype(::Type{I}, N₁::Int, N₂::Int) where {I<:Sector} +Base.@assume_effects :foldable function fusiontreetype(::Type{I}, N₁::Int, N₂::Int) where {I<:Sector} return Tuple{fusiontreetype(I, N₁),fusiontreetype(I, N₂)} end diff --git a/src/fusiontrees/manipulations.jl b/src/fusiontrees/manipulations.jl index c01b5641c..ed9869760 100644 --- a/src/fusiontrees/manipulations.jl +++ b/src/fusiontrees/manipulations.jl @@ -884,7 +884,8 @@ function artin_braid(f::FusionTree{I,N}, i; inv::Bool=false) where {I,N} return fusiontreedict(I)(f′ => coeff) elseif FusionStyle(I) isa SimpleFusion local newtrees - for c′ in intersect(a ⊗ d, e ⊗ conj(b)) + cs = collect(I, intersect(a ⊗ d, e ⊗ conj(b))) + for c′ in cs coeff = oftype(oneT, if inv conj(Rsymbol(d, c, e) * Fsymbol(d, a, b, e, c′, c)) * @@ -905,7 +906,8 @@ function artin_braid(f::FusionTree{I,N}, i; inv::Bool=false) where {I,N} return newtrees else # GenericFusion local newtrees - for c′ in intersect(a ⊗ d, e ⊗ conj(b)) + cs = collect(I, intersect(a ⊗ d, e ⊗ conj(b))) + for c′ in cs Rmat1 = inv ? Rsymbol(d, c, e)' : Rsymbol(c, d, e) Rmat2 = inv ? Rsymbol(d, a, c′)' : Rsymbol(a, d, c′) Fmat = Fsymbol(d, a, b, e, c′, c) diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index b1d8d84ef..3805a22a1 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -66,9 +66,13 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) I = sectortype(Vsrc) T = sectorscalartype(I) N = numind(Vdst) + N₁ = numout(Vsrc) + N₂ = numin(Vsrc) isdual_src = (map(isdual, codomain(Vsrc).spaces), map(isdual, domain(Vsrc).spaces)) + data = Vector{_GenericTransformerData{T,N}}() + nthreads = get_num_transformer_threads() if nthreads > 1 fusiontreeblocks = Vector{FusionTreeBlock{I,N₁,N₂,fusiontreetype(I, N₁, N₂)}}() @@ -82,7 +86,7 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) end end - data = Vector{_GenericTransformerData{T,N}}(undef, length(fusiontreeblocks)) + resize!(data, length(fusiontreeblocks)) counter = Threads.Atomic{Int}(1) Threads.@sync for _ in 1:min(nthreads, length(fusiontreeblocks)) Threads.@spawn begin @@ -106,18 +110,13 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) sz_dst, newstructs_dst = repack_transformer_structure(fusionstructure_dst, inds_dst) - @debug("Created recoupling block for uncoupled: $uncoupled", - sz = size(matrix), - sparsity = count(!iszero, matrix) / length(matrix)) - - data[local_counter] = (matrix, (sz_dst, newstructs_dst), - (sz_src, newstructs_src)) + data1[local_counter] = (matrix, (sz_dst, newstructs_dst), + (sz_src, newstructs_src)) end end end + transformer = GenericTreeTransformer{T,N}(data) else - data = Vector{_GenericTransformerData{T,N}}() - isdual_src = (map(isdual, codomain(Vsrc).spaces), map(isdual, domain(Vsrc).spaces)) for cod_uncoupled_src in sectors(codomain(Vsrc)), dom_uncoupled_src in sectors(domain(Vsrc)) @@ -140,24 +139,20 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) sz_dst, newstructs_dst = repack_transformer_structure(fusionstructure_dst, inds_dst) - @debug("Created recoupling block for uncoupled: $uncoupled", - sz = size(matrix), sparsity = count(!iszero, matrix) / length(matrix)) - push!(data, (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src))) end + transformer = GenericTreeTransformer{T,N}(data) end - transformer = GenericTreeTransformer{T,N}(data) - # sort by (approximate) weight to facilitate multi-threading strategies sort!(transformer) Δt = Base.time() - t₀ @debug("TreeTransformer for $Vsrc to $Vdst via $p", - nblocks = length(data), - sz_median = size(data[cld(end, 2)][1], 1), - sz_max = size(data[1][1], 1), + nblocks = length(transformer.data), + sz_median = size(transformer.data[cld(end, 2)][1], 1), + sz_max = size(transformer.data[1][1], 1), Δt) return transformer From ed9e260e1e71b47a8a77fdcf28248b798b207bb1 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 17 Aug 2025 09:36:20 +0200 Subject: [PATCH 23/28] fix multithreaded implementation --- src/tensors/treetransformers.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index 3805a22a1..1dc1dd832 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -85,10 +85,11 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) push!(fusiontreeblocks, fs_src) end end + nblocks = length(fusiontreeblocks) - resize!(data, length(fusiontreeblocks)) + resize!(data, nblocks) counter = Threads.Atomic{Int}(1) - Threads.@sync for _ in 1:min(nthreads, length(fusiontreeblocks)) + Threads.@sync for _ in 1:min(nthreads, nblocks) Threads.@spawn begin while true local_counter = Threads.atomic_add!(counter, 1) @@ -97,6 +98,7 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) fs_dst, U = transform(fs_src) matrix = copy(transpose(U)) # TODO: should we avoid this + trees_src = fusiontrees(fs_src) inds_src = map(Base.Fix1(getindex, structure_src.fusiontreeindices), trees_src) trees_dst = fusiontrees(fs_dst) @@ -110,7 +112,7 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) sz_dst, newstructs_dst = repack_transformer_structure(fusionstructure_dst, inds_dst) - data1[local_counter] = (matrix, (sz_dst, newstructs_dst), + data[local_counter] = (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src)) end end From d672fd0d3f474fde9d66daef7f6e9a7dff92ebfd Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 17 Aug 2025 09:37:12 +0200 Subject: [PATCH 24/28] speed up hashing by hashing less things --- src/fusiontrees/fusiontreeblocks.jl | 44 ++++++++++++++++++----------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/src/fusiontrees/fusiontreeblocks.jl b/src/fusiontrees/fusiontreeblocks.jl index f2a619428..2e7763682 100644 --- a/src/fusiontrees/fusiontreeblocks.jl +++ b/src/fusiontrees/fusiontreeblocks.jl @@ -54,9 +54,22 @@ numind(::Type{T}) where {T<:FusionTreeBlock} = numin(T) + numout(T) fusiontrees(block::FusionTreeBlock) = block.trees Base.length(block::FusionTreeBlock) = length(fusiontrees(block)) +# Within one block, all values of uncoupled and isdual are equal, so avoid hashing these +function treeindex_data((f₁, f₂)) + I = sectortype(f₁) + if FusionStyle(I) isa GenericFusion + return (f₁.coupled, f₁.innerlines..., f₂.innerlines...), + (f₁.vertices..., f₂.vertices...) + elseif FusionStyle(I) isa MultipleFusion + return (f₁.coupled, f₁.innerlines..., f₂.innerlines...) + else # there should be only a single element anyways + return false + end +end function treeindex_map(fs::FusionTreeBlock) I = sectortype(fs) - return fusiontreedict(I)(f => ind for (ind, f) in enumerate(fusiontrees(fs))) + return fusiontreedict(I)(treeindex_data(f) => ind + for (ind, f) in enumerate(fusiontrees(fs))) end # Manipulations @@ -112,7 +125,7 @@ function bendright(src::FusionTreeBlock) coeff = coeff₀ * Bsymbol(a, b, c) vertices2 = N₂ > 0 ? (f₂.vertices..., 1) : () f₂′ = FusionTree(uncoupled2, a, isdual2, inner2, vertices2) - row = indexmap[(f₁′, f₂′)] + row = indexmap[treeindex_data((f₁′, f₂′))] @inbounds U[row, col] = coeff else Bmat = Bsymbol(a, b, c) @@ -171,7 +184,7 @@ function bendleft(src::FusionTreeBlock) coeff = coeff₀ * Bsymbol(a, b, c) vertices2 = N₂ > 0 ? (f₂.vertices..., 1) : () f₂′ = FusionTree(uncoupled2, a, isdual2, inner2, vertices2) - row = indexmap[(f₂′, f₁′)] + row = indexmap[treeindex_data((f₂′, f₁′))] @inbounds U[row, col] = conj(coeff) else Bmat = Bsymbol(a, b, c) @@ -181,7 +194,7 @@ function bendleft(src::FusionTreeBlock) iszero(coeff) && continue vertices2 = N₂ > 0 ? (f₂.vertices..., ν) : () f₂′ = FusionTree(uncoupled2, a, isdual2, inner2, vertices2) - row = indexmap[(f₂′, f₁′)] + row = indexmap[treeindex_data((f₂′, f₁′))] @inbounds U[row, col] = conj(coeff) end end @@ -198,7 +211,7 @@ function foldright(src::FusionTreeBlock) N₁ = numout(src) N₂ = numin(src) @assert N₁ > 0 - dst = FusionTreeBlock{sectortype(src)}(uncoupled_dst, isdual_dst) + dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst) dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst) indexmap = treeindex_map(dst) @@ -220,7 +233,7 @@ function foldright(src::FusionTreeBlock) c = first(c1 ⊗ c2) fl = FusionTree{I}(Base.tail(f₁.uncoupled), c, Base.tail(f₁.isdual)) fr = FusionTree{I}((c1, f₂.uncoupled...), c, (!isduala, f₂.isdual...)) - row = indexmap[(fl, fr)] + row = indexmap[treeindex_data((fl, fr))] @inbounds U[row, col] = factor else if N₁ == 1 @@ -245,7 +258,7 @@ function foldright(src::FusionTreeBlock) fl = FusionTree{I}(uncoupled, coupled, isdual, inner, vertices) for (fr, coeff2) in frs_coeffs coeff = factor * coeff1 * conj(coeff2) - row = indexmap[(fl, fr)] + row = indexmap[treeindex_data((fl, fr))] @inbounds U[row, col] = coeff end end @@ -289,7 +302,7 @@ function foldleft(src::FusionTreeBlock) c = first(c1 ⊗ c2) fl = FusionTree{I}(Base.tail(f₁.uncoupled), c, Base.tail(f₁.isdual)) fr = FusionTree{I}((c1, f₂.uncoupled...), c, (!isduala, f₂.isdual...)) - row = indexmap[(fr, fl)] + row = indexmap[treeindex_data((fr, fl))] @inbounds U[row, col] = conj(factor) else if N₁ == 1 @@ -314,7 +327,7 @@ function foldleft(src::FusionTreeBlock) fl = FusionTree{I}(uncoupled, coupled, isdual, inner, vertices) for (fr, coeff2) in fr_coeffs coeff = factor * coeff1 * conj(coeff2) - row = indexmap[(fr, fl)] + row = indexmap[treeindex_data((fr, fl))] @inbounds U[row, col] = conj(coeff) end end @@ -477,7 +490,7 @@ function artin_braid(src::FusionTreeBlock{I,N,0}, i; inv::Bool=false) where {I,N vertices′ = TupleTools.setindex(vertices′, vertices[i - 1], i) end f′ = FusionTree{I}(uncoupled′, coupled′, isdual′, inner′, vertices′) - row = indexmap[(f′, f₂)] + row = indexmap[treeindex_data((f′, f₂))] @inbounds U[row, col] = oneT continue end @@ -490,18 +503,17 @@ function artin_braid(src::FusionTreeBlock{I,N,0}, i; inv::Bool=false) where {I,N if FusionStyle(I) isa MultiplicityFreeFusion R = oftype(oneT, (inv ? conj(Rsymbol(b, a, c)) : Rsymbol(a, b, c))) f′ = FusionTree{I}(uncoupled′, coupled′, isdual′, inner, vertices) - row = indexmap[(f′, f₂)] + row = indexmap[treeindex_data((f′, f₂))] @inbounds U[row, col] = R else # GenericFusion μ = vertices[1] Rmat = inv ? Rsymbol(b, a, c)' : Rsymbol(a, b, c) - local newtrees for ν in axes(Rmat, 2) R = oftype(oneT, Rmat[μ, ν]) iszero(R) && continue vertices′ = TupleTools.setindex(vertices, ν, 1) f′ = FusionTree{I}(uncoupled′, coupled′, isdual′, inner, vertices′) - row = indexmap[(f′, f₂)] + row = indexmap[treeindex_data((f′, f₂))] @inbounds U[row, col] = R end end @@ -525,7 +537,7 @@ function artin_braid(src::FusionTreeBlock{I,N,0}, i; inv::Bool=false) where {I,N end) inner′ = TupleTools.setindex(inner, c′, i - 1) f′ = FusionTree{I}(uncoupled′, coupled′, isdual′, inner′) - row = indexmap[(f′, f₂)] + row = indexmap[treeindex_data((f′, f₂))] @inbounds U[row, col] = coeff elseif FusionStyle(I) isa SimpleFusion cs = collect(I, intersect(a ⊗ d, e ⊗ conj(b))) @@ -541,7 +553,7 @@ function artin_braid(src::FusionTreeBlock{I,N,0}, i; inv::Bool=false) where {I,N iszero(coeff) && continue inner′ = TupleTools.setindex(inner, c′, i - 1) f′ = FusionTree{I}(uncoupled′, coupled′, isdual′, inner′) - row = indexmap[(f′, f₂)] + row = indexmap[treeindex_data((f′, f₂))] @inbounds U[row, col] = coeff end else # GenericFusion @@ -564,7 +576,7 @@ function artin_braid(src::FusionTreeBlock{I,N,0}, i; inv::Bool=false) where {I,N vertices′ = TupleTools.setindex(vertices′, λ, i) inner′ = TupleTools.setindex(inner, c′, i - 1) f′ = FusionTree{I}(uncoupled′, coupled′, isdual′, inner′, vertices′) - row = indexmap[(f′, f₂)] + row = indexmap[treeindex_data((f′, f₂))] @inbounds U[row, col] = coeff end end From f27547e3cb1a15d1970269d8c5a79b72cf784130 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 17 Aug 2025 09:37:39 +0200 Subject: [PATCH 25/28] Slight refactor of artin_braid --- src/fusiontrees/fusiontreeblocks.jl | 40 +++++++++++++++-------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/src/fusiontrees/fusiontreeblocks.jl b/src/fusiontrees/fusiontreeblocks.jl index 2e7763682..6f0b18a74 100644 --- a/src/fusiontrees/fusiontreeblocks.jl +++ b/src/fusiontrees/fusiontreeblocks.jl @@ -456,30 +456,26 @@ function artin_braid(src::FusionTreeBlock{I,N,0}, i; inv::Bool=false) where {I,N throw(ArgumentError("Cannot swap outputs i=$i and i+1 out of only $N outputs")) uncoupled = src.uncoupled[1] - uncoupled′ = TupleTools.setindex(uncoupled, uncoupled[i + 1], i) - uncoupled′ = TupleTools.setindex(uncoupled′, uncoupled[i], i + 1) + a, b = uncoupled[i], uncoupled[i + 1] + uncoupled′ = TupleTools.setindex(uncoupled, b, i) + uncoupled′ = TupleTools.setindex(uncoupled′, a, i + 1) + coupled′ = rightone(src.uncoupled[1][N]) + isdual = src.isdual[1] isdual′ = TupleTools.setindex(isdual, isdual[i], i + 1) isdual′ = TupleTools.setindex(isdual′, isdual[i + 1], i) dst = FusionTreeBlock{I}((uncoupled′, ()), (isdual′, ())) + oneT = one(sectorscalartype(I)) + indexmap = treeindex_map(dst) U = zeros(sectorscalartype(I), length(dst), length(src)) - for (col, (f, f₂)) in enumerate(fusiontrees(src)) - a, b = uncoupled[i], uncoupled[i + 1] - uncoupled′ = TupleTools.setindex(uncoupled, b, i) - uncoupled′ = TupleTools.setindex(uncoupled′, a, i + 1) - coupled′ = f.coupled - isdual′ = TupleTools.setindex(f.isdual, f.isdual[i], i + 1) - isdual′ = TupleTools.setindex(isdual′, f.isdual[i + 1], i) - inner = f.innerlines - inner_extended = (uncoupled[1], inner..., coupled′) - vertices = f.vertices - oneT = one(sectorscalartype(I)) - - if isone(uncoupled[i]) || isone(uncoupled[i + 1]) - # braiding with trivial sector: simple and always possible + if isone(a) || isone(b) # braiding with trivial sector: simple and always possible + for (col, (f, f₂)) in enumerate(fusiontrees(src)) + inner = f.innerlines + inner_extended = (uncoupled[1], inner..., coupled′) + vertices = f.vertices inner′ = inner vertices′ = vertices if i > 1 # we also need to alter innerlines and vertices @@ -492,11 +488,17 @@ function artin_braid(src::FusionTreeBlock{I,N,0}, i; inv::Bool=false) where {I,N f′ = FusionTree{I}(uncoupled′, coupled′, isdual′, inner′, vertices′) row = indexmap[treeindex_data((f′, f₂))] @inbounds U[row, col] = oneT - continue end + return dst, U + end - BraidingStyle(I) isa NoBraiding && - throw(SectorMismatch("Cannot braid sectors $(uncoupled[i]) and $(uncoupled[i + 1])")) + BraidingStyle(I) isa NoBraiding && + throw(SectorMismatch(lazy"Cannot braid sectors $a and $b")) + + for (col, (f, f₂)) in enumerate(fusiontrees(src)) + inner = f.innerlines + inner_extended = (uncoupled[1], inner..., coupled′) + vertices = f.vertices if i == 1 c = N > 2 ? inner[1] : coupled′ From aae6602ca5446bd89e910d4dc3b4123f9143ac8b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 17 Aug 2025 09:54:03 +0200 Subject: [PATCH 26/28] reduce allocations with sizehints --- src/fusiontrees/fusiontreeblocks.jl | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/fusiontrees/fusiontreeblocks.jl b/src/fusiontrees/fusiontreeblocks.jl index 6f0b18a74..ecd06739b 100644 --- a/src/fusiontrees/fusiontreeblocks.jl +++ b/src/fusiontrees/fusiontreeblocks.jl @@ -3,11 +3,11 @@ struct FusionTreeBlock{I,N₁,N₂,F<:FusionTreePair{I,N₁,N₂}} end function FusionTreeBlock{I}(uncoupled::Tuple{NTuple{N₁,I},NTuple{N₂,I}}, - isdual::Tuple{NTuple{N₁,Bool},NTuple{N₂,Bool}}) where {I<:Sector, - N₁,N₂} - F₁ = fusiontreetype(I, N₁) - F₂ = fusiontreetype(I, N₂) - trees = Vector{Tuple{F₁,F₂}}(undef, 0) + isdual::Tuple{NTuple{N₁,Bool},NTuple{N₂,Bool}}; + sizehint::Int=0) where {I<:Sector,N₁,N₂} + F = fusiontreetype(I, N₁, N₂) + trees = Vector{F}(undef, 0) + sizehint > 0 && sizehint!(trees, sizehint) if N₁ == N₂ == 0 return FusionTreeBlock(trees) @@ -97,7 +97,7 @@ function bendright(src::FusionTreeBlock) N₂ = numin(src) @assert N₁ > 0 - dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst) + dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst; sizehint=length(src)) indexmap = treeindex_map(dst) U = zeros(sectorscalartype(I), length(dst), length(src)) @@ -156,7 +156,7 @@ function bendleft(src::FusionTreeBlock) N₂ = numout(src) @assert N₁ > 0 - dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst) + dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst; sizehint=length(src)) indexmap = treeindex_map(dst) U = zeros(sectorscalartype(I), length(dst), length(src)) @@ -211,9 +211,8 @@ function foldright(src::FusionTreeBlock) N₁ = numout(src) N₂ = numin(src) @assert N₁ > 0 - dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst) - dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst) + dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst; sizehint=length(src)) indexmap = treeindex_map(dst) U = zeros(sectorscalartype(I), length(dst), length(src)) @@ -282,7 +281,7 @@ function foldleft(src::FusionTreeBlock) N₂ = numout(src) @assert N₁ > 0 - dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst) + dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst; sizehint=length(src)) indexmap = treeindex_map(dst) U = zeros(sectorscalartype(I), length(dst), length(src)) @@ -464,7 +463,7 @@ function artin_braid(src::FusionTreeBlock{I,N,0}, i; inv::Bool=false) where {I,N isdual = src.isdual[1] isdual′ = TupleTools.setindex(isdual, isdual[i], i + 1) isdual′ = TupleTools.setindex(isdual′, isdual[i + 1], i) - dst = FusionTreeBlock{I}((uncoupled′, ()), (isdual′, ())) + dst = FusionTreeBlock{I}((uncoupled′, ()), (isdual′, ()); sizehint=length(src)) oneT = one(sectorscalartype(I)) @@ -596,7 +595,7 @@ function braid(src::FusionTreeBlock{I,N,0}, p::NTuple{N,Int}, if FusionStyle(I) isa UniqueFusion && BraidingStyle(I) isa SymmetricBraiding uncoupled′ = TupleTools._permute(src.uncoupled[1], p) isdual′ = TupleTools._permute(src.isdual[1], p) - dst = FusionTreeBlock{I}(uncoupled′, isdual′) + dst = FusionTreeBlock{I}(uncoupled′, isdual′; sizehint=length(src)) U = transformation_matrix(dst, src) do (f₁, f₂) return ((f₁′, f₂) => c for (f₁′, c) in braid(f₁, p, levels)) end From edde8f9ee5b310012938a3a2b49ea5ff69875cdf Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 17 Aug 2025 10:18:22 +0200 Subject: [PATCH 27/28] separate treemanipulation threads --- src/TensorKit.jl | 13 +++++++++++++ src/tensors/treetransformers.jl | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 6f7467e37..1bcd39135 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -198,6 +198,19 @@ function set_num_transformer_threads(n::Int) return TRANSFORMER_THREADS[] = n end +const TREEMANIPULATION_THREADS = Ref(1) + +get_num_manipulation_threads() = TREEMANIPULATION_THREADS[] + +function set_num_manipulation_threads(n::Int) + N = Base.Threads.nthreads() + if n > N + n = N + Strided._set_num_threads_warn(n) + end + return TREEMANIPULATION_THREADS[] = n +end + # Definitions and methods for tensors #------------------------------------- # general definitions diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index 1dc1dd832..a6031c635 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -73,7 +73,7 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) data = Vector{_GenericTransformerData{T,N}}() - nthreads = get_num_transformer_threads() + nthreads = get_num_manipulation_threads() if nthreads > 1 fusiontreeblocks = Vector{FusionTreeBlock{I,N₁,N₂,fusiontreetype(I, N₁, N₂)}}() for cod_uncoupled_src in sectors(codomain(Vsrc)), From 4764f72e4099193e7fd6c464c45a3c653f41fb15 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 17 Aug 2025 10:39:03 +0200 Subject: [PATCH 28/28] formatter --- src/fusiontrees/fusiontrees.jl | 3 ++- src/tensors/treetransformers.jl | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/fusiontrees/fusiontrees.jl b/src/fusiontrees/fusiontrees.jl index 7fb46d6a7..64daff4f7 100644 --- a/src/fusiontrees/fusiontrees.jl +++ b/src/fusiontrees/fusiontrees.jl @@ -145,7 +145,8 @@ Base.@assume_effects :foldable function fusiontreetype(::Type{I}, N::Int) where FusionTree{I,N,N - 2,N - 1} end end -Base.@assume_effects :foldable function fusiontreetype(::Type{I}, N₁::Int, N₂::Int) where {I<:Sector} +Base.@assume_effects :foldable function fusiontreetype(::Type{I}, N₁::Int, + N₂::Int) where {I<:Sector} return Tuple{fusiontreetype(I, N₁),fusiontreetype(I, N₂)} end diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index a6031c635..8c9601ad3 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -113,7 +113,7 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) inds_dst) data[local_counter] = (matrix, (sz_dst, newstructs_dst), - (sz_src, newstructs_src)) + (sz_src, newstructs_src)) end end end