Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
version = "0.16.4"
version = "0.16.5"
authors = ["ITensor developers <support@itensor.org> and contributors"]

[workspace]
Expand Down
2 changes: 1 addition & 1 deletion src/TensorAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export contract, contract!, eig_full, eig_trunc, eig_vals, eigh_full, eigh_trunc
if VERSION >= v"1.11.0-DEV.469"
eval(
Meta.parse(
"public biperm, bipartition, contractopadd!, label_type, matricizeopperm, to_range, zero!, scale!, permuteddims, PermutedDims, conjed, ConjArray"
"public biperm, bipartition, contractopadd!, label_type, matricizeopperm, permutedims, permutedims!, to_range, zero!, scale!, permuteddims, PermutedDims, conjed, ConjArray"
)
)
end
Expand Down
50 changes: 49 additions & 1 deletion src/permutedimsadd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ function bipermutedimsopadd!(
check_input(bipermutedimsopadd!, dest, op, src, perm_codomain, perm_domain)

dest′ = SV.StridedView(dest)
src′ = permutedims(SV.StridedView(src), perm)
src′ = Base.permutedims(SV.StridedView(src), perm)
_opadd!(dest′, op, src′, α, β)
return dest
end
Expand Down Expand Up @@ -181,3 +181,51 @@ end
`dest .+= src`.
"""
add!(dest, src) = add!(dest, src, true, true)

# ---------------------------------------------------------------------------- #
# permutedims — out-of-place, optional bipartition
# ---------------------------------------------------------------------------- #

"""
permutedims!(dest, a, perm)
permutedims!(dest, a, perm_codomain, perm_domain)

In-place counterpart of [`permutedims`](@ref): write the permuted `a` into `dest`.
Both forms forward to `bipermutedimsopadd!` with `α, β = true, false`; the flat form
passes an empty domain permutation.
"""
function permutedims!(dest, a, perm)
return bipermutedimsopadd!(dest, identity, a, perm, (), true, false)
end
function permutedims!(dest, a, perm_codomain, perm_domain)
return bipermutedimsopadd!(dest, identity, a, perm_codomain, perm_domain, true, false)
end

"""
permutedims(a, perm)
permutedims(a, perm_codomain, perm_domain)

Out-of-place permutation of `a`, mirroring `TensorKit.permute`. The single-permutation
form reorders every dimension into `perm`, giving an all-codomain result. The
two-permutation form additionally splits the dimensions into a codomain/domain
bipartition, with `perm_codomain` selecting the codomain dimensions and `perm_domain`
the domain ones.

Allocates the destination with [`similar_map`](@ref) and materializes it through
[`permutedims!`](@ref), so any operand implementing the `permutedimsopadd!` /
`bipermutedimsopadd!` interface (a dense array, a graded array, a `TensorMap`) is
permuted with no dedicated method. A dense operand ignores the bipartition and stores
the result flat.
"""
function permutedims(a, perm)
dest = similar_map(a, eltype(a), map(p -> axes(a, p), perm), ())
return permutedims!(dest, a, perm)
end
function permutedims(a, perm_codomain, perm_domain)
codomain_axes = map(p -> axes(a, p), perm_codomain)
# `similar_map` dualizes the domain axes it is given (as it does for `similar_map`
# itself), so pass them pre-conjugated to land back on the source's own axes.
domain_axes = map(p -> conj(axes(a, p)), perm_domain)
dest = similar_map(a, eltype(a), codomain_axes, domain_axes)
return permutedims!(dest, a, perm_codomain, perm_domain)
end
4 changes: 2 additions & 2 deletions test/test_exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ using Test: @test, @testset
exports,
[
:biperm, :bipartition, :contractopadd!, :label_type, :matricizeopperm,
:to_range, :zero!, :scale!, :permuteddims, :PermutedDims, :conjed,
:ConjArray,
:permutedims, :permutedims!, :to_range, :zero!, :scale!, :permuteddims,
:PermutedDims, :conjed, :ConjArray,
]
)
end
Expand Down
27 changes: 25 additions & 2 deletions test/test_permutedimsadd.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Adapt: adapt
using JLArrays: JLArray
using TensorAlgebra: ConjArray, PermutedDims, add!, bipermutedimsopadd!, conjed,
permuteddims, permutedimsadd!, permutedimsopadd!
using TensorAlgebra: TensorAlgebra, ConjArray, PermutedDims, add!, bipermutedimsopadd!,
conjed, permuteddims, permutedimsadd!, permutedimsopadd!
using Test: @test, @testset

# A non-`AbstractArray` operand, to check that `permuteddims` falls back to `PermutedDims`.
Expand Down Expand Up @@ -168,6 +168,29 @@ end
bipermutedimsopadd!(dest, identity, src, (), (), T(3), T(5))
@test dest[] == 3 * 7 + 5 * 2
end
@testset "permutedims / permutedims! (out-of-place, arraytype=$arrayt)" for arrayt in
(
Array,
JLArray,
)
dev = adapt(arrayt)
a = dev(randn(2, 3, 4))
ref = permutedims(a, (3, 1, 2))
# Flat form reorders all dimensions; on a dense array the bipartition form
# ignores the split and stores the result flat in the concatenated order.
@test TensorAlgebra.permutedims(a, (3, 1, 2)) == ref
@test TensorAlgebra.permutedims(a, (3, 1), (2,)) == ref
@test TensorAlgebra.permutedims(a, (), (3, 1, 2)) == ref
dest = dev(zeros(4, 2, 3))
@test TensorAlgebra.permutedims!(dest, a, (3, 1, 2)) === dest
@test dest == ref
dest = dev(zeros(4, 2, 3))
TensorAlgebra.permutedims!(dest, a, (3, 1), (2,))
@test dest == ref
dest = dev(zeros(4, 2, 3))
TensorAlgebra.permutedims!(dest, a, (), (3, 1, 2))
@test dest == ref
end
@testset "permutedimsopadd! (arraytype=$arrayt)" for arrayt in (Array,)
dev = adapt(arrayt)
a = dev(randn(ComplexF64, 2, 2, 2))
Expand Down
42 changes: 40 additions & 2 deletions test/test_tensorkitext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ using StableRNGs: StableRNG
using TensorAlgebra: TensorAlgebra, checked_project, contract, matricize, project,
project_map, projectto!, rand_map, randn_map, similar_map, tryflattenlinear,
unmatricize, zeros_map
using TensorKit: @tensor, AbstractTensorMap, Rep, SU₂, TensorMap, U₁, fuse, isomorphism,
randn, space, storagetype, ←, ⊗
using TensorKit: @tensor, AbstractTensorMap, Rep, SU₂, TensorMap, U₁, dual, fuse,
isomorphism, randn, space, storagetype, ←, ⊗
using Test: @test, @test_throws, @testset

# A shared bond contracts when it sits in one operand's domain and the other's codomain, i.e.
Expand Down Expand Up @@ -170,6 +170,44 @@ using Test: @test, @test_throws, @testset
@test norm(project(elt[0, 1], (W,))) == 0
end

# `permutedims` reorders a `TensorMap`'s indices; the flat form gives an all-codomain
# result, and the bipartition form re-expresses the requested codomain/domain split. Both
# ride TensorKit's `permute` through the `bipermutedimsopadd!` interface, no dedicated method.
@testset "permutedims on a TensorMap" begin
A1 = Rep[U₁](0 => 2, 1 => 1)
A2 = Rep[U₁](0 => 1, 1 => 1)
B = Rep[U₁](0 => 1, -1 => 2)
t = randn(rng, elt, A1 ⊗ A2, B)
ref = permutedims(convert(Array, t), (3, 1, 2))

# Flat: all-codomain result whose dense form matches the plain permutation.
tf = TensorAlgebra.permutedims(t, (3, 1, 2))
@test space(tf) == ((dual(B) ⊗ A1 ⊗ A2) ← one(A1))
@test convert(Array, tf) == ref

# Bipartition selecting the original split reproduces the space and data exactly.
tb = TensorAlgebra.permutedims(t, (1, 2), (3,))
@test space(tb) == space(t)
@test convert(Array, tb) == convert(Array, t)

# Repartitioning form: move the domain index into the codomain while reordering.
tr = TensorAlgebra.permutedims(t, (3, 1), (2,))
@test space(tr) == ((dual(B) ⊗ A1) ← dual(A2))
@test convert(Array, tr) == ref

# In-place form writes into a matching destination.
dest = similar_map(t, elt, (dual(B), A1, A2), ())
@test TensorAlgebra.permutedims!(dest, t, (3, 1, 2)) === dest
@test convert(Array, dest) == ref

# Empty codomain: every index lands in the domain, the mirror of the flat all-codomain
# form. The domain space type is read from the operand, so the empty codomain tuple does
# not need to carry it.
te = TensorAlgebra.permutedims(t, (), (3, 1, 2))
@test space(te) == (one(A1) ← (B ⊗ dual(A1) ⊗ dual(A2)))
@test convert(Array, te) == ref
end

# A linear combination of `TensorMap`s flattens to a `LinearBroadcasted` that materializes
# into a `TensorMap` destination via `copyto!`; a nonlinear broadcast has no linear form.
@testset "linear-combination broadcast" begin
Expand Down
Loading