diff --git a/Project.toml b/Project.toml index aaf0a86..fe17043 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.13.2" +version = "0.13.3" authors = ["ITensor developers and contributors"] [workspace] diff --git a/ext/TensorAlgebraMooncakeExt/TensorAlgebraMooncakeExt.jl b/ext/TensorAlgebraMooncakeExt/TensorAlgebraMooncakeExt.jl index 8913e1c..32efb8f 100644 --- a/ext/TensorAlgebraMooncakeExt/TensorAlgebraMooncakeExt.jl +++ b/ext/TensorAlgebraMooncakeExt/TensorAlgebraMooncakeExt.jl @@ -13,6 +13,15 @@ Mooncake.tangent_type(::Type{<:ContractAlgorithm}) = Mooncake.NoTangent } @zero_derivative DefaultCtx Tuple{typeof(biperm), Any, Any, Any} @zero_derivative DefaultCtx Tuple{typeof(biperms), typeof(contract), Any, Any, Any} +@zero_derivative DefaultCtx Tuple{ + typeof(biperms), + typeof(contract), + Val, + Any, + Any, + Any, + Any, +} @zero_derivative DefaultCtx Tuple{ typeof(check_input), typeof(contract), Any, Any, Any, Any, Any, Any, } diff --git a/src/contract/biperms.jl b/src/contract/biperms.jl index d287f65..0fba70a 100644 --- a/src/contract/biperms.jl +++ b/src/contract/biperms.jl @@ -1,8 +1,10 @@ -# `a ∖ b` and `a ∩ b` as a `Vector`, preserving the order of `a`, via a linear -# scan. For the small collections here `Base.setdiff`/`intersect` are slower -# because they build a `Set` and hash. Both assume set-like (unique) inputs. +import TupleTools + +# `a ∖ b` as a `Vector`, preserving the order of `a`, via a linear scan. For the small +# collections here `Base.setdiff` is slower because it builds a `Set` and hashes; it +# assumes set-like (unique) inputs. Used to assemble the destination labels in +# `contract_labels`. smallsetdiff(a, b) = [x for x in a if x ∉ b] -smallintersect(a, b) = [x for x in a if x ∈ b] # Position of each element of `x` in `y`, as a tuple. Linear scan, no hashing # (`Base.indexin` builds a `Dict`), for the small collections here. @@ -32,24 +34,41 @@ length_domain(t) = 0 length_codomain(t) = length(t) - length_domain(t) +# `findfirst` for a match the caller guarantees exists, so the result is an `Int` rather +# than `Union{Int, Nothing}` (the `Nothing` would otherwise break inference downstream). +checked_findfirst(pred, collection) = something(findfirst(pred, collection)) + # codomain <-- domain -function biperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2) - codomain = Tuple(smallsetdiff(dimnames1, dimnames2)) - contracted = Tuple(smallintersect(dimnames1, dimnames2)) - domain = Tuple(smallsetdiff(dimnames2, dimnames1)) - - # `codomain`/`contracted` and `contracted`/`domain` partition the operands by - # construction, so the only label consistency left to check is that the - # destination carries exactly the uncontracted labels. `biperm` below then - # checks each group lands in the destination. - length(codomain) + length(domain) == length(dimnames_dest) || +function biperms(::typeof(contract), labels_dest, labels1, labels2) + t1, t2 = Tuple(labels1), Tuple(labels2) + contracted1 = map(in(t2), t1) + return biperms(contract, Val(count(contracted1)), labels_dest, t1, t2, contracted1) +end +# `K` is the number of contracted labels. Passing it as a `Val` makes the group sizes +# compile-time constants, so the permutations below come out as concretely-typed tuples and +# the rest of the contraction stays type-stable. `contracted1` is the boolean mask of which +# of `labels1`'s labels are contracted (its `count` is `K`), threaded in from the caller. +function biperms( + ::typeof(contract), ::Val{K}, labels_dest, labels1, labels2, contracted1 + ) where {K} + n1, n2 = length(labels1), length(labels2) + # `sortperm` of the boolean mask is a stable partition: uncontracted (`false`) indices + # first, contracted (`true`) indices last, each in their original order. + perm1_codomain, perm1_domain = + bipartition(TupleTools.sortperm(contracted1), Val(n1 - K)) + perm2_domain, _ = + bipartition(TupleTools.sortperm(map(in(labels1), labels2)), Val(n2 - K)) + # Align the contracted groups: list operand 2's contracted labels in operand 1's order. + perm2_codomain = map(p -> checked_findfirst(==(labels1[p]), labels2), perm1_domain) + # The operands partition into (un)contracted groups by construction; the only label + # consistency left to check is that the destination carries exactly the uncontracted + # labels. Locating each below then checks they all land in the destination. + length(labels_dest) == (n1 - K) + (n2 - K) || throw(ArgumentError("Invalid contraction labels")) - - perm_codomain_dest, perm_domain_dest = biperm(dimnames_dest, codomain, domain) - invperm_dest = invperm((perm_codomain_dest..., perm_domain_dest...)) - biperm_dest = bipartition(invperm_dest, Val(length(codomain))) - - biperm1 = biperm(dimnames1, codomain, contracted) - biperm2 = biperm(dimnames2, contracted, domain) - return biperm_dest, biperm1, biperm2 + pos_dest = ( + map(p -> checked_findfirst(==(labels1[p]), labels_dest), perm1_codomain)..., + map(p -> checked_findfirst(==(labels2[p]), labels_dest), perm2_domain)..., + ) + biperm_dest = bipartition(invperm(pos_dest), Val(n1 - K)) + return biperm_dest, (perm1_codomain, perm1_domain), (perm2_codomain, perm2_domain) end diff --git a/src/contract/contract.jl b/src/contract/contract.jl index 7141b6d..a8a6d19 100644 --- a/src/contract/contract.jl +++ b/src/contract/contract.jl @@ -9,15 +9,30 @@ end function contract( labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs... ) - (perm_dest_codomain, perm_dest_domain), (perm1_codomain, perm1_domain), - (perm2_codomain, perm2_domain) = biperms(contract, labels_dest, labels1, labels2) - return contract( - perm_dest_codomain, perm_dest_domain, - a1, perm1_codomain, perm1_domain, - a2, perm2_codomain, perm2_domain; + t1 = ntuple(i -> labels1[i], Val(ndims(a1))) + t2 = ntuple(i -> labels2[i], Val(ndims(a2))) + contracted1 = map(in(t2), t1) + # Cross into a `Val(K)` method (a function-barrier on the contracted count) so the + # bipartitioned permutations and the contraction below them are type-stable. + return _contract( + Val(count(contracted1)), + labels_dest, + a1, + t1, + a2, + t2, + contracted1; kwargs... ) end +function _contract( + ::Val{K}, labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2, + contracted1; kwargs... + ) where {K} + biperm_dest, biperm1, biperm2 = + biperms(contract, Val(K), labels_dest, labels1, labels2, contracted1) + return contract(biperm_dest..., a1, biperm1..., a2, biperm2...; kwargs...) +end # contract (bipartitioned permutations) function contract( @@ -117,13 +132,25 @@ function contractopadd!( α::Number, β::Number; kwargs... ) - (perm_dest_codomain, perm_dest_domain), (perm1_codomain, perm1_domain), - (perm2_codomain, perm2_domain) = biperms(contract, labels_dest, labels1, labels2) + t1 = ntuple(i -> labels1[i], Val(ndims(a1))) + t2 = ntuple(i -> labels2[i], Val(ndims(a2))) + contracted1 = map(in(t2), t1) + # Cross into a `Val(K)` method (a function-barrier on the contracted count) so the + # bipartitioned permutations and the contraction below them are type-stable. + return _contractopadd!( + Val(count(contracted1)), a_dest, labels_dest, + op1, a1, t1, op2, a2, t2, α, β, contracted1; kwargs... + ) +end +function _contractopadd!( + ::Val{K}, a_dest::AbstractArray, labels_dest, + op1, a1::AbstractArray, labels1, op2, a2::AbstractArray, labels2, + α::Number, β::Number, contracted1; kwargs... + ) where {K} + biperm_dest, biperm1, biperm2 = + biperms(contract, Val(K), labels_dest, labels1, labels2, contracted1) return contractopadd!( - a_dest, perm_dest_codomain, perm_dest_domain, - op1, a1, perm1_codomain, perm1_domain, - op2, a2, perm2_codomain, perm2_domain, - α, β; kwargs... + a_dest, biperm_dest..., op1, a1, biperm1..., op2, a2, biperm2..., α, β; kwargs... ) end # contractopadd! (bipartitioned permutations, algorithm selection) diff --git a/test/test_setoperations.jl b/test/test_setoperations.jl index 7a7bd8a..c489854 100644 --- a/test/test_setoperations.jl +++ b/test/test_setoperations.jl @@ -1,15 +1,13 @@ -using TensorAlgebra: biperms, contract, smallintersect, smallsetdiff, tuple_indexin +using TensorAlgebra: biperms, contract, smallsetdiff, tuple_indexin using Test: @test, @test_throws, @testset -@testset "smallsetdiff/smallintersect" begin +@testset "smallsetdiff" begin # Order-preserving, returning a `Vector`. @test smallsetdiff((:i, :j, :k), (:k, :i)) == [:j] - @test smallintersect((:i, :j, :k), (:k, :i)) == [:i, :k] @test smallsetdiff([:i, :j, :k], [:k, :i]) == [:j] - @test smallintersect([:i, :j, :k], [:k, :i]) == [:i, :k] # Disjoint and empty cases. @test smallsetdiff((:i, :j), ()) == [:i, :j] - @test smallintersect((:i, :j), (:k,)) == [] + @test smallsetdiff((:i, :j), (:i, :j)) == [] end @testset "tuple_indexin" begin