Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
version = "0.15.1"
version = "0.16.0"
authors = ["ITensor developers <support@itensor.org> and contributors"]

[workspace]
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ path = ".."
Documenter = "1.8.1"
ITensorFormatter = "0.2.27"
Literate = "2.20.1"
TensorAlgebra = "0.15"
TensorAlgebra = "0.16"
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
path = ".."

[compat]
TensorAlgebra = "0.15"
TensorAlgebra = "0.16"
24 changes: 12 additions & 12 deletions src/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ for f in (
function $f(style::FusionStyle, A, ndims_codomain::Val; kwargs...)
A_mat = matricize(style, A, ndims_codomain)
X, Y = MatrixAlgebraKit.$f(A_mat; kwargs...)
axes_codomain, axes_domain = bipartition(axes(A), ndims_codomain)
return unmatricize(style, X, axes_codomain, (axes(X, 2),)),
axes_codomain, axes_domain = bipartition_axes(axes(A), ndims_codomain)
return unmatricize(style, X, axes_codomain, (conj(axes(X, 2)),)),
unmatricize(style, Y, (axes(Y, 1),), axes_domain)
end
function $f(A, ndims_codomain::Val; kwargs...)
Expand Down Expand Up @@ -209,9 +209,9 @@ for f in (:svd_compact, :svd_full, :svd_trunc)
function $f(style::FusionStyle, A, ndims_codomain::Val; kwargs...)
A_mat = matricize(style, A, ndims_codomain)
U, S, Vᴴ = MatrixAlgebraKit.$f(A_mat; kwargs...)
axes_codomain, axes_domain = bipartition(axes(A), ndims_codomain)
return unmatricize(style, U, axes_codomain, (axes(U, 2),)),
unmatricize(style, S, (axes(S, 1),), (axes(S, 2),)),
axes_codomain, axes_domain = bipartition_axes(axes(A), ndims_codomain)
return unmatricize(style, U, axes_codomain, (conj(axes(U, 2)),)),
unmatricize(style, S, (axes(S, 1),), (conj(axes(S, 2)),)),
unmatricize(style, Vᴴ, (axes(Vᴴ, 1),), axes_domain)
end
function $f(A, ndims_codomain::Val; kwargs...)
Expand All @@ -228,7 +228,7 @@ for f in (:eigh_full, :eig_full, :eigh_trunc, :eig_trunc)
A_mat = matricize(style, A, ndims_codomain)
D, V = MatrixAlgebraKit.$f(A_mat; kwargs...)
axes_codomain = first(bipartition(axes(A), ndims_codomain))
return D, unmatricize(style, V, axes_codomain, (axes(V, ndims(V)),))
return D, unmatricize(style, V, axes_codomain, (conj(axes(V, ndims(V))),))
end
function $f(A, ndims_codomain::Val; kwargs...)
return $f(FusionStyle(A), A, ndims_codomain; kwargs...)
Expand Down Expand Up @@ -406,7 +406,7 @@ function left_null!!(style::FusionStyle, A, ndims_codomain::Val; kwargs...)
A_mat = matricize(style, A, ndims_codomain)
N = MatrixAlgebraKit.left_null!(A_mat; kwargs...)
axes_codomain = first(bipartition(axes(A), ndims_codomain))
return unmatricize(style, N, axes_codomain, (axes(N, 2),))
return unmatricize(style, N, axes_codomain, (conj(axes(N, 2)),))
end
function left_null!!(A, ndims_codomain::Val; kwargs...)
return left_null!!(FusionStyle(A), A, ndims_codomain; kwargs...)
Expand Down Expand Up @@ -442,7 +442,7 @@ right_null
function right_null!!(style::FusionStyle, A, ndims_codomain::Val; kwargs...)
A_mat = matricize(style, A, ndims_codomain)
Nᴴ = MatrixAlgebraKit.right_null!(A_mat; kwargs...)
axes_domain = last(bipartition(axes(A), ndims_codomain))
_, axes_domain = bipartition_axes(axes(A), ndims_codomain)
return unmatricize(style, Nᴴ, (axes(Nᴴ, 1),), axes_domain)
end
function right_null!!(A, ndims_codomain::Val; kwargs...)
Expand Down Expand Up @@ -499,7 +499,7 @@ function gram_eigh_full!!(
A_mat = matricize(style, A, ndims_codomain)
X = MatrixAlgebra.gram_eigh_full!!(A_mat; kwargs...)
axes_codomain = first(bipartition(axes(A), ndims_codomain))
return unmatricize(style, X, axes_codomain, (axes(X, 2),))
return unmatricize(style, X, axes_codomain, (conj(axes(X, 2)),))
end
function gram_eigh_full!!(A, ndims_codomain::Val; kwargs...)
return gram_eigh_full!!(FusionStyle(A), A, ndims_codomain; kwargs...)
Expand Down Expand Up @@ -560,8 +560,8 @@ function gram_eigh_full_with_pinv!!(
A_mat = matricize(style, A, ndims_codomain)
X, Y = MatrixAlgebra.gram_eigh_full_with_pinv!!(A_mat; kwargs...)
axes_codomain = first(bipartition(axes(A), ndims_codomain))
return unmatricize(style, X, axes_codomain, (axes(X, 2),)),
unmatricize(style, Y, (axes(Y, 1),), conj.(axes_codomain))
return unmatricize(style, X, axes_codomain, (conj(axes(X, 2)),)),
unmatricize(style, Y, (axes(Y, 1),), axes_codomain)
end
function gram_eigh_full_with_pinv!!(A, ndims_codomain::Val; kwargs...)
return gram_eigh_full_with_pinv!!(FusionStyle(A), A, ndims_codomain; kwargs...)
Expand Down Expand Up @@ -612,7 +612,7 @@ one
function one!!(style::FusionStyle, A, ndims_codomain::Val; kwargs...)
A_mat = matricize(style, A, ndims_codomain)
MatrixAlgebraKit.one!(A_mat)
codomain_axes, domain_axes = bipartition(axes(A), ndims_codomain)
codomain_axes, domain_axes = bipartition_axes(axes(A), ndims_codomain)
return unmatricize(style, A_mat, codomain_axes, domain_axes)
end
function one!!(A, ndims_codomain::Val; kwargs...)
Expand Down
25 changes: 19 additions & 6 deletions src/matricize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,28 @@ end

# ==================================== unmatricize =======================================
# Split form: `codomain_axes` and `domain_axes` are the destination axes for the codomain and
# domain groups. This is the primary overload point for new fusion styles. Permutation is
# handled separately by `unmatricizeperm`, so `unmatricize` never has to disambiguate axis
# tuples from permutation tuples regardless of how unconstrained `m` and the axes are.
# domain groups, given codomain-facing (un-dualized), the same convention as `similar_map`. A
# fusion style stores the domain axes dualized, so its overload re-dualizes them with `conj`
# (a no-op on a dense axis). This is the primary overload point for new fusion styles.
# Permutation is handled separately by `unmatricizeperm`, so `unmatricize` never has to
# disambiguate axis tuples from permutation tuples regardless of how unconstrained `m` and the
# axes are.
function unmatricize(style::FusionStyle, m, codomain_axes, domain_axes)
return throw(MethodError(unmatricize, (style, m, codomain_axes, domain_axes)))
end
function unmatricize(m, codomain_axes, domain_axes)
return unmatricize(FusionStyle(m), m, codomain_axes, domain_axes)
end

# Split `axes` into its codomain and domain groups like `bipartition`, but present the domain
# group codomain-facing (un-dualized) with `conj`, the convention `unmatricize` and `similar_map`
# take. The domain axes `bipartition` reads off `axes(a)` are in the stored (dualized) form, so
# this bridges from `axes(a)` to the `unmatricize` axis convention (a no-op on dense axes).
function bipartition_axes(t::Tuple, split...)
codomain_axes, domain_axes = bipartition(t, split...)
return codomain_axes, conj.(domain_axes)
end

# Inverse-bipermutation form: split `axes_dest` into codomain/domain groups reordered by the
# inverse bipermutation, unmatricize in that order, then permute back.
function unmatricizeperm(
Expand All @@ -223,7 +235,7 @@ function unmatricizeperm(
invbiperm = BiTuple(invperm_codomain, invperm_domain)
length(axes_dest) == length(invbiperm) ||
throw(ArgumentError("axes do not match permutation"))
codomain_axes, domain_axes = bipartition(axes_dest, invbiperm)
codomain_axes, domain_axes = bipartition_axes(axes_dest, invbiperm)
a12 = unmatricize(style, m, codomain_axes, domain_axes)
biperm_dest = BiTuple(Tuple(invperm(invbiperm)), Val(length_codomain(axes_dest)))
return bipermutedims(a12, biperm_dest)
Expand All @@ -242,7 +254,7 @@ function unmatricizeperm!(
invbiperm = BiTuple(invperm_codomain, invperm_domain)
ndims(a_dest) == length(invbiperm) ||
throw(ArgumentError("destination does not match permutation"))
codomain_axes, domain_axes = bipartition(axes(a_dest), invbiperm)
codomain_axes, domain_axes = bipartition_axes(axes(a_dest), invbiperm)
a_perm = unmatricize(style, m, codomain_axes, domain_axes)
biperm_dest = BiTuple(Tuple(invperm(invbiperm)), Val(length_codomain(axes(a_dest))))
return bipermutedims!(a_dest, a_perm, biperm_dest)
Expand All @@ -268,6 +280,7 @@ function matricizekind(
return PermuteMatricizeKind
end
# A dense reshape ignores the codomain/domain split: it just reshapes to the concatenated axes.
# `conj` re-dualizes the codomain-facing `domain_axes` into stored form, a no-op on a dense axis.
function unmatricize(style::ReshapeFusion, m, codomain_axes, domain_axes)
return reshape(m, (codomain_axes..., domain_axes...))
return reshape(m, (codomain_axes..., conj.(domain_axes)...))
end
2 changes: 1 addition & 1 deletion src/matrixfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ for f in MATRIX_FUNCTIONS
function $f(style::FusionStyle, a, ndims_codomain::Val; kwargs...)
a_mat = matricize(style, a, ndims_codomain)
fa_mat = Base.$f(a_mat; kwargs...)
codomain_axes, domain_axes = bipartition(axes(a), ndims_codomain)
codomain_axes, domain_axes = bipartition_axes(axes(a), ndims_codomain)
return unmatricize(style, fa_mat, codomain_axes, domain_axes)
end
function $f(a, ndims_codomain::Val; kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Random = "1.10"
SafeTestsets = "0.1"
StableRNGs = "1.0.2"
Suppressor = "0.2"
TensorAlgebra = "0.15"
TensorAlgebra = "0.16"
TensorOperations = "5.1.4"
Test = "1.10"
TestExtras = "0.3.1"
Loading