Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ITensorBase"
uuid = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
version = "0.9.0"
version = "0.9.1"
authors = ["ITensor developers <support@itensor.org> and contributors"]

[workspace]
Expand Down
6 changes: 4 additions & 2 deletions ext/ITensorBaseMooncakeExt/ITensorBaseMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
module ITensorBaseMooncakeExt

using ITensorBase: AbstractNamedTensor, NamedUnitRange, dimnames, dimnames_setdiff, inds,
name, nameperm, to_inds, uniquename
using ITensorBase: AbstractNamedTensor, NamedUnitRange, dimnames, dimnames_setdiff,
from_contract_labels, inds, name, nameperm, to_contract_labels, to_inds, uniquename
using Mooncake: Mooncake, @zero_derivative, DefaultCtx

Mooncake.tangent_type(::Type{<:NamedUnitRange}) = Mooncake.NoTangent

@zero_derivative DefaultCtx Tuple{typeof(nameperm), Any, Any, Any}
@zero_derivative DefaultCtx Tuple{typeof(to_contract_labels), Any, Any}
@zero_derivative DefaultCtx Tuple{typeof(from_contract_labels), Any, Any, Any}
# `dimnames(::NamedTensor)` returns the stored names `Vector` directly, so its output
# aliases a field, where `@zero_derivative` is documented to be incorrect. Let
# Mooncake differentiate it through the underlying `getfield`, whose built-in rule
Expand Down
36 changes: 33 additions & 3 deletions src/tensoralgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,41 @@ using TupleTools: TupleTools
# This layer is used to define derivative rules (to skip differentiating `setdiff`).
dimnames_setdiff(s1, s2) = setdiff(s1, s2)

# Relabel the global dimension names to labels local to this contraction: operand 1's names
# become `1:length(dimnames1)`, and each of operand 2's names reuses operand 1's label where
# they match (the contracted dimensions) and otherwise gets a fresh label
# `length(dimnames1) + position`. Encoding the fresh labels by position lets
# `from_contract_labels` recover the uncontracted operand-2 names by position.
function to_contract_labels(dimnames1, dimnames2)
n1len = length(dimnames1)
labels2 = map(eachindex(dimnames2)) do i2
i1 = findfirst(==(dimnames2[i2]), dimnames1)
return isnothing(i1) ? n1len + i2 : i1
end
return 1:n1len, labels2
end

# Invert `to_contract_labels`: map the result labels back to the global dimension names of the
# two operands by position.
function from_contract_labels(labels, dimnames1, dimnames2)
n1len = length(dimnames1)
return map(
label -> label <= n1len ? dimnames1[label] : dimnames2[label - n1len],
labels
)
end

Base.:*(a1::AbstractNamedTensor, a2::AbstractNamedTensor) = mul_nameddims(a1, a2)
function mul_nameddims(a1::AbstractNamedTensor, a2::AbstractNamedTensor)
a_dest, dimnames_dest = TA.contract(
unnamed(a1), dimnames(a1), unnamed(a2), dimnames(a2)
)
dimnames1, dimnames2 = dimnames(a1), dimnames(a2)
# The contraction structure depends only on the equality pattern of the dimension names,
# so relabel them to integers local to this contraction once and run the contraction-label
# bookkeeping on cheap integers, recovering the result names afterward. This keeps
# `TensorAlgebra`'s `setdiff`/`findfirst` passes off the dimension-name type, which for
# `IndexName` carries an id and a tag dictionary and is comparatively expensive to compare.
labels1, labels2 = to_contract_labels(dimnames1, dimnames2)
a_dest, labels_dest = TA.contract(unnamed(a1), labels1, unnamed(a2), labels2)
dimnames_dest = from_contract_labels(labels_dest, dimnames1, dimnames2)
return nameddims(a_dest, dimnames_dest)
end

Expand Down
Loading