From 2afcdfc045ac91d39a4d9d244a3f721fbfdcf0ae Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 3 Apr 2026 11:25:43 +0200 Subject: [PATCH 01/10] Allow BraidingTensor to have a custom storage type --- ext/TensorKitAdaptExt.jl | 8 +- ext/TensorKitCUDAExt/TensorKitCUDAExt.jl | 16 +- ext/TensorKitCUDAExt/cutensormap.jl | 4 + src/planar/preprocessors.jl | 39 ++- src/tensors/braidingtensor.jl | 151 ++++++------ test/cuda/factorizations.jl | 2 + test/cuda/planar.jl | 300 +++++++++++++++++++++++ test/cuda/tensors.jl | 35 ++- test/setup.jl | 6 +- 9 files changed, 459 insertions(+), 102 deletions(-) create mode 100644 test/cuda/planar.jl diff --git a/ext/TensorKitAdaptExt.jl b/ext/TensorKitAdaptExt.jl index 30c223820..4d2693b7b 100644 --- a/ext/TensorKitAdaptExt.jl +++ b/ext/TensorKitAdaptExt.jl @@ -15,8 +15,12 @@ function Adapt.adapt_structure(to, x::DiagonalTensorMap) data′ = adapt(to, x.data) return DiagonalTensorMap(data′, x.domain) end -function Adapt.adapt_structure(::Type{TorA}, x::BraidingTensor) where {TorA <: Union{Number, DenseArray{<:Number}}} - return BraidingTensor{scalartype(TorA)}(space(x), x.adjoint) +function Adapt.adapt_structure(::Type{T}, x::BraidingTensor{T′, S, A}) where {T <: Number, T′, S, A} + A′ = TensorKit.similarstoragetype(A, T) + return BraidingTensor{T, S, A′}(space(x), x.adjoint) +end +function Adapt.adapt_structure(::Type{TA}, x::BraidingTensor{T, S, A}) where {T′, TA <: DenseArray{T′}, T, S, A} + return BraidingTensor{T′, S, TA}(space(x), x.adjoint) end end diff --git a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl index f5efb98bb..83e34a0a7 100644 --- a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl +++ b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl @@ -3,14 +3,17 @@ module TensorKitCUDAExt using CUDA, CUDA.CUBLAS, CUDA.CUSOLVER, LinearAlgebra using CUDA: @allowscalar using cuTENSOR: cuTENSOR +using Strided: StridedViews import CUDA: rand as curand, rand! as curand!, randn as curandn, randn! as curandn! +using CUDA.KernelAbstractions +using CUDA.KernelAbstractions: @kernel, @index using TensorKit using TensorKit.Factorizations using TensorKit.Strided using TensorKit.Factorizations: AbstractAlgorithm using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check -import TensorKit: randisometry, rand, randn +import TensorKit: randisometry, rand, randn, fill_braidingsubblock! using TensorKit: MatrixAlgebraKit @@ -19,4 +22,15 @@ using Random include("cutensormap.jl") include("truncation.jl") +function TensorKit.fill_braidingsubblock!(data::TD, val) where {T, TD <: Union{<:CuMatrix{T}, <:StridedViews.StridedView{T, 4, <:CuArray{T}}}} + @kernel function fill_subblock_kernel!(subblock, val) + idx = @index(Global, Cartesian) + idx_val = idx[1] == idx[4] && idx[2] == idx[3] ? val : zero(val) + @inbounds subblock[idx] = idx_val + end + kernel = fill_subblock_kernel!(KernelAbstractions.get_backend(data)) + kernel(data, val; ndrange = size(data)) + return data +end + end diff --git a/ext/TensorKitCUDAExt/cutensormap.jl b/ext/TensorKitCUDAExt/cutensormap.jl index f065c2ec1..8894164a9 100644 --- a/ext/TensorKitCUDAExt/cutensormap.jl +++ b/ext/TensorKitCUDAExt/cutensormap.jl @@ -168,3 +168,7 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth) return tf end end + +function TensorKit._add_transform_multi!(tdst::CuTensorMap, tsrc, p, (U, structs_dst, structs_src)::Tuple{<:Array, TD, TS}, buffers, alpha, beta, backend...) where {TD, TS} + return TensorKit._add_transform_multi!(tdst, tsrc, p, (CUDA.Adapt.adapt(CuArray, U), structs_dst, structs_src), buffers, alpha, beta, backend...) +end diff --git a/src/planar/preprocessors.jl b/src/planar/preprocessors.jl index d30406750..3098fc4cf 100644 --- a/src/planar/preprocessors.jl +++ b/src/planar/preprocessors.jl @@ -83,6 +83,23 @@ _add_adjoint(ex) = Expr(TO.prime, ex) # spaces from the rest of the expression. Construct the explicit BraidingTensor objects and # insert them in the expression. function _construct_braidingtensors(ex) + function filter_f(expr) + if TO.istensor(expr) + return _remove_adjoint(TO.decomposetensor(expr)[1]) != :τ + elseif TO.istensorexpr(expr) + return any(filter_f, expr.args) + else + return false + end + end + function extract_tensors(tensor_ex) + if TO.istensor(tensor_ex) + return [TO.decomposetensor(tensor_ex)[1]] + elseif TO.istensorexpr(tensor_ex) + return collect(Iterators.flatmap(extract_tensors, filter(filter_f, tensor_ex.args))) + end + end + # get storagetype ex isa Expr || return ex if ex.head == :macrocall && ex.args[1] == Symbol("@notensor") return ex @@ -104,7 +121,9 @@ function _construct_braidingtensors(ex) ) end end - newrhs, success = _construct_braidingtensors!(rhs, preargs, indexmap) + # if this is a definition, the lhs tensor is NOT yet defined + no_τ_ex = reduce(vcat, Iterators.flatmap(extract_tensors, filter(filter_f, rhs.args)); init = Symbol[]) + newrhs, success = _construct_braidingtensors!(rhs, preargs, indexmap, no_τ_ex) success || throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex")) pre = Expr( @@ -115,7 +134,8 @@ function _construct_braidingtensors(ex) elseif TO.istensorexpr(ex) preargs = Vector{Any}() indexmap = Dict{Any, Any}() - newex, success = _construct_braidingtensors!(ex, preargs, indexmap) + no_τ_ex = reduce(vcat, Iterators.flatmap(extract_tensors, filter(filter_f, ex.args)); init = Symbol[]) + newex, success = _construct_braidingtensors!(ex, preargs, indexmap, no_τ_ex) success || throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex")) pre = Expr( @@ -128,7 +148,7 @@ function _construct_braidingtensors(ex) end end -function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed to be a single tensor expression +function _construct_braidingtensors!(ex, preargs, indexmap, non_braiding) # ex is guaranteed to be a single tensor expression if TO.isscalarexpr(ex) # ex could be tensorscalar call with more braiding tensors return _construct_braidingtensors(ex), true @@ -163,7 +183,8 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t end if foundV1 && foundV2 s = gensym(:τ) - constructex = Expr(:call, GlobalRef(TensorKit, :BraidingTensor), V1, V2) + storageex = Expr(:call, GlobalRef(TensorKit, :promote_storagetype), non_braiding...) + constructex = Expr(:call, GlobalRef(TensorKit, :BraidingTensor), storageex, V1, V2) push!(preargs, Expr(:(=), s, constructex)) obj = _is_adjoint(obj) ? _add_adjoint(s) : s success = true @@ -196,7 +217,7 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t newargs = Vector{Any}(undef, length(args)) success = true for i in 1:length(ex.args) - newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmap) + newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmap, non_braiding) success = success && successa end newex = Expr(ex.head, newargs...) @@ -212,7 +233,7 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t for i in 2:length(ex.args) successes[i] && continue newargs[i], successa = _construct_braidingtensors!( - args[i], preargs, indexmap + args[i], preargs, indexmap, non_braiding ) successes[i] = successa end @@ -232,7 +253,7 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t indices = [TO.getindices(arg) for arg in args] for i in 2:length(ex.args) indexmapa = copy(indexmap) - newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmapa) + newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmapa, non_braiding) for l in indices[i] if !haskey(indexmap, l) && haskey(indexmapa, l) indexmap[l] = indexmapa[l] @@ -243,10 +264,10 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t newex = Expr(ex.head, newargs...) return newex, success elseif isexpr(ex, :call) && ex.args[1] == :/ && length(ex.args) == 3 - newarg, success = _construct_braidingtensors!(ex.args[2], preargs, indexmap) + newarg, success = _construct_braidingtensors!(ex.args[2], preargs, indexmap, non_braiding) return Expr(:call, :/, newarg, ex.args[3]), success elseif isexpr(ex, :call) && ex.args[1] == :\ && length(ex.args) == 3 - newarg, success = _construct_braidingtensors!(ex.args[3], preargs, indexmap) + newarg, success = _construct_braidingtensors!(ex.args[3], preargs, indexmap, non_braiding) return Expr(:call, :\, ex.args[2], newarg), success else error("unexpected expression $ex") diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 3ff8a9abf..e1c6694ec 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -2,73 +2,81 @@ # special (2,2) tensor that implements a standard braiding operation #====================================================================# """ - struct BraidingTensor{T,S<:IndexSpace} <: AbstractTensorMap{T, S, 2, 2} + struct BraidingTensor{T, S <: IndexSpace, A <: DenseVector{T}} <: AbstractTensorMap{T, S, 2, 2} BraidingTensor(V1::S, V2::S, adjoint::Bool=false) where {S<:IndexSpace} + BraidingTensor{T, S, A}(V1::S, V2::S, adjoint::Bool=false) where {T, S, A} Specific subtype of [`AbstractTensorMap`](@ref) for representing the braiding tensor that braids the first input over the second input; its inverse can be obtained as the adjoint. It holds that `domain(BraidingTensor(V1, V2)) == V1 ⊗ V2` and -`codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`. +`codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`. The storage type `TA` +controls the array type of the braiding tensor used when indexing +and multiplying with other tensors. """ -struct BraidingTensor{T, S} <: AbstractTensorMap{T, S, 2, 2} +struct BraidingTensor{T, S, A <: DenseVector{T}} <: AbstractTensorMap{T, S, 2, 2} V1::S V2::S adjoint::Bool - function BraidingTensor{T, S}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace} - for a in sectors(V1) - for b in sectors(V2) - for c in (a ⊗ b) - Nsymbol(a, b, c) == Nsymbol(b, a, c) || - throw(ArgumentError("Cannot define a braiding between $a and $b")) - end - end + function BraidingTensor{T, S, A}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace, A <: DenseVector{T}} + for a in sectors(V1), b in sectors(V2), c in (a ⊗ b) + Nsymbol(a, b, c) == Nsymbol(b, a, c) || + throw(ArgumentError("Cannot define a braiding between $a and $b")) end - return new{T, S}(V1, V2, adjoint) + return new{T, S, A}(V1, V2, adjoint) # partial construction: only construct rowr and colr when needed end end function BraidingTensor{T}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace} - return BraidingTensor{T, S}(V1, V2, adjoint) + TA = similarstoragetype(T) + return BraidingTensor{T, S, TA}(V1, V2, adjoint) end -function BraidingTensor{T}(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false) where {T} - return BraidingTensor{T}(promote(V1, V2)..., adjoint) +function BraidingTensor(V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace} + T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64 + return BraidingTensor{T}(V1, V2, adjoint) end function BraidingTensor(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false) return BraidingTensor(promote(V1, V2)..., adjoint) end -function BraidingTensor(V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace} - T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64 - return BraidingTensor{T, S}(V1, V2, adjoint) -end function BraidingTensor(V::HomSpace, adjoint::Bool = false) domain(V) == reverse(codomain(V)) || throw(SpaceMismatch("Cannot define a braiding on $V")) return BraidingTensor(V[2], V[1], adjoint) end +function BraidingTensor{T, S, A}(V::HomSpace, adjoint::Bool = false) where {T, S, A} + domain(V) == reverse(codomain(V)) || + throw(SpaceMismatch("Cannot define a braiding on $V")) + return BraidingTensor{T, S, A}(V[2], V[1], adjoint) +end function BraidingTensor{T}(V::HomSpace, adjoint::Bool = false) where {T} domain(V) == reverse(codomain(V)) || throw(SpaceMismatch("Cannot define a braiding on $V")) return BraidingTensor{T}(V[2], V[1], adjoint) end -function Base.adjoint(b::BraidingTensor{T, S}) where {T, S} - return BraidingTensor{T, S}(b.V1, b.V2, !b.adjoint) +# these are here to make the preprocessing for `@planar` expressions less painful +function BraidingTensor(TorA::Type, V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace} + T = eltype(TorA) + TA = similarstoragetype(TorA) + return BraidingTensor{T, S, TA}(V1, V2, adjoint) +end +function BraidingTensor(TorA::Type, V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false) + return BraidingTensor(TorA, promote(V1, V2)..., adjoint) +end +function BraidingTensor(TorA::Type, V::HomSpace, adjoint::Bool = false) + domain(V) == reverse(codomain(V)) || + throw(SpaceMismatch("Cannot define a braiding on $V")) + T = eltype(TorA) + S = spacetype(V[2]) + TA = storagetype(TorA) + return BraidingTensor{T, S, TA}(V[2], V[1], adjoint) +end +function Base.adjoint(b::BraidingTensor{T, S, A}) where {T, S, A} + return BraidingTensor{T, S, A}(b.V1, b.V2, !b.adjoint) end +storagetype(::Type{BraidingTensor{T, S, A}}) where {T, S, A} = A space(b::BraidingTensor) = b.adjoint ? b.V1 ⊗ b.V2 ← b.V2 ⊗ b.V1 : b.V2 ⊗ b.V1 ← b.V1 ⊗ b.V2 -# specializations to ignore the storagetype of BraidingTensor -promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: AbstractTensorMap} = storagetype(B) -promote_storagetype(::Type{A}, ::Type{B}) where {A <: AbstractTensorMap, B <: BraidingTensor} = storagetype(A) -promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: BraidingTensor} = storagetype(A) - -promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: AbstractTensorMap} = - similarstoragetype(B, T) -promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: AbstractTensorMap, B <: BraidingTensor} = - similarstoragetype(A, T) -promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: BraidingTensor} = - similarstoragetype(A, T) - function Base.getindex(b::BraidingTensor) sectortype(b) === Trivial || throw(SectorMismatch()) (V1, V2) = domain(b) @@ -99,6 +107,13 @@ function _braiding_factor(f₁, f₂, inv::Bool = false) return r end +# generates scalar indexing errors on GPU +function fill_braidingsubblock!(data, val) + f(I) = ((I[1] == I[4]) & (I[2] == I[3])) * val + return data .= f.(CartesianIndices(data)) +end + + @inline function subblock( b::BraidingTensor, (f₁, f₂)::Tuple{FusionTree{I, 2}, FusionTree{I, 2}} ) where {I <: Sector} @@ -115,15 +130,10 @@ end d = (dims(codomain(b), f₁.uncoupled)..., dims(domain(b), f₂.uncoupled)...) n1 = d[1] * d[2] n2 = d[3] * d[4] - data = sreshape(StridedView(Matrix{eltype(b)}(undef, n1, n2)), d) - fill!(data, zero(eltype(b))) - + data_parent = storagetype(b)(undef, prod(d)) + data = sreshape(StridedView(data_parent), d) r = _braiding_factor(f₁, f₂, b.adjoint) - if !isnothing(r) - @inbounds for i in axes(data, 1), j in axes(data, 2) - data[i, j, j, i] = r - end - end + isnothing(r) ? zerovector!(data) : fill_braidingsubblock!(data, r) return data end @@ -134,49 +144,52 @@ TensorMap(b::BraidingTensor) = copy!(similar(b), b) Base.convert(::Type{TensorMap}, b::BraidingTensor) = TensorMap(b) Base.complex(b::BraidingTensor{<:Complex}) = b -function Base.complex(b::BraidingTensor) - return BraidingTensor{complex(scalartype(b))}(space(b), b.adjoint) +function Base.complex(b::BraidingTensor{T, S, A}) where {T, S, A} + Tc = complex(T) + Ac = similarstoragetype(A, Tc) + return BraidingTensor{Tc, S, Ac}(space(b), b.adjoint) end -function block(b::BraidingTensor, s::Sector) - I = sectortype(b) - I == typeof(s) || throw(SectorMismatch()) - - # TODO: probably always square? - m = blockdim(codomain(b), s) - n = blockdim(domain(b), s) - data = Matrix{eltype(b)}(undef, (m, n)) - - length(data) == 0 && return data # s ∉ blocksectors(b) - - data = fill!(data, zero(eltype(b))) - +# Trivial +function fill_braidingblock!(data, b::BraidingTensor, s::Trivial) V1, V2 = codomain(b) - if sectortype(b) === Trivial - d1, d2 = dim(V1), dim(V2) - subblock = sreshape(StridedView(data), (d1, d2, d2, d1)) - @inbounds for i in axes(subblock, 1), j in axes(subblock, 2) - subblock[i, j, j, i] = one(eltype(b)) - end - return data - end + d1, d2 = dim(V1), dim(V2) + subblock = sreshape(StridedView(data), (d1, d2, d2, d1)) + fill_braidingsubblock!(subblock, one(eltype(b))) + return data +end +# Nontrivial +function fill_braidingblock!(data, b::BraidingTensor, s::Sector) base_offset = first(blockstructure(b)[s][2]) - 1 for ((f₁, f₂), (sz, str, off)) in pairs(subblockstructure(space(b))) (f₁.coupled == f₂.coupled == s) || continue r = _braiding_factor(f₁, f₂, b.adjoint) - isnothing(r) && continue # change offset to account for single block subblock = StridedView(data, sz, str, off - base_offset) - @inbounds for i in axes(subblock, 1), j in axes(subblock, 2) - subblock[i, j, j, i] = r - end + # without the zero-value, the non-trivial block is not set + # correctly in the GPU case + isnothing(r) ? zerovector!(subblock) : fill_braidingsubblock!(subblock, r) end - return data end +function block(b::BraidingTensor, s::Sector) + I = sectortype(b) + I == typeof(s) || throw(SectorMismatch()) + + # TODO: probably always square? + m = blockdim(codomain(b), s) + n = blockdim(domain(b), s) + + data = reshape(storagetype(b)(undef, m * n), (m, n)) + + m * n == 0 && return data # s ∉ blocksectors(b) + + return fill_braidingblock!(data, b, s) +end + # Index manipulations # ------------------- has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false diff --git a/test/cuda/factorizations.jl b/test/cuda/factorizations.jl index 62e23c9df..63848767f 100644 --- a/test/cuda/factorizations.jl +++ b/test/cuda/factorizations.jl @@ -7,6 +7,8 @@ const CUDAExt = Base.get_extension(TensorKit, :TensorKitCUDAExt) @assert !isnothing(CUDAExt) "Failed to load TensorKit - CUDA extension" const CuTensorMap = getglobal(CUDAExt, :CuTensorMap) +using CUDA.CUBLAS + spacelist = factorization_spacelist(fast_tests) eltypes = (Float32, ComplexF64) diff --git a/test/cuda/planar.jl b/test/cuda/planar.jl new file mode 100644 index 000000000..454515c5d --- /dev/null +++ b/test/cuda/planar.jl @@ -0,0 +1,300 @@ +using Test, TestExtras +using Adapt +using TensorOperations, CUDA, cuTENSOR +using TensorKit +using TensorKit: type_repr +using TensorKit: PlanarTrivial, ℙ +using TensorKit: planaradd!, planartrace!, planarcontract! + +const CUDAExt = Base.get_extension(TensorKit, :TensorKitCUDAExt) +@assert !isnothing(CUDAExt) +const CuTensorMap = getglobal(CUDAExt, :CuTensorMap) + +spacelist = default_spacelist(fast_tests) + +for V in spacelist + I = sectortype(first(V)) + Istr = type_repr(I) + BraidingStyle(I) isa NoBraiding && continue + @timedtestset "Braiding tensor + CUDA with symmetry: $Istr" verbose = true begin + W = V[1] ⊗ V[2] ← V[2] ⊗ V[1] + T = isreal(sectortype(W)) ? Float64 : ComplexF64 + t1 = @constinferred BraidingTensor{T, spacetype(V[2]), CuVector{T, CUDA.DeviceMemory}}(W) + @test space(t1) == W + @test codomain(t1) == codomain(W) + @test domain(t1) == domain(W) + @test scalartype(t1) == (isreal(sectortype(W)) ? Float64 : ComplexF64) + @test storagetype(t1) == CuVector{scalartype(t1), CUDA.DeviceMemory} + t2 = @constinferred BraidingTensor{ComplexF64, spacetype(V[2]), CuVector{ComplexF64, CUDA.DeviceMemory}}(W) + @test scalartype(t2) == ComplexF64 + @test storagetype(t2) == CuVector{ComplexF64, CUDA.DeviceMemory} + t3 = @testinferred adapt(storagetype(t2), t1) + @test storagetype(t3) == storagetype(t2) + # allowscalar needed for the StridedView comparison + @test t3 ≈ t2 + + W2 = reverse(codomain(W)) ← domain(W) + @test_throws SpaceMismatch BraidingTensor(W2) + + @test adjoint(t1) isa BraidingTensor + @test complex(t1) isa BraidingTensor + @test scalartype(complex(t1)) <: Complex + + t3 = @inferred TensorMap(t2) + @test storagetype(t3) == CuVector{ComplexF64, CUDA.DeviceMemory} + t4 = braid(adapt(CuArray, id(scalartype(t2), domain(t2))), ((2, 1), (3, 4)), (1, 2, 3, 4)) + @test t1 ≈ t4 + for (c, b) in blocks(t1) + @test block(t1, c) ≈ b ≈ block(t3, c) + end + + CUDA.@allowscalar begin + for (f1, f2) in fusiontrees(t1) + @test t1[f1, f2] ≈ t3[f1, f2] + end + end + + t5 = @inferred TensorMap(t2') + @test storagetype(t5) == CuVector{ComplexF64, CUDA.DeviceMemory} + t6 = braid(adapt(CuArray, id(scalartype(t2), domain(t2'))), ((2, 1), (3, 4)), (4, 3, 2, 1)) + CUDA.@allowscalar begin + @test t5 ≈ t6 + for (c, b) in blocks(t1') + @test block(t1', c) ≈ b ≈ block(t5, c) + end + for (f1, f2) in fusiontrees(t1') + # needed here for broadcasting the - in isapprox + @test t1'[f1, f2] ≈ t5[f1, f2] + end + end + end +end + +@testset "planar methods" verbose = true begin + @testset "planaradd" begin + A = CUDA.randn(ℂ^2 ⊗ ℂ^3 ← ℂ^6 ⊗ ℂ^5 ⊗ ℂ^4) + C = CUDA.randn((ℂ^5)' ⊗ (ℂ^6)' ← ℂ^4 ⊗ (ℂ^3)' ⊗ (ℂ^2)') + A′ = force_planar(A) + C′ = force_planar(C) + p = ((4, 3), (5, 2, 1)) + + @test force_planar(tensoradd!(C, A, p, false, true, true)) ≈ + planaradd!(C′, A′, p, true, true) + end + + @testset "planartrace" begin + A = CUDA.randn(ℂ^2 ⊗ ℂ^3 ← ℂ^2 ⊗ ℂ^5 ⊗ ℂ^4) + C = CUDA.randn((ℂ^5)' ⊗ ℂ^3 ← ℂ^4) + A′ = force_planar(A) + C′ = force_planar(C) + p = ((4, 2), (5,)) + q = ((1,), (3,)) + + @test force_planar(tensortrace!(C, A, p, q, false, true, true)) ≈ + planartrace!(C′, A′, p, q, true, true) + end + + @testset "planarcontract" begin + A = CUDA.randn(ℂ^2 ⊗ ℂ^3 ← ℂ^2 ⊗ ℂ^5 ⊗ ℂ^4) + B = CUDA.randn(ℂ^2 ⊗ ℂ^4 ← ℂ^4 ⊗ ℂ^3) + C = CUDA.randn((ℂ^5)' ⊗ (ℂ^2)' ⊗ ℂ^2 ← (ℂ^2)' ⊗ ℂ^4) + + A′ = force_planar(A) + B′ = force_planar(B) + C′ = force_planar(C) + + pA = ((1, 3, 4), (5, 2)) + pB = ((2, 4), (1, 3)) + pAB = ((3, 2, 1), (4, 5)) + + @test force_planar(tensorcontract!(C, A, pA, false, B, pB, false, pAB, true, true)) ≈ + planarcontract!(C′, A′, pA, B′, pB, pAB, true, true) + end +end + +@testset "@planar" verbose = true begin + T = ComplexF64 + + @testset "contractcheck" begin + V = ℂ^2 + A = CUDA.rand(T, V ⊗ V ← V) + B = CUDA.rand(T, V ⊗ V ← V') + @tensor C1[i j; k l] := A[i j; m] * B[k l; m] + @tensor contractcheck = true C2[i j; k l] := A[i j; m] * B[k l; m] + @test C1 ≈ C2 + B2 = CUDA.rand(T, V ⊗ V ← V) # wrong duality for third space + @test_throws SpaceMismatch("incompatible spaces for m: $V ≠ $(V')") begin + @tensor contractcheck = true C3[i j; k l] := A[i j; m] * B2[k l; m] + end + + #= # TODO NEEDS UPDATES TO planar/preprocessors + A = CUDA.rand(T, V ← V ⊗ V) + B = CUDA.rand(T, V ⊗ V ← V) + @planar C1[i; j] := A[i; k l] * τ[k l; m n] * B[m n; j] + @planar contractcheck = true C2[i; j] := A[i; k l] * τ[k l; m n] * B[m n; j] + @test C1 ≈ C2 + @test_throws SpaceMismatch("incompatible spaces for m: $V ≠ $(V')") begin + @planar contractcheck = true C3[i; j] := A[i; k l] * τ[k l; m n] * B[n j; m] + end=# + end + + @testset "MPS networks" begin + P = ℂ^2 + Vmps = ℂ^12 + Vmpo = ℂ^4 + + # ∂AC + # ------- + x = CUDA.randn(T, Vmps ⊗ P ← Vmps) + O = CUDA.randn(T, Vmpo ⊗ P ← P ⊗ Vmpo) + GL = CUDA.randn(T, Vmps ⊗ Vmpo' ← Vmps) + GR = CUDA.randn(T, Vmps ⊗ Vmpo ← Vmps) + + x′ = force_planar(x) + O′ = force_planar(O) + GL′ = force_planar(GL) + GR′ = force_planar(GR) + + for alloc in + (TensorOperations.DefaultAllocator(),) + @tensor allocator = alloc y[-1 -2; -3] := GL[-1 2; 1] * x[1 3; 4] * + O[2 -2; 3 5] * GR[4 5; -3] + @planar allocator = alloc y′[-1 -2; -3] := GL′[-1 2; 1] * x′[1 3; 4] * + O′[2 -2; 3 5] * GR′[4 5; -3] + @test force_planar(y) ≈ y′ + end + + # ∂AC2 + # ------- + x2 = CUDA.randn(T, Vmps ⊗ P ← Vmps ⊗ P') + x2′ = force_planar(x2) + @tensor contractcheck = true y2[-1 -2; -3 -4] := GL[-1 7; 6] * x2[6 5; 1 3] * + O[7 -2; 5 4] * O[4 -4; 3 2] * + GR[1 2; -3] + @planar y2′[-1 -2; -3 -4] := GL′[-1 7; 6] * x2′[6 5; 1 3] * O′[7 -2; 5 4] * + O′[4 -4; 3 2] * GR′[1 2; -3] + @test force_planar(y2) ≈ y2′ + + # transfer matrix + # ---------------- + v = CUDA.randn(T, Vmps ← Vmps) + v′ = force_planar(v) + @tensor ρ[-1; -2] := x[-1 2; 1] * conj(x[-2 2; 3]) * v[1; 3] + @planar ρ′[-1; -2] := x′[-1 2; 1] * conj(x′[-2 2; 3]) * v′[1; 3] + @test force_planar(ρ) ≈ ρ′ + + @tensor ρ2[-1 -2; -3] := GL[1 -2; 3] * x[3 2; -3] * conj(x[1 2; -1]) + @plansor ρ3[-1 -2; -3] := GL[1 2; 4] * x[4 5; -3] * τ[2 3; 5 -2] * conj(x[1 3; -1]) + #@planar ρ2′[-1 -2; -3] := GL′[1 2; 4] * x′[4 5; -3] * τ[2 3; 5 -2] * + # conj(x′[1 3; -1]) + #@test force_planar(ρ2) ≈ ρ2′ + @test ρ2 ≈ ρ3 + + # Periodic boundary conditions + # ---------------------------- + f1 = isomorphism(storagetype(O), fuse(Vmpo^3), Vmpo ⊗ Vmpo' ⊗ Vmpo) + f2 = isomorphism(storagetype(O), fuse(Vmpo^3), Vmpo ⊗ Vmpo' ⊗ Vmpo) + f1′ = force_planar(f1) + f2′ = force_planar(f2) + @tensor O_periodic1[-1 -2; -3 -4] := O[1 -2; -3 2] * f1[-1; 1 3 4] * + conj(f2[-4; 2 3 4]) + @plansor O_periodic2[-1 -2; -3 -4] := O[1 2; -3 6] * f1[-1; 1 3 5] * + conj(f2[-4; 6 7 8]) * τ[2 3; 7 4] * + τ[4 5; 8 -2] + #=@planar O_periodic′[-1 -2; -3 -4] := O′[1 2; -3 6] * f1′[-1; 1 3 5] * + conj(f2′[-4; 6 7 8]) * τ[2 3; 7 4] * + τ[4 5; 8 -2]=# + @test O_periodic1 ≈ O_periodic2 + #@test force_planar(O_periodic1) ≈ O_periodic′ + end + + @testset "MERA networks" begin + Vmera = ℂ^2 + + u = CUDA.randn(T, Vmera ⊗ Vmera ← Vmera ⊗ Vmera) + w = CUDA.randn(T, Vmera ⊗ Vmera ← Vmera) + ρ = CUDA.randn(T, Vmera ⊗ Vmera ⊗ Vmera ← Vmera ⊗ Vmera ⊗ Vmera) + h = CUDA.randn(T, Vmera ⊗ Vmera ⊗ Vmera ← Vmera ⊗ Vmera ⊗ Vmera) + + u′ = force_planar(u) + w′ = force_planar(w) + ρ′ = force_planar(ρ) + h′ = force_planar(h) + + for alloc in + (TensorOperations.DefaultAllocator(),) + @tensor allocator = alloc begin + C = ( + ( + ( + ( + ( + ((h[9 3 4; 5 1 2] * u[1 2; 7 12]) * conj(u[3 4; 11 13])) * + (u[8 5; 15 6] * w[6 7; 19]) + ) * + (conj(u[8 9; 17 10]) * conj(w[10 11; 22])) + ) * + ((w[12 14; 20] * conj(w[13 14; 23])) * ρ[18 19 20; 21 22 23]) + ) * + w[16 15; 18] + ) * conj(w[16 17; 21]) + ) + end + @planar allocator = alloc begin + C′ = ( + ( + ( + ( + ( + ((h′[9 3 4; 5 1 2] * u′[1 2; 7 12]) * conj(u′[3 4; 11 13])) * + (u′[8 5; 15 6] * w′[6 7; 19]) + ) * + (conj(u′[8 9; 17 10]) * conj(w′[10 11; 22])) + ) * + ((w′[12 14; 20] * conj(w′[13 14; 23])) * ρ′[18 19 20; 21 22 23]) + ) * + w′[16 15; 18] + ) * conj(w′[16 17; 21]) + ) + end + @test C ≈ C′ + end + end + + @testset "Issue 93" begin + T = Float64 + V1 = ℂ^2 + V2 = ℂ^3 + t1 = CUDA.randn(T, V1 ← V2) + t2 = CUDA.randn(T, V2 ← V1) + + tr1 = @planar opt = true t1[a; b] * t2[b; a] / 2 + tr2 = @planar opt = true t1[d; a] * t2[b; c] * 1 / 2 * τ[c b; a d] + tr3 = @planar opt = true t1[d; a] * t2[b; c] * τ[a c; d b] / 2 + tr4 = @planar opt = true t1[f; a] * 1 / 2 * t2[c; d] * τ[d b; c e] * τ[e b; a f] + tr5 = @planar opt = true t1[f; a] * t2[c; d] / 2 * τ[d b; c e] * τ[a e; f b] + tr6 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b] / 2 * τ[e b; a f] + tr7 = @planar opt = true t1[f; a] * t2[c; d] * (τ[c d; e b] * τ[a e; f b] / 2) + + @test tr1 ≈ tr2 ≈ tr3 ≈ tr4 ≈ tr5 ≈ tr6 ≈ tr7 + + tr1 = @plansor opt = true t1[a; b] * t2[b; a] / 2 + tr2 = @plansor opt = true t1[d; a] * t2[b; c] * 1 / 2 * τ[c b; a d] + tr3 = @plansor opt = true t1[d; a] * t2[b; c] * τ[a c; d b] / 2 + tr4 = @plansor opt = true t1[f; a] * 1 / 2 * t2[c; d] * τ[d b; c e] * τ[e b; a f] + tr5 = @plansor opt = true t1[f; a] * t2[c; d] / 2 * τ[d b; c e] * τ[a e; f b] + tr6 = @plansor opt = true t1[f; a] * t2[c; d] * τ[c d; e b] / 2 * τ[e b; a f] + tr7 = @plansor opt = true t1[f; a] * t2[c; d] * (τ[c d; e b] * τ[a e; f b] / 2) + + @test tr1 ≈ tr2 ≈ tr3 ≈ tr4 ≈ tr5 ≈ tr6 ≈ tr7 + end + @testset "Issue 262" begin + V = ℂ^2 + A = CUDA.randn(T, V ← V) + B = CUDA.randn(T, V ← V') + C = CUDA.randn(T, V' ← V) + @planar D1[i; j] := A[i; j] + B[i; k] * C[k; j] + @planar D2[i; j] := B[i; k] * C[k; j] + A[i; j] + @test D1 ≈ D2 + end +end diff --git a/test/cuda/tensors.jl b/test/cuda/tensors.jl index c88e98b45..8314d8466 100644 --- a/test/cuda/tensors.jl +++ b/test/cuda/tensors.jl @@ -268,21 +268,18 @@ for V in spacelist for p in permutations(1:5) p1 = ntuple(n -> p[n], k) p2 = ntuple(n -> p[k + n], 5 - k) - CUDA.@allowscalar begin - t2 = @constinferred permute(t, (p1, p2)) - t2 = permute(t, (p1, p2)) - @test norm(t2) ≈ norm(t) - t2′ = permute(t′, (p1, p2)) - @test dot(t2′, t2) ≈ dot(t′, t) ≈ dot(transpose(t2′), transpose(t2)) - end - end - CUDA.@allowscalar begin - t3 = @constinferred repartition(t, $k) - @test norm(t3) ≈ norm(t) - t3′ = @constinferred repartition!(similar(t3), t′) - @test norm(t3′) ≈ norm(t′) - @test dot(t′, t) ≈ dot(t3′, t3) + t2 = @constinferred permute(t, (p1, p2)) + t2 = permute(t, (p1, p2)) + @test norm(t2) ≈ norm(t) + t2′ = permute(t′, (p1, p2)) + @test dot(t2′, t2) ≈ dot(t′, t) ≈ dot(transpose(t2′), transpose(t2)) end + t3 = @constinferred repartition(t, $k) + t3 = repartition(t, k) + @test norm(t3) ≈ norm(t) + t3′ = @constinferred repartition!(similar(t3), t′) + @test norm(t3′) ≈ norm(t′) + @test dot(t′, t) ≈ dot(t3′, t3) end end symmetricbraiding && @timedtestset "Permutations: test via CPU" begin @@ -292,14 +289,14 @@ for V in spacelist for p in permutations(1:5) p1 = ntuple(n -> p[n], k) p2 = ntuple(n -> p[k + n], 5 - k) - dt2 = CUDA.@allowscalar permute(t, (p1, p2)) + dt2 = permute(t, (p1, p2)) ht2 = permute(TensorKit.to_cpu(t), (p1, p2)) - @test ht2 == TensorKit.to_cpu(dt2) + @test ht2 ≈ TensorKit.to_cpu(dt2) end dt3 = CUDA.@allowscalar repartition(t, k) ht3 = repartition(TensorKit.to_cpu(t), k) - @test ht3 == TensorKit.to_cpu(dt3) + @test ht3 ≈ TensorKit.to_cpu(dt3) end end symmetricbraiding && @timedtestset "Full trace: test self-consistency" begin @@ -347,7 +344,7 @@ for V in spacelist end @test ta ≈ tb end - #=if BraidingStyle(I) isa Bosonic && hasfusiontensor(I) + if BraidingStyle(I) isa Bosonic && hasfusiontensor(I) @timedtestset "Tensor contraction: test via CPU" begin dA1 = CUDA.randn(ComplexF64, V1' * V2', V3') dA2 = CUDA.randn(ComplexF64, V3 * V4, V5) @@ -362,7 +359,7 @@ for V in spacelist TensorKit.to_cpu(dH)[s1, s2, t1, t2] @test TensorKit.to_cpu(dHrA12) ≈ hHrA12 end - end=# # doesn't yet work because of AdjointTensor + end BraidingStyle(I) isa HasBraiding && @timedtestset "Index flipping: test flipping inverse" begin t = CUDA.rand(ComplexF64, V1 ⊗ V2 ⊗ V3 ← (V4 ⊗ V5)') for i in 1:5 diff --git a/test/setup.jl b/test/setup.jl index 37c682d62..cb230a2dd 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -111,20 +111,22 @@ function force_planar(V::GradedSpace) end force_planar(V::ProductSpace) = mapreduce(force_planar, ⊗, V) function force_planar(tsrc::TensorMap{<:Any, ComplexSpace}) - tdst = TensorMap{scalartype(tsrc)}( + tdst = TensorKit.TensorMapWithStorage{scalartype(tsrc), storagetype(tsrc)}( undef, force_planar(codomain(tsrc)) ← force_planar(domain(tsrc)) ) + tdst = similar(tsrc, force_planar(codomain(tsrc)) ← force_planar(domain(tsrc))) copyto!(block(tdst, PlanarTrivial()), block(tsrc, Trivial())) return tdst end function force_planar(tsrc::TensorMap{<:Any, <:GradedSpace}) - tdst = TensorMap{scalartype(tsrc)}( + tdst = TensorKit.TensorMapWithStorage{scalartype(tsrc), storagetype(tsrc)}( undef, force_planar(codomain(tsrc)) ← force_planar(domain(tsrc)) ) + tdst = similar(tsrc, force_planar(codomain(tsrc)) ← force_planar(domain(tsrc))) for (c, b) in blocks(tsrc) copyto!(block(tdst, c ⊠ PlanarTrivial()), b) end From 832fdf356cf0dbf74c8fedb7ccb5f022c5c2a2a6 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 22 Apr 2026 09:35:43 +0200 Subject: [PATCH 02/10] Remove extraneous lines Co-authored-by: Jutho --- src/tensors/braidingtensor.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index e1c6694ec..992f52f49 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -128,8 +128,6 @@ end throw(SectorMismatch()) end d = (dims(codomain(b), f₁.uncoupled)..., dims(domain(b), f₂.uncoupled)...) - n1 = d[1] * d[2] - n2 = d[3] * d[4] data_parent = storagetype(b)(undef, prod(d)) data = sreshape(StridedView(data_parent), d) r = _braiding_factor(f₁, f₂, b.adjoint) From f4c77106fea35ba0d32177d0b3975897fd76eb9d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 22 Apr 2026 03:57:48 -0400 Subject: [PATCH 03/10] More extraneous line removal --- test/setup.jl | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/test/setup.jl b/test/setup.jl index cb230a2dd..106aa6c36 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -111,21 +111,11 @@ function force_planar(V::GradedSpace) end force_planar(V::ProductSpace) = mapreduce(force_planar, ⊗, V) function force_planar(tsrc::TensorMap{<:Any, ComplexSpace}) - tdst = TensorKit.TensorMapWithStorage{scalartype(tsrc), storagetype(tsrc)}( - undef, - force_planar(codomain(tsrc)) ← - force_planar(domain(tsrc)) - ) tdst = similar(tsrc, force_planar(codomain(tsrc)) ← force_planar(domain(tsrc))) copyto!(block(tdst, PlanarTrivial()), block(tsrc, Trivial())) return tdst end function force_planar(tsrc::TensorMap{<:Any, <:GradedSpace}) - tdst = TensorKit.TensorMapWithStorage{scalartype(tsrc), storagetype(tsrc)}( - undef, - force_planar(codomain(tsrc)) ← - force_planar(domain(tsrc)) - ) tdst = similar(tsrc, force_planar(codomain(tsrc)) ← force_planar(domain(tsrc))) for (c, b) in blocks(tsrc) copyto!(block(tdst, c ⊠ PlanarTrivial()), b) From 1ca35b63e4aabf631d2192f1ec2327a26ee5df17 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 22 Apr 2026 05:06:29 -0400 Subject: [PATCH 04/10] A little more cleanup --- ext/TensorKitCUDAExt/TensorKitCUDAExt.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl index 83e34a0a7..63beed07c 100644 --- a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl +++ b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl @@ -5,8 +5,7 @@ using CUDA: @allowscalar using cuTENSOR: cuTENSOR using Strided: StridedViews import CUDA: rand as curand, rand! as curand!, randn as curandn, randn! as curandn! -using CUDA.KernelAbstractions -using CUDA.KernelAbstractions: @kernel, @index +using CUDA.KernelAbstractions: @kernel, @index, get_backend using TensorKit using TensorKit.Factorizations @@ -28,7 +27,7 @@ function TensorKit.fill_braidingsubblock!(data::TD, val) where {T, TD <: Union{< idx_val = idx[1] == idx[4] && idx[2] == idx[3] ? val : zero(val) @inbounds subblock[idx] = idx_val end - kernel = fill_subblock_kernel!(KernelAbstractions.get_backend(data)) + kernel = fill_subblock_kernel!(get_backend(data)) kernel(data, val; ndrange = size(data)) return data end From 925c9c5892f12369e7a2ed0db5483672aa21695e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 22 Apr 2026 07:04:22 -0400 Subject: [PATCH 05/10] Coverage --- ext/TensorKitCUDAExt/TensorKitCUDAExt.jl | 3 +++ test/cuda/planar.jl | 2 ++ 2 files changed, 5 insertions(+) diff --git a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl index 63beed07c..530c8cc85 100644 --- a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl +++ b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl @@ -22,11 +22,14 @@ include("cutensormap.jl") include("truncation.jl") function TensorKit.fill_braidingsubblock!(data::TD, val) where {T, TD <: Union{<:CuMatrix{T}, <:StridedViews.StridedView{T, 4, <:CuArray{T}}}} + # COV_EXCL_START + # kernels are not reachable by coverage @kernel function fill_subblock_kernel!(subblock, val) idx = @index(Global, Cartesian) idx_val = idx[1] == idx[4] && idx[2] == idx[3] ? val : zero(val) @inbounds subblock[idx] = idx_val end + # COV_EXCL_STOP kernel = fill_subblock_kernel!(get_backend(data)) kernel(data, val; ndrange = size(data)) return data diff --git a/test/cuda/planar.jl b/test/cuda/planar.jl index 454515c5d..35d4feae5 100644 --- a/test/cuda/planar.jl +++ b/test/cuda/planar.jl @@ -30,6 +30,8 @@ for V in spacelist @test storagetype(t2) == CuVector{ComplexF64, CUDA.DeviceMemory} t3 = @testinferred adapt(storagetype(t2), t1) @test storagetype(t3) == storagetype(t2) + t4 = @testinferred adapt(scalartype(t2), t1) + @test storagetype(t3) == storagetype(t2) # allowscalar needed for the StridedView comparison @test t3 ≈ t2 From 8e515b5062a2d00ee7d6af3c04fdc5ca167931fc Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 23 Apr 2026 07:04:57 +0200 Subject: [PATCH 06/10] Update src/tensors/braidingtensor.jl Co-authored-by: Lukas Devos --- src/tensors/braidingtensor.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 992f52f49..6c74f142c 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -166,8 +166,6 @@ function fill_braidingblock!(data, b::BraidingTensor, s::Sector) r = _braiding_factor(f₁, f₂, b.adjoint) # change offset to account for single block subblock = StridedView(data, sz, str, off - base_offset) - # without the zero-value, the non-trivial block is not set - # correctly in the GPU case isnothing(r) ? zerovector!(subblock) : fill_braidingsubblock!(subblock, r) end return data From 9a07b426d95dff49f340e9aa4955f874963ee60e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 23 Apr 2026 02:24:27 -0400 Subject: [PATCH 07/10] Use braidingtensortype --- src/planar/preprocessors.jl | 3 ++- src/tensors/braidingtensor.jl | 25 +++++++++++++++---------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/planar/preprocessors.jl b/src/planar/preprocessors.jl index 3098fc4cf..e6b67eb64 100644 --- a/src/planar/preprocessors.jl +++ b/src/planar/preprocessors.jl @@ -184,7 +184,8 @@ function _construct_braidingtensors!(ex, preargs, indexmap, non_braiding) # ex i if foundV1 && foundV2 s = gensym(:τ) storageex = Expr(:call, GlobalRef(TensorKit, :promote_storagetype), non_braiding...) - constructex = Expr(:call, GlobalRef(TensorKit, :BraidingTensor), storageex, V1, V2) + braidingex = Expr(:call, GlobalRef(TensorKit, :braidingtensortype), V1, V2, storageex) + constructex = Expr(:call, braidingex, V1, V2) push!(preargs, Expr(:(=), s, constructex)) obj = _is_adjoint(obj) ? _add_adjoint(s) : s success = true diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 6c74f142c..2e7705ddb 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -53,25 +53,30 @@ function BraidingTensor{T}(V::HomSpace, adjoint::Bool = false) where {T} throw(SpaceMismatch("Cannot define a braiding on $V")) return BraidingTensor{T}(V[2], V[1], adjoint) end + +function Base.adjoint(b::BraidingTensor{T, S, A}) where {T, S, A} + return BraidingTensor{T, S, A}(b.V1, b.V2, !b.adjoint) +end + # these are here to make the preprocessing for `@planar` expressions less painful -function BraidingTensor(TorA::Type, V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace} +function braidingtensortype(::Type{S}, ::Type{TorA}) where {S <: IndexSpace, TorA} T = eltype(TorA) TA = similarstoragetype(TorA) - return BraidingTensor{T, S, TA}(V1, V2, adjoint) + return BraidingTensor{T, S, TA} end -function BraidingTensor(TorA::Type, V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false) - return BraidingTensor(TorA, promote(V1, V2)..., adjoint) +braidingtensortype(V::S, ::Type{TorA}) where {S <: IndexSpace, TorA} = braidingtensortype(S, TorA) +braidingtensortype(V1::S, V2::S, ::Type{TorA}) where {S <: IndexSpace, TorA} = braidingtensortype(S, TorA) +function braidingtensortype(V1::IndexSpace, V2::IndexSpace, ::Type{TorA}) where {TorA} + S = promote(V1, V2) + return braidingtensortype(S..., TorA) end -function BraidingTensor(TorA::Type, V::HomSpace, adjoint::Bool = false) +function braidingtensortype(V::HomSpace, ::Type{TorA}) where {TorA} domain(V) == reverse(codomain(V)) || throw(SpaceMismatch("Cannot define a braiding on $V")) T = eltype(TorA) S = spacetype(V[2]) - TA = storagetype(TorA) - return BraidingTensor{T, S, TA}(V[2], V[1], adjoint) -end -function Base.adjoint(b::BraidingTensor{T, S, A}) where {T, S, A} - return BraidingTensor{T, S, A}(b.V1, b.V2, !b.adjoint) + TA = similarstoragetype(TorA) + return braidingtensortype(S, TA) end storagetype(::Type{BraidingTensor{T, S, A}}) where {T, S, A} = A From bd9d394d4a4e65cc9a3adf2c0a321f94d09c3b61 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 23 Apr 2026 15:32:40 +0200 Subject: [PATCH 08/10] Apply suggestions from code review Co-authored-by: Lukas Devos --- src/tensors/braidingtensor.jl | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 2e7705ddb..405aa1407 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -28,8 +28,7 @@ struct BraidingTensor{T, S, A <: DenseVector{T}} <: AbstractTensorMap{T, S, 2, 2 end end function BraidingTensor{T}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace} - TA = similarstoragetype(T) - return BraidingTensor{T, S, TA}(V1, V2, adjoint) + return braidingtensortype(S, T)(V1, V2, adjoint) end function BraidingTensor(V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace} T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64 @@ -60,9 +59,8 @@ end # these are here to make the preprocessing for `@planar` expressions less painful function braidingtensortype(::Type{S}, ::Type{TorA}) where {S <: IndexSpace, TorA} - T = eltype(TorA) - TA = similarstoragetype(TorA) - return BraidingTensor{T, S, TA} + A = similarstoragetype(TorA) + return BraidingTensor{scalartype(A), S, A} end braidingtensortype(V::S, ::Type{TorA}) where {S <: IndexSpace, TorA} = braidingtensortype(S, TorA) braidingtensortype(V1::S, V2::S, ::Type{TorA}) where {S <: IndexSpace, TorA} = braidingtensortype(S, TorA) @@ -73,10 +71,7 @@ end function braidingtensortype(V::HomSpace, ::Type{TorA}) where {TorA} domain(V) == reverse(codomain(V)) || throw(SpaceMismatch("Cannot define a braiding on $V")) - T = eltype(TorA) - S = spacetype(V[2]) - TA = similarstoragetype(TorA) - return braidingtensortype(S, TA) + return braidingtensortype(spacetype(V), TorA) end storagetype(::Type{BraidingTensor{T, S, A}}) where {T, S, A} = A From a69385d70ff88ea69e2bb275c6051513d36d2dfc Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 24 Apr 2026 13:55:30 +0200 Subject: [PATCH 09/10] Update src/tensors/braidingtensor.jl Co-authored-by: Jutho --- src/tensors/braidingtensor.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 405aa1407..f08dd8181 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -69,8 +69,6 @@ function braidingtensortype(V1::IndexSpace, V2::IndexSpace, ::Type{TorA}) where return braidingtensortype(S..., TorA) end function braidingtensortype(V::HomSpace, ::Type{TorA}) where {TorA} - domain(V) == reverse(codomain(V)) || - throw(SpaceMismatch("Cannot define a braiding on $V")) return braidingtensortype(spacetype(V), TorA) end From 3f1fdd80a19139c9dfba858b04857c35473fd495 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 24 Apr 2026 16:25:57 +0200 Subject: [PATCH 10/10] Update Project.toml bump minor version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9df43e067..0308e3933 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorKit" uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec" -version = "0.16.4" +version = "0.17.0" authors = ["Jutho Haegeman, Lukas Devos"] [deps]