From d753a55d29e4a4aae9feaa8d23ce2fbeab496eaf Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Wed, 1 Jul 2026 19:07:39 -0400 Subject: [PATCH 1/2] Take unmatricize domain axes un-dualized (matching similar_map) `unmatricize` now takes its domain axes codomain-facing (un-dualized), the same convention `similar_map` already uses, rather than in the stored dualized form. The two allocation entry points now agree, and a backend receives the actual domain spaces instead of the dualized index spaces. Each fusion style re-dualizes the domain axes with `conj` (a no-op on a dense axis) when it stores them, so dense arrays are unaffected and only graded backends observe the change. This is a breaking change to the `unmatricize` axis convention: a caller reconstructing an array whose domain axis is `dual(r)` now passes `r`. --- Project.toml | 2 +- docs/Project.toml | 2 +- examples/Project.toml | 2 +- src/factorizations.jl | 24 ++++++++++++------------ src/matricize.jl | 16 ++++++++++------ src/matrixfunctions.jl | 2 +- test/Project.toml | 2 +- 7 files changed, 27 insertions(+), 23 deletions(-) diff --git a/Project.toml b/Project.toml index 91ca84e..718d95a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.15.1" +version = "0.16.0" authors = ["ITensor developers and contributors"] [workspace] diff --git a/docs/Project.toml b/docs/Project.toml index 119db36..065b705 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -11,4 +11,4 @@ path = ".." Documenter = "1.8.1" ITensorFormatter = "0.2.27" Literate = "2.20.1" -TensorAlgebra = "0.15" +TensorAlgebra = "0.16" diff --git a/examples/Project.toml b/examples/Project.toml index 9139910..480a18f 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,4 +5,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" path = ".." [compat] -TensorAlgebra = "0.15" +TensorAlgebra = "0.16" diff --git a/src/factorizations.jl b/src/factorizations.jl index c4ab7a9..37783d4 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -12,8 +12,8 @@ for f in ( A_mat = matricize(style, A, ndims_codomain) X, Y = MatrixAlgebraKit.$f(A_mat; kwargs...) axes_codomain, axes_domain = bipartition(axes(A), ndims_codomain) - return unmatricize(style, X, axes_codomain, (axes(X, 2),)), - unmatricize(style, Y, (axes(Y, 1),), axes_domain) + return unmatricize(style, X, axes_codomain, (conj(axes(X, 2)),)), + unmatricize(style, Y, (axes(Y, 1),), conj.(axes_domain)) end function $f(A, ndims_codomain::Val; kwargs...) return $f(FusionStyle(A), A, ndims_codomain; kwargs...) @@ -210,9 +210,9 @@ for f in (:svd_compact, :svd_full, :svd_trunc) A_mat = matricize(style, A, ndims_codomain) U, S, Vᴴ = MatrixAlgebraKit.$f(A_mat; kwargs...) axes_codomain, axes_domain = bipartition(axes(A), ndims_codomain) - return unmatricize(style, U, axes_codomain, (axes(U, 2),)), - unmatricize(style, S, (axes(S, 1),), (axes(S, 2),)), - unmatricize(style, Vᴴ, (axes(Vᴴ, 1),), axes_domain) + return unmatricize(style, U, axes_codomain, (conj(axes(U, 2)),)), + unmatricize(style, S, (axes(S, 1),), (conj(axes(S, 2)),)), + unmatricize(style, Vᴴ, (axes(Vᴴ, 1),), conj.(axes_domain)) end function $f(A, ndims_codomain::Val; kwargs...) return $f(FusionStyle(A), A, ndims_codomain; kwargs...) @@ -228,7 +228,7 @@ for f in (:eigh_full, :eig_full, :eigh_trunc, :eig_trunc) A_mat = matricize(style, A, ndims_codomain) D, V = MatrixAlgebraKit.$f(A_mat; kwargs...) axes_codomain = first(bipartition(axes(A), ndims_codomain)) - return D, unmatricize(style, V, axes_codomain, (axes(V, ndims(V)),)) + return D, unmatricize(style, V, axes_codomain, (conj(axes(V, ndims(V))),)) end function $f(A, ndims_codomain::Val; kwargs...) return $f(FusionStyle(A), A, ndims_codomain; kwargs...) @@ -406,7 +406,7 @@ function left_null!!(style::FusionStyle, A, ndims_codomain::Val; kwargs...) A_mat = matricize(style, A, ndims_codomain) N = MatrixAlgebraKit.left_null!(A_mat; kwargs...) axes_codomain = first(bipartition(axes(A), ndims_codomain)) - return unmatricize(style, N, axes_codomain, (axes(N, 2),)) + return unmatricize(style, N, axes_codomain, (conj(axes(N, 2)),)) end function left_null!!(A, ndims_codomain::Val; kwargs...) return left_null!!(FusionStyle(A), A, ndims_codomain; kwargs...) @@ -443,7 +443,7 @@ function right_null!!(style::FusionStyle, A, ndims_codomain::Val; kwargs...) A_mat = matricize(style, A, ndims_codomain) Nᴴ = MatrixAlgebraKit.right_null!(A_mat; kwargs...) axes_domain = last(bipartition(axes(A), ndims_codomain)) - return unmatricize(style, Nᴴ, (axes(Nᴴ, 1),), axes_domain) + return unmatricize(style, Nᴴ, (axes(Nᴴ, 1),), conj.(axes_domain)) end function right_null!!(A, ndims_codomain::Val; kwargs...) return right_null!!(FusionStyle(A), A, ndims_codomain; kwargs...) @@ -499,7 +499,7 @@ function gram_eigh_full!!( A_mat = matricize(style, A, ndims_codomain) X = MatrixAlgebra.gram_eigh_full!!(A_mat; kwargs...) axes_codomain = first(bipartition(axes(A), ndims_codomain)) - return unmatricize(style, X, axes_codomain, (axes(X, 2),)) + return unmatricize(style, X, axes_codomain, (conj(axes(X, 2)),)) end function gram_eigh_full!!(A, ndims_codomain::Val; kwargs...) return gram_eigh_full!!(FusionStyle(A), A, ndims_codomain; kwargs...) @@ -560,8 +560,8 @@ function gram_eigh_full_with_pinv!!( A_mat = matricize(style, A, ndims_codomain) X, Y = MatrixAlgebra.gram_eigh_full_with_pinv!!(A_mat; kwargs...) axes_codomain = first(bipartition(axes(A), ndims_codomain)) - return unmatricize(style, X, axes_codomain, (axes(X, 2),)), - unmatricize(style, Y, (axes(Y, 1),), conj.(axes_codomain)) + return unmatricize(style, X, axes_codomain, (conj(axes(X, 2)),)), + unmatricize(style, Y, (axes(Y, 1),), axes_codomain) end function gram_eigh_full_with_pinv!!(A, ndims_codomain::Val; kwargs...) return gram_eigh_full_with_pinv!!(FusionStyle(A), A, ndims_codomain; kwargs...) @@ -613,7 +613,7 @@ function one!!(style::FusionStyle, A, ndims_codomain::Val; kwargs...) A_mat = matricize(style, A, ndims_codomain) MatrixAlgebraKit.one!(A_mat) codomain_axes, domain_axes = bipartition(axes(A), ndims_codomain) - return unmatricize(style, A_mat, codomain_axes, domain_axes) + return unmatricize(style, A_mat, codomain_axes, conj.(domain_axes)) end function one!!(A, ndims_codomain::Val; kwargs...) return one!!(FusionStyle(A), A, ndims_codomain; kwargs...) diff --git a/src/matricize.jl b/src/matricize.jl index 61d709e..164edd8 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -198,9 +198,12 @@ end # ==================================== unmatricize ======================================= # Split form: `codomain_axes` and `domain_axes` are the destination axes for the codomain and -# domain groups. This is the primary overload point for new fusion styles. Permutation is -# handled separately by `unmatricizeperm`, so `unmatricize` never has to disambiguate axis -# tuples from permutation tuples regardless of how unconstrained `m` and the axes are. +# domain groups, given codomain-facing (un-dualized), the same convention as `similar_map`. A +# fusion style stores the domain axes dualized, so its overload re-dualizes them with `conj` +# (a no-op on a dense axis). This is the primary overload point for new fusion styles. +# Permutation is handled separately by `unmatricizeperm`, so `unmatricize` never has to +# disambiguate axis tuples from permutation tuples regardless of how unconstrained `m` and the +# axes are. function unmatricize(style::FusionStyle, m, codomain_axes, domain_axes) return throw(MethodError(unmatricize, (style, m, codomain_axes, domain_axes))) end @@ -224,7 +227,7 @@ function unmatricizeperm( length(axes_dest) == length(invbiperm) || throw(ArgumentError("axes do not match permutation")) codomain_axes, domain_axes = bipartition(axes_dest, invbiperm) - a12 = unmatricize(style, m, codomain_axes, domain_axes) + a12 = unmatricize(style, m, codomain_axes, conj.(domain_axes)) biperm_dest = BiTuple(Tuple(invperm(invbiperm)), Val(length_codomain(axes_dest))) return bipermutedims(a12, biperm_dest) end @@ -243,7 +246,7 @@ function unmatricizeperm!( ndims(a_dest) == length(invbiperm) || throw(ArgumentError("destination does not match permutation")) codomain_axes, domain_axes = bipartition(axes(a_dest), invbiperm) - a_perm = unmatricize(style, m, codomain_axes, domain_axes) + a_perm = unmatricize(style, m, codomain_axes, conj.(domain_axes)) biperm_dest = BiTuple(Tuple(invperm(invbiperm)), Val(length_codomain(axes(a_dest)))) return bipermutedims!(a_dest, a_perm, biperm_dest) end @@ -268,6 +271,7 @@ function matricizekind( return PermuteMatricizeKind end # A dense reshape ignores the codomain/domain split: it just reshapes to the concatenated axes. +# `conj` re-dualizes the codomain-facing `domain_axes` into stored form, a no-op on a dense axis. function unmatricize(style::ReshapeFusion, m, codomain_axes, domain_axes) - return reshape(m, (codomain_axes..., domain_axes...)) + return reshape(m, (codomain_axes..., conj.(domain_axes)...)) end diff --git a/src/matrixfunctions.jl b/src/matrixfunctions.jl index e0676aa..a26e079 100644 --- a/src/matrixfunctions.jl +++ b/src/matrixfunctions.jl @@ -37,7 +37,7 @@ for f in MATRIX_FUNCTIONS a_mat = matricize(style, a, ndims_codomain) fa_mat = Base.$f(a_mat; kwargs...) codomain_axes, domain_axes = bipartition(axes(a), ndims_codomain) - return unmatricize(style, fa_mat, codomain_axes, domain_axes) + return unmatricize(style, fa_mat, codomain_axes, conj.(domain_axes)) end function $f(a, ndims_codomain::Val; kwargs...) return $f(FusionStyle(a), a, ndims_codomain; kwargs...) diff --git a/test/Project.toml b/test/Project.toml index a8e405f..963ee89 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -36,7 +36,7 @@ Random = "1.10" SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" -TensorAlgebra = "0.15" +TensorAlgebra = "0.16" TensorOperations = "5.1.4" Test = "1.10" TestExtras = "0.3.1" From 6a9338090debe53901ed337d2862bd16d898dcc6 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Wed, 1 Jul 2026 19:33:18 -0400 Subject: [PATCH 2/2] Route bipartition domain axes through a bipartition_axes helper ## Summary Adds `bipartition_axes`, a `bipartition` that presents the domain group un-dualized, and routes the factorization, matrix-function, and `unmatricizeperm` callers through it. This keeps the domain `conj` in one place so the `unmatricize` axis convention is applied consistently, rather than repeating `conj.(axes_domain)` at each call site. The single-factor rank-axis reconstructions keep their local `conj`, since no bipartition is involved there. --- src/factorizations.jl | 16 ++++++++-------- src/matricize.jl | 17 +++++++++++++---- src/matrixfunctions.jl | 4 ++-- 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/src/factorizations.jl b/src/factorizations.jl index 37783d4..0cf471c 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -11,9 +11,9 @@ for f in ( function $f(style::FusionStyle, A, ndims_codomain::Val; kwargs...) A_mat = matricize(style, A, ndims_codomain) X, Y = MatrixAlgebraKit.$f(A_mat; kwargs...) - axes_codomain, axes_domain = bipartition(axes(A), ndims_codomain) + axes_codomain, axes_domain = bipartition_axes(axes(A), ndims_codomain) return unmatricize(style, X, axes_codomain, (conj(axes(X, 2)),)), - unmatricize(style, Y, (axes(Y, 1),), conj.(axes_domain)) + unmatricize(style, Y, (axes(Y, 1),), axes_domain) end function $f(A, ndims_codomain::Val; kwargs...) return $f(FusionStyle(A), A, ndims_codomain; kwargs...) @@ -209,10 +209,10 @@ for f in (:svd_compact, :svd_full, :svd_trunc) function $f(style::FusionStyle, A, ndims_codomain::Val; kwargs...) A_mat = matricize(style, A, ndims_codomain) U, S, Vᴴ = MatrixAlgebraKit.$f(A_mat; kwargs...) - axes_codomain, axes_domain = bipartition(axes(A), ndims_codomain) + axes_codomain, axes_domain = bipartition_axes(axes(A), ndims_codomain) return unmatricize(style, U, axes_codomain, (conj(axes(U, 2)),)), unmatricize(style, S, (axes(S, 1),), (conj(axes(S, 2)),)), - unmatricize(style, Vᴴ, (axes(Vᴴ, 1),), conj.(axes_domain)) + unmatricize(style, Vᴴ, (axes(Vᴴ, 1),), axes_domain) end function $f(A, ndims_codomain::Val; kwargs...) return $f(FusionStyle(A), A, ndims_codomain; kwargs...) @@ -442,8 +442,8 @@ right_null function right_null!!(style::FusionStyle, A, ndims_codomain::Val; kwargs...) A_mat = matricize(style, A, ndims_codomain) Nᴴ = MatrixAlgebraKit.right_null!(A_mat; kwargs...) - axes_domain = last(bipartition(axes(A), ndims_codomain)) - return unmatricize(style, Nᴴ, (axes(Nᴴ, 1),), conj.(axes_domain)) + _, axes_domain = bipartition_axes(axes(A), ndims_codomain) + return unmatricize(style, Nᴴ, (axes(Nᴴ, 1),), axes_domain) end function right_null!!(A, ndims_codomain::Val; kwargs...) return right_null!!(FusionStyle(A), A, ndims_codomain; kwargs...) @@ -612,8 +612,8 @@ one function one!!(style::FusionStyle, A, ndims_codomain::Val; kwargs...) A_mat = matricize(style, A, ndims_codomain) MatrixAlgebraKit.one!(A_mat) - codomain_axes, domain_axes = bipartition(axes(A), ndims_codomain) - return unmatricize(style, A_mat, codomain_axes, conj.(domain_axes)) + codomain_axes, domain_axes = bipartition_axes(axes(A), ndims_codomain) + return unmatricize(style, A_mat, codomain_axes, domain_axes) end function one!!(A, ndims_codomain::Val; kwargs...) return one!!(FusionStyle(A), A, ndims_codomain; kwargs...) diff --git a/src/matricize.jl b/src/matricize.jl index 164edd8..ea71bb8 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -211,6 +211,15 @@ function unmatricize(m, codomain_axes, domain_axes) return unmatricize(FusionStyle(m), m, codomain_axes, domain_axes) end +# Split `axes` into its codomain and domain groups like `bipartition`, but present the domain +# group codomain-facing (un-dualized) with `conj`, the convention `unmatricize` and `similar_map` +# take. The domain axes `bipartition` reads off `axes(a)` are in the stored (dualized) form, so +# this bridges from `axes(a)` to the `unmatricize` axis convention (a no-op on dense axes). +function bipartition_axes(t::Tuple, split...) + codomain_axes, domain_axes = bipartition(t, split...) + return codomain_axes, conj.(domain_axes) +end + # Inverse-bipermutation form: split `axes_dest` into codomain/domain groups reordered by the # inverse bipermutation, unmatricize in that order, then permute back. function unmatricizeperm( @@ -226,8 +235,8 @@ function unmatricizeperm( invbiperm = BiTuple(invperm_codomain, invperm_domain) length(axes_dest) == length(invbiperm) || throw(ArgumentError("axes do not match permutation")) - codomain_axes, domain_axes = bipartition(axes_dest, invbiperm) - a12 = unmatricize(style, m, codomain_axes, conj.(domain_axes)) + codomain_axes, domain_axes = bipartition_axes(axes_dest, invbiperm) + a12 = unmatricize(style, m, codomain_axes, domain_axes) biperm_dest = BiTuple(Tuple(invperm(invbiperm)), Val(length_codomain(axes_dest))) return bipermutedims(a12, biperm_dest) end @@ -245,8 +254,8 @@ function unmatricizeperm!( invbiperm = BiTuple(invperm_codomain, invperm_domain) ndims(a_dest) == length(invbiperm) || throw(ArgumentError("destination does not match permutation")) - codomain_axes, domain_axes = bipartition(axes(a_dest), invbiperm) - a_perm = unmatricize(style, m, codomain_axes, conj.(domain_axes)) + codomain_axes, domain_axes = bipartition_axes(axes(a_dest), invbiperm) + a_perm = unmatricize(style, m, codomain_axes, domain_axes) biperm_dest = BiTuple(Tuple(invperm(invbiperm)), Val(length_codomain(axes(a_dest)))) return bipermutedims!(a_dest, a_perm, biperm_dest) end diff --git a/src/matrixfunctions.jl b/src/matrixfunctions.jl index a26e079..1734a9e 100644 --- a/src/matrixfunctions.jl +++ b/src/matrixfunctions.jl @@ -36,8 +36,8 @@ for f in MATRIX_FUNCTIONS function $f(style::FusionStyle, a, ndims_codomain::Val; kwargs...) a_mat = matricize(style, a, ndims_codomain) fa_mat = Base.$f(a_mat; kwargs...) - codomain_axes, domain_axes = bipartition(axes(a), ndims_codomain) - return unmatricize(style, fa_mat, codomain_axes, conj.(domain_axes)) + codomain_axes, domain_axes = bipartition_axes(axes(a), ndims_codomain) + return unmatricize(style, fa_mat, codomain_axes, domain_axes) end function $f(a, ndims_codomain::Val; kwargs...) return $f(FusionStyle(a), a, ndims_codomain; kwargs...)