diff --git a/Project.toml b/Project.toml index b34ba28..aaf8772 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ITensorBase" uuid = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" -version = "0.9.0" +version = "0.9.1" authors = ["ITensor developers and contributors"] [workspace] diff --git a/ext/ITensorBaseMooncakeExt/ITensorBaseMooncakeExt.jl b/ext/ITensorBaseMooncakeExt/ITensorBaseMooncakeExt.jl index 8b52725..164784d 100644 --- a/ext/ITensorBaseMooncakeExt/ITensorBaseMooncakeExt.jl +++ b/ext/ITensorBaseMooncakeExt/ITensorBaseMooncakeExt.jl @@ -1,12 +1,14 @@ module ITensorBaseMooncakeExt -using ITensorBase: AbstractNamedTensor, NamedUnitRange, dimnames, dimnames_setdiff, inds, - name, nameperm, to_inds, uniquename +using ITensorBase: AbstractNamedTensor, NamedUnitRange, dimnames, dimnames_setdiff, + from_contract_labels, inds, name, nameperm, to_contract_labels, to_inds, uniquename using Mooncake: Mooncake, @zero_derivative, DefaultCtx Mooncake.tangent_type(::Type{<:NamedUnitRange}) = Mooncake.NoTangent @zero_derivative DefaultCtx Tuple{typeof(nameperm), Any, Any, Any} +@zero_derivative DefaultCtx Tuple{typeof(to_contract_labels), Any, Any} +@zero_derivative DefaultCtx Tuple{typeof(from_contract_labels), Any, Any, Any} # `dimnames(::NamedTensor)` returns the stored names `Vector` directly, so its output # aliases a field, where `@zero_derivative` is documented to be incorrect. Let # Mooncake differentiate it through the underlying `getfield`, whose built-in rule diff --git a/src/tensoralgebra.jl b/src/tensoralgebra.jl index 27bcd88..ffca850 100644 --- a/src/tensoralgebra.jl +++ b/src/tensoralgebra.jl @@ -7,11 +7,41 @@ using TupleTools: TupleTools # This layer is used to define derivative rules (to skip differentiating `setdiff`). dimnames_setdiff(s1, s2) = setdiff(s1, s2) +# Relabel the global dimension names to labels local to this contraction: operand 1's names +# become `1:length(dimnames1)`, and each of operand 2's names reuses operand 1's label where +# they match (the contracted dimensions) and otherwise gets a fresh label +# `length(dimnames1) + position`. Encoding the fresh labels by position lets +# `from_contract_labels` recover the uncontracted operand-2 names by position. +function to_contract_labels(dimnames1, dimnames2) + n1len = length(dimnames1) + labels2 = map(eachindex(dimnames2)) do i2 + i1 = findfirst(==(dimnames2[i2]), dimnames1) + return isnothing(i1) ? n1len + i2 : i1 + end + return 1:n1len, labels2 +end + +# Invert `to_contract_labels`: map the result labels back to the global dimension names of the +# two operands by position. +function from_contract_labels(labels, dimnames1, dimnames2) + n1len = length(dimnames1) + return map( + label -> label <= n1len ? dimnames1[label] : dimnames2[label - n1len], + labels + ) +end + Base.:*(a1::AbstractNamedTensor, a2::AbstractNamedTensor) = mul_nameddims(a1, a2) function mul_nameddims(a1::AbstractNamedTensor, a2::AbstractNamedTensor) - a_dest, dimnames_dest = TA.contract( - unnamed(a1), dimnames(a1), unnamed(a2), dimnames(a2) - ) + dimnames1, dimnames2 = dimnames(a1), dimnames(a2) + # The contraction structure depends only on the equality pattern of the dimension names, + # so relabel them to integers local to this contraction once and run the contraction-label + # bookkeeping on cheap integers, recovering the result names afterward. This keeps + # `TensorAlgebra`'s `setdiff`/`findfirst` passes off the dimension-name type, which for + # `IndexName` carries an id and a tag dictionary and is comparatively expensive to compare. + labels1, labels2 = to_contract_labels(dimnames1, dimnames2) + a_dest, labels_dest = TA.contract(unnamed(a1), labels1, unnamed(a2), labels2) + dimnames_dest = from_contract_labels(labels_dest, dimnames1, dimnames2) return nameddims(a_dest, dimnames_dest) end