Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
version = "0.13.2"
version = "0.13.3"
authors = ["ITensor developers <support@itensor.org> and contributors"]

[workspace]
Expand Down
9 changes: 9 additions & 0 deletions ext/TensorAlgebraMooncakeExt/TensorAlgebraMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ Mooncake.tangent_type(::Type{<:ContractAlgorithm}) = Mooncake.NoTangent
}
@zero_derivative DefaultCtx Tuple{typeof(biperm), Any, Any, Any}
@zero_derivative DefaultCtx Tuple{typeof(biperms), typeof(contract), Any, Any, Any}
@zero_derivative DefaultCtx Tuple{
typeof(biperms),
typeof(contract),
Val,
Any,
Any,
Any,
Any,
}
@zero_derivative DefaultCtx Tuple{
typeof(check_input), typeof(contract), Any, Any, Any, Any, Any, Any,
}
Expand Down
63 changes: 41 additions & 22 deletions src/contract/biperms.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# `a ∖ b` and `a ∩ b` as a `Vector`, preserving the order of `a`, via a linear
# scan. For the small collections here `Base.setdiff`/`intersect` are slower
# because they build a `Set` and hash. Both assume set-like (unique) inputs.
import TupleTools

# `a ∖ b` as a `Vector`, preserving the order of `a`, via a linear scan. For the small
# collections here `Base.setdiff` is slower because it builds a `Set` and hashes; it
# assumes set-like (unique) inputs. Used to assemble the destination labels in
# `contract_labels`.
smallsetdiff(a, b) = [x for x in a if x ∉ b]
smallintersect(a, b) = [x for x in a if x ∈ b]

# Position of each element of `x` in `y`, as a tuple. Linear scan, no hashing
# (`Base.indexin` builds a `Dict`), for the small collections here.
Expand Down Expand Up @@ -32,24 +34,41 @@ length_domain(t) = 0

length_codomain(t) = length(t) - length_domain(t)

# `findfirst` for a match the caller guarantees exists, so the result is an `Int` rather
# than `Union{Int, Nothing}` (the `Nothing` would otherwise break inference downstream).
checked_findfirst(pred, collection) = something(findfirst(pred, collection))

# codomain <-- domain
function biperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)
codomain = Tuple(smallsetdiff(dimnames1, dimnames2))
contracted = Tuple(smallintersect(dimnames1, dimnames2))
domain = Tuple(smallsetdiff(dimnames2, dimnames1))

# `codomain`/`contracted` and `contracted`/`domain` partition the operands by
# construction, so the only label consistency left to check is that the
# destination carries exactly the uncontracted labels. `biperm` below then
# checks each group lands in the destination.
length(codomain) + length(domain) == length(dimnames_dest) ||
function biperms(::typeof(contract), labels_dest, labels1, labels2)
t1, t2 = Tuple(labels1), Tuple(labels2)
contracted1 = map(in(t2), t1)
return biperms(contract, Val(count(contracted1)), labels_dest, t1, t2, contracted1)
end
# `K` is the number of contracted labels. Passing it as a `Val` makes the group sizes
# compile-time constants, so the permutations below come out as concretely-typed tuples and
# the rest of the contraction stays type-stable. `contracted1` is the boolean mask of which
# of `labels1`'s labels are contracted (its `count` is `K`), threaded in from the caller.
function biperms(
::typeof(contract), ::Val{K}, labels_dest, labels1, labels2, contracted1
) where {K}
n1, n2 = length(labels1), length(labels2)
# `sortperm` of the boolean mask is a stable partition: uncontracted (`false`) indices
# first, contracted (`true`) indices last, each in their original order.
perm1_codomain, perm1_domain =
bipartition(TupleTools.sortperm(contracted1), Val(n1 - K))
perm2_domain, _ =
bipartition(TupleTools.sortperm(map(in(labels1), labels2)), Val(n2 - K))
# Align the contracted groups: list operand 2's contracted labels in operand 1's order.
perm2_codomain = map(p -> checked_findfirst(==(labels1[p]), labels2), perm1_domain)
# The operands partition into (un)contracted groups by construction; the only label
# consistency left to check is that the destination carries exactly the uncontracted
# labels. Locating each below then checks they all land in the destination.
length(labels_dest) == (n1 - K) + (n2 - K) ||
throw(ArgumentError("Invalid contraction labels"))

perm_codomain_dest, perm_domain_dest = biperm(dimnames_dest, codomain, domain)
invperm_dest = invperm((perm_codomain_dest..., perm_domain_dest...))
biperm_dest = bipartition(invperm_dest, Val(length(codomain)))

biperm1 = biperm(dimnames1, codomain, contracted)
biperm2 = biperm(dimnames2, contracted, domain)
return biperm_dest, biperm1, biperm2
pos_dest = (
map(p -> checked_findfirst(==(labels1[p]), labels_dest), perm1_codomain)...,
map(p -> checked_findfirst(==(labels2[p]), labels_dest), perm2_domain)...,
)
biperm_dest = bipartition(invperm(pos_dest), Val(n1 - K))
return biperm_dest, (perm1_codomain, perm1_domain), (perm2_codomain, perm2_domain)
end
51 changes: 39 additions & 12 deletions src/contract/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,30 @@ end
function contract(
labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs...
)
(perm_dest_codomain, perm_dest_domain), (perm1_codomain, perm1_domain),
(perm2_codomain, perm2_domain) = biperms(contract, labels_dest, labels1, labels2)
return contract(
perm_dest_codomain, perm_dest_domain,
a1, perm1_codomain, perm1_domain,
a2, perm2_codomain, perm2_domain;
t1 = ntuple(i -> labels1[i], Val(ndims(a1)))
t2 = ntuple(i -> labels2[i], Val(ndims(a2)))
contracted1 = map(in(t2), t1)
# Cross into a `Val(K)` method (a function-barrier on the contracted count) so the
# bipartitioned permutations and the contraction below them are type-stable.
return _contract(
Val(count(contracted1)),
labels_dest,
a1,
t1,
a2,
t2,
contracted1;
kwargs...
)
end
function _contract(
::Val{K}, labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2,
contracted1; kwargs...
) where {K}
biperm_dest, biperm1, biperm2 =
biperms(contract, Val(K), labels_dest, labels1, labels2, contracted1)
return contract(biperm_dest..., a1, biperm1..., a2, biperm2...; kwargs...)
end

# contract (bipartitioned permutations)
function contract(
Expand Down Expand Up @@ -117,13 +132,25 @@ function contractopadd!(
α::Number, β::Number;
kwargs...
)
(perm_dest_codomain, perm_dest_domain), (perm1_codomain, perm1_domain),
(perm2_codomain, perm2_domain) = biperms(contract, labels_dest, labels1, labels2)
t1 = ntuple(i -> labels1[i], Val(ndims(a1)))
t2 = ntuple(i -> labels2[i], Val(ndims(a2)))
contracted1 = map(in(t2), t1)
# Cross into a `Val(K)` method (a function-barrier on the contracted count) so the
# bipartitioned permutations and the contraction below them are type-stable.
return _contractopadd!(
Val(count(contracted1)), a_dest, labels_dest,
op1, a1, t1, op2, a2, t2, α, β, contracted1; kwargs...
)
end
function _contractopadd!(
::Val{K}, a_dest::AbstractArray, labels_dest,
op1, a1::AbstractArray, labels1, op2, a2::AbstractArray, labels2,
α::Number, β::Number, contracted1; kwargs...
) where {K}
biperm_dest, biperm1, biperm2 =
biperms(contract, Val(K), labels_dest, labels1, labels2, contracted1)
return contractopadd!(
a_dest, perm_dest_codomain, perm_dest_domain,
op1, a1, perm1_codomain, perm1_domain,
op2, a2, perm2_codomain, perm2_domain,
α, β; kwargs...
a_dest, biperm_dest..., op1, a1, biperm1..., op2, a2, biperm2..., α, β; kwargs...
)
end
# contractopadd! (bipartitioned permutations, algorithm selection)
Expand Down
8 changes: 3 additions & 5 deletions test/test_setoperations.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
using TensorAlgebra: biperms, contract, smallintersect, smallsetdiff, tuple_indexin
using TensorAlgebra: biperms, contract, smallsetdiff, tuple_indexin
using Test: @test, @test_throws, @testset

@testset "smallsetdiff/smallintersect" begin
@testset "smallsetdiff" begin
# Order-preserving, returning a `Vector`.
@test smallsetdiff((:i, :j, :k), (:k, :i)) == [:j]
@test smallintersect((:i, :j, :k), (:k, :i)) == [:i, :k]
@test smallsetdiff([:i, :j, :k], [:k, :i]) == [:j]
@test smallintersect([:i, :j, :k], [:k, :i]) == [:i, :k]
# Disjoint and empty cases.
@test smallsetdiff((:i, :j), ()) == [:i, :j]
@test smallintersect((:i, :j), (:k,)) == []
@test smallsetdiff((:i, :j), (:i, :j)) == []
end

@testset "tuple_indexin" begin
Expand Down
Loading