Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "TensorKit"
uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
version = "0.16.4"
version = "0.17.0"
authors = ["Jutho Haegeman, Lukas Devos"]

[deps]
Expand Down
8 changes: 6 additions & 2 deletions ext/TensorKitAdaptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@ function Adapt.adapt_structure(to, x::DiagonalTensorMap)
data′ = adapt(to, x.data)
return DiagonalTensorMap(data′, x.domain)
end
function Adapt.adapt_structure(::Type{TorA}, x::BraidingTensor) where {TorA <: Union{Number, DenseArray{<:Number}}}
return BraidingTensor{scalartype(TorA)}(space(x), x.adjoint)
function Adapt.adapt_structure(::Type{T}, x::BraidingTensor{T′, S, A}) where {T <: Number, T′, S, A}
A′ = TensorKit.similarstoragetype(A, T)
return BraidingTensor{T, S, A′}(space(x), x.adjoint)
end
function Adapt.adapt_structure(::Type{TA}, x::BraidingTensor{T, S, A}) where {T′, TA <: DenseArray{T′}, T, S, A}
return BraidingTensor{T′, S, TA}(space(x), x.adjoint)
end

end
18 changes: 17 additions & 1 deletion ext/TensorKitCUDAExt/TensorKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ module TensorKitCUDAExt
using CUDA, CUDA.CUBLAS, CUDA.CUSOLVER, LinearAlgebra
using CUDA: @allowscalar
using cuTENSOR: cuTENSOR
using Strided: StridedViews
import CUDA: rand as curand, rand! as curand!, randn as curandn, randn! as curandn!
using CUDA.KernelAbstractions: @kernel, @index, get_backend

using TensorKit
using TensorKit.Factorizations
using TensorKit.Strided
using TensorKit.Factorizations: AbstractAlgorithm
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check
import TensorKit: randisometry, rand, randn
import TensorKit: randisometry, rand, randn, fill_braidingsubblock!

using TensorKit: MatrixAlgebraKit

Expand All @@ -19,4 +21,18 @@ using Random
include("cutensormap.jl")
include("truncation.jl")

function TensorKit.fill_braidingsubblock!(data::TD, val) where {T, TD <: Union{<:CuMatrix{T}, <:StridedViews.StridedView{T, 4, <:CuArray{T}}}}
# COV_EXCL_START
# kernels are not reachable by coverage
@kernel function fill_subblock_kernel!(subblock, val)
idx = @index(Global, Cartesian)
idx_val = idx[1] == idx[4] && idx[2] == idx[3] ? val : zero(val)
@inbounds subblock[idx] = idx_val
end
# COV_EXCL_STOP
kernel = fill_subblock_kernel!(get_backend(data))
kernel(data, val; ndrange = size(data))
return data
end

end
4 changes: 4 additions & 0 deletions ext/TensorKitCUDAExt/cutensormap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,7 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
return tf
end
end

function TensorKit._add_transform_multi!(tdst::CuTensorMap, tsrc, p, (U, structs_dst, structs_src)::Tuple{<:Array, TD, TS}, buffers, alpha, beta, backend...) where {TD, TS}
return TensorKit._add_transform_multi!(tdst, tsrc, p, (CUDA.Adapt.adapt(CuArray, U), structs_dst, structs_src), buffers, alpha, beta, backend...)
end
Comment thread
kshyatt marked this conversation as resolved.
40 changes: 31 additions & 9 deletions src/planar/preprocessors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,23 @@ _add_adjoint(ex) = Expr(TO.prime, ex)
# spaces from the rest of the expression. Construct the explicit BraidingTensor objects and
# insert them in the expression.
function _construct_braidingtensors(ex)
function filter_f(expr)
if TO.istensor(expr)
return _remove_adjoint(TO.decomposetensor(expr)[1]) != :τ
elseif TO.istensorexpr(expr)
return any(filter_f, expr.args)
else
return false
end
end
function extract_tensors(tensor_ex)
if TO.istensor(tensor_ex)
return [TO.decomposetensor(tensor_ex)[1]]
elseif TO.istensorexpr(tensor_ex)
return collect(Iterators.flatmap(extract_tensors, filter(filter_f, tensor_ex.args)))
end
end
# get storagetype
ex isa Expr || return ex
if ex.head == :macrocall && ex.args[1] == Symbol("@notensor")
return ex
Expand All @@ -104,7 +121,9 @@ function _construct_braidingtensors(ex)
)
end
end
newrhs, success = _construct_braidingtensors!(rhs, preargs, indexmap)
# if this is a definition, the lhs tensor is NOT yet defined
no_τ_ex = reduce(vcat, Iterators.flatmap(extract_tensors, filter(filter_f, rhs.args)); init = Symbol[])
newrhs, success = _construct_braidingtensors!(rhs, preargs, indexmap, no_τ_ex)
success ||
throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex"))
pre = Expr(
Expand All @@ -115,7 +134,8 @@ function _construct_braidingtensors(ex)
elseif TO.istensorexpr(ex)
preargs = Vector{Any}()
indexmap = Dict{Any, Any}()
newex, success = _construct_braidingtensors!(ex, preargs, indexmap)
no_τ_ex = reduce(vcat, Iterators.flatmap(extract_tensors, filter(filter_f, ex.args)); init = Symbol[])
newex, success = _construct_braidingtensors!(ex, preargs, indexmap, no_τ_ex)
success ||
throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex"))
pre = Expr(
Expand All @@ -128,7 +148,7 @@ function _construct_braidingtensors(ex)
end
end

function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed to be a single tensor expression
function _construct_braidingtensors!(ex, preargs, indexmap, non_braiding) # ex is guaranteed to be a single tensor expression
if TO.isscalarexpr(ex)
# ex could be tensorscalar call with more braiding tensors
return _construct_braidingtensors(ex), true
Expand Down Expand Up @@ -163,7 +183,9 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
end
if foundV1 && foundV2
s = gensym(:τ)
constructex = Expr(:call, GlobalRef(TensorKit, :BraidingTensor), V1, V2)
storageex = Expr(:call, GlobalRef(TensorKit, :promote_storagetype), non_braiding...)
braidingex = Expr(:call, GlobalRef(TensorKit, :braidingtensortype), V1, V2, storageex)
constructex = Expr(:call, braidingex, V1, V2)
push!(preargs, Expr(:(=), s, constructex))
obj = _is_adjoint(obj) ? _add_adjoint(s) : s
success = true
Expand Down Expand Up @@ -196,7 +218,7 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
newargs = Vector{Any}(undef, length(args))
success = true
for i in 1:length(ex.args)
newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmap)
newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmap, non_braiding)
success = success && successa
end
newex = Expr(ex.head, newargs...)
Expand All @@ -212,7 +234,7 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
for i in 2:length(ex.args)
successes[i] && continue
newargs[i], successa = _construct_braidingtensors!(
args[i], preargs, indexmap
args[i], preargs, indexmap, non_braiding
)
successes[i] = successa
end
Expand All @@ -232,7 +254,7 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
indices = [TO.getindices(arg) for arg in args]
for i in 2:length(ex.args)
indexmapa = copy(indexmap)
newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmapa)
newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmapa, non_braiding)
for l in indices[i]
if !haskey(indexmap, l) && haskey(indexmapa, l)
indexmap[l] = indexmapa[l]
Expand All @@ -243,10 +265,10 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
newex = Expr(ex.head, newargs...)
return newex, success
elseif isexpr(ex, :call) && ex.args[1] == :/ && length(ex.args) == 3
newarg, success = _construct_braidingtensors!(ex.args[2], preargs, indexmap)
newarg, success = _construct_braidingtensors!(ex.args[2], preargs, indexmap, non_braiding)
return Expr(:call, :/, newarg, ex.args[3]), success
elseif isexpr(ex, :call) && ex.args[1] == :\ && length(ex.args) == 3
newarg, success = _construct_braidingtensors!(ex.args[3], preargs, indexmap)
newarg, success = _construct_braidingtensors!(ex.args[3], preargs, indexmap, non_braiding)
return Expr(:call, :\, ex.args[2], newarg), success
else
error("unexpected expression $ex")
Expand Down
149 changes: 78 additions & 71 deletions src/tensors/braidingtensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,72 +2,78 @@
# special (2,2) tensor that implements a standard braiding operation
#====================================================================#
"""
struct BraidingTensor{T,S<:IndexSpace} <: AbstractTensorMap{T, S, 2, 2}
struct BraidingTensor{T, S <: IndexSpace, A <: DenseVector{T}} <: AbstractTensorMap{T, S, 2, 2}
BraidingTensor(V1::S, V2::S, adjoint::Bool=false) where {S<:IndexSpace}
BraidingTensor{T, S, A}(V1::S, V2::S, adjoint::Bool=false) where {T, S, A}

Specific subtype of [`AbstractTensorMap`](@ref) for representing the braiding tensor that
braids the first input over the second input; its inverse can be obtained as the adjoint.

It holds that `domain(BraidingTensor(V1, V2)) == V1 ⊗ V2` and
`codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`.
`codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`. The storage type `TA`
controls the array type of the braiding tensor used when indexing
and multiplying with other tensors.
"""
struct BraidingTensor{T, S} <: AbstractTensorMap{T, S, 2, 2}
struct BraidingTensor{T, S, A <: DenseVector{T}} <: AbstractTensorMap{T, S, 2, 2}
V1::S
V2::S
adjoint::Bool
function BraidingTensor{T, S}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace}
for a in sectors(V1)
for b in sectors(V2)
for c in (a ⊗ b)
Nsymbol(a, b, c) == Nsymbol(b, a, c) ||
throw(ArgumentError("Cannot define a braiding between $a and $b"))
end
end
function BraidingTensor{T, S, A}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace, A <: DenseVector{T}}
for a in sectors(V1), b in sectors(V2), c in (a ⊗ b)
Nsymbol(a, b, c) == Nsymbol(b, a, c) ||
throw(ArgumentError("Cannot define a braiding between $a and $b"))
end
return new{T, S}(V1, V2, adjoint)
return new{T, S, A}(V1, V2, adjoint)
# partial construction: only construct rowr and colr when needed
end
end
function BraidingTensor{T}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace}
return BraidingTensor{T, S}(V1, V2, adjoint)
return braidingtensortype(S, T)(V1, V2, adjoint)
end
function BraidingTensor{T}(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false) where {T}
return BraidingTensor{T}(promote(V1, V2)..., adjoint)
function BraidingTensor(V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace}
T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64
return BraidingTensor{T}(V1, V2, adjoint)
end
function BraidingTensor(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false)
return BraidingTensor(promote(V1, V2)..., adjoint)
end
function BraidingTensor(V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace}
T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64
return BraidingTensor{T, S}(V1, V2, adjoint)
end
function BraidingTensor(V::HomSpace, adjoint::Bool = false)
domain(V) == reverse(codomain(V)) ||
throw(SpaceMismatch("Cannot define a braiding on $V"))
return BraidingTensor(V[2], V[1], adjoint)
end
function BraidingTensor{T, S, A}(V::HomSpace, adjoint::Bool = false) where {T, S, A}
domain(V) == reverse(codomain(V)) ||
throw(SpaceMismatch("Cannot define a braiding on $V"))
return BraidingTensor{T, S, A}(V[2], V[1], adjoint)
end
function BraidingTensor{T}(V::HomSpace, adjoint::Bool = false) where {T}
domain(V) == reverse(codomain(V)) ||
throw(SpaceMismatch("Cannot define a braiding on $V"))
return BraidingTensor{T}(V[2], V[1], adjoint)
end
function Base.adjoint(b::BraidingTensor{T, S}) where {T, S}
return BraidingTensor{T, S}(b.V1, b.V2, !b.adjoint)
end

space(b::BraidingTensor) = b.adjoint ? b.V1 ⊗ b.V2 ← b.V2 ⊗ b.V1 : b.V2 ⊗ b.V1 ← b.V1 ⊗ b.V2
function Base.adjoint(b::BraidingTensor{T, S, A}) where {T, S, A}
return BraidingTensor{T, S, A}(b.V1, b.V2, !b.adjoint)
end

# specializations to ignore the storagetype of BraidingTensor
promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: AbstractTensorMap} = storagetype(B)
promote_storagetype(::Type{A}, ::Type{B}) where {A <: AbstractTensorMap, B <: BraidingTensor} = storagetype(A)
promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: BraidingTensor} = storagetype(A)
# these are here to make the preprocessing for `@planar` expressions less painful
Comment thread
kshyatt marked this conversation as resolved.
function braidingtensortype(::Type{S}, ::Type{TorA}) where {S <: IndexSpace, TorA}
A = similarstoragetype(TorA)
return BraidingTensor{scalartype(A), S, A}
end
braidingtensortype(V::S, ::Type{TorA}) where {S <: IndexSpace, TorA} = braidingtensortype(S, TorA)
braidingtensortype(V1::S, V2::S, ::Type{TorA}) where {S <: IndexSpace, TorA} = braidingtensortype(S, TorA)
function braidingtensortype(V1::IndexSpace, V2::IndexSpace, ::Type{TorA}) where {TorA}
S = promote(V1, V2)
return braidingtensortype(S..., TorA)
end
function braidingtensortype(V::HomSpace, ::Type{TorA}) where {TorA}
return braidingtensortype(spacetype(V), TorA)
end

promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: AbstractTensorMap} =
similarstoragetype(B, T)
promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: AbstractTensorMap, B <: BraidingTensor} =
similarstoragetype(A, T)
promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: BraidingTensor} =
similarstoragetype(A, T)
storagetype(::Type{BraidingTensor{T, S, A}}) where {T, S, A} = A
space(b::BraidingTensor) = b.adjoint ? b.V1 ⊗ b.V2 ← b.V2 ⊗ b.V1 : b.V2 ⊗ b.V1 ← b.V1 ⊗ b.V2

function Base.getindex(b::BraidingTensor)
sectortype(b) === Trivial || throw(SectorMismatch())
Expand Down Expand Up @@ -99,6 +105,13 @@ function _braiding_factor(f₁, f₂, inv::Bool = false)
return r
end

# generates scalar indexing errors on GPU
function fill_braidingsubblock!(data, val)
f(I) = ((I[1] == I[4]) & (I[2] == I[3])) * val
return data .= f.(CartesianIndices(data))
end


@inline function subblock(
b::BraidingTensor, (f₁, f₂)::Tuple{FusionTree{I, 2}, FusionTree{I, 2}}
) where {I <: Sector}
Expand All @@ -113,17 +126,10 @@ end
throw(SectorMismatch())
end
d = (dims(codomain(b), f₁.uncoupled)..., dims(domain(b), f₂.uncoupled)...)
n1 = d[1] * d[2]
n2 = d[3] * d[4]
data = sreshape(StridedView(Matrix{eltype(b)}(undef, n1, n2)), d)
fill!(data, zero(eltype(b)))

data_parent = storagetype(b)(undef, prod(d))
data = sreshape(StridedView(data_parent), d)
r = _braiding_factor(f₁, f₂, b.adjoint)
if !isnothing(r)
@inbounds for i in axes(data, 1), j in axes(data, 2)
data[i, j, j, i] = r
end
end
isnothing(r) ? zerovector!(data) : fill_braidingsubblock!(data, r)
return data
end

Expand All @@ -134,49 +140,50 @@ TensorMap(b::BraidingTensor) = copy!(similar(b), b)
Base.convert(::Type{TensorMap}, b::BraidingTensor) = TensorMap(b)

Base.complex(b::BraidingTensor{<:Complex}) = b
function Base.complex(b::BraidingTensor)
return BraidingTensor{complex(scalartype(b))}(space(b), b.adjoint)
function Base.complex(b::BraidingTensor{T, S, A}) where {T, S, A}
Tc = complex(T)
Ac = similarstoragetype(A, Tc)
return BraidingTensor{Tc, S, Ac}(space(b), b.adjoint)
end

function block(b::BraidingTensor, s::Sector)
I = sectortype(b)
I == typeof(s) || throw(SectorMismatch())

# TODO: probably always square?
m = blockdim(codomain(b), s)
n = blockdim(domain(b), s)
data = Matrix{eltype(b)}(undef, (m, n))
Comment thread
kshyatt marked this conversation as resolved.

length(data) == 0 && return data # s ∉ blocksectors(b)

data = fill!(data, zero(eltype(b)))

# Trivial
function fill_braidingblock!(data, b::BraidingTensor, s::Trivial)
V1, V2 = codomain(b)
if sectortype(b) === Trivial
d1, d2 = dim(V1), dim(V2)
subblock = sreshape(StridedView(data), (d1, d2, d2, d1))
@inbounds for i in axes(subblock, 1), j in axes(subblock, 2)
subblock[i, j, j, i] = one(eltype(b))
end
return data
end
d1, d2 = dim(V1), dim(V2)
subblock = sreshape(StridedView(data), (d1, d2, d2, d1))
fill_braidingsubblock!(subblock, one(eltype(b)))
return data
end

# Nontrivial
function fill_braidingblock!(data, b::BraidingTensor, s::Sector)
base_offset = first(blockstructure(b)[s][2]) - 1

for ((f₁, f₂), (sz, str, off)) in pairs(subblockstructure(space(b)))
(f₁.coupled == f₂.coupled == s) || continue
r = _braiding_factor(f₁, f₂, b.adjoint)
isnothing(r) && continue
# change offset to account for single block
subblock = StridedView(data, sz, str, off - base_offset)
@inbounds for i in axes(subblock, 1), j in axes(subblock, 2)
subblock[i, j, j, i] = r
end
isnothing(r) ? zerovector!(subblock) : fill_braidingsubblock!(subblock, r)
end

return data
end

function block(b::BraidingTensor, s::Sector)
I = sectortype(b)
I == typeof(s) || throw(SectorMismatch())

# TODO: probably always square?
m = blockdim(codomain(b), s)
n = blockdim(domain(b), s)

data = reshape(storagetype(b)(undef, m * n), (m, n))

m * n == 0 && return data # s ∉ blocksectors(b)

return fill_braidingblock!(data, b, s)
end

# Index manipulations
# -------------------
has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false
Expand Down
Loading
Loading