Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions test/enzyme-indexmanipulations-flip-twist/flip.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using Test, TestExtras
using TensorKit
using Enzyme, EnzymeTestUtils
using Random

spacelist = ad_spacelist(fast_tests)
eltypes = (Float64, ComplexF64)

@timedtestset "Enzyme - Index Manipulations (flip):" begin
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T) TA ($TA)" for V in spacelist, T in eltypes, TA in (Duplicated,)
atol = default_tol(T)
rtol = default_tol(T)
has_braiding = BraidingStyle(sectortype(eltype(V))) isa HasBraiding
if has_braiding
A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')
EnzymeTestUtils.test_reverse(flip, TA, (A, TA), (1, Const); atol, rtol, fkwargs = (inv = false,))
EnzymeTestUtils.test_reverse(flip, TA, (A, TA), [1, 3]; atol, rtol, fkwargs = (inv = true,))
EnzymeTestUtils.test_reverse(flip, TA, (A, TA), (1, Const); atol, rtol)
EnzymeTestUtils.test_reverse(flip, TA, (A, TA), ([1, 3], Const); atol, rtol)
end
end
end
24 changes: 24 additions & 0 deletions test/enzyme-indexmanipulations-flip-twist/twist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using Test, TestExtras
using TensorKit
using TensorOperations
using VectorInterface: Zero, One
using Enzyme, EnzymeTestUtils
using Random

spacelist = ad_spacelist(fast_tests)
eltypes = (Float64, ComplexF64)

@timedtestset "Enzyme - Index Manipulations (twist):" begin
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, TA in (Duplicated,)
atol = default_tol(T)
rtol = default_tol(T)
A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')
has_braiding = BraidingStyle(sectortype(eltype(V))) isa HasBraiding
if has_braiding && !(T <: Real && !(sectorscalartype(sectortype(A)) <: Real))
EnzymeTestUtils.test_reverse(twist!, TA, (copy(A), TA), (1, Const); atol, rtol, fkwargs = (inv = false,))
EnzymeTestUtils.test_reverse(twist!, TA, (copy(A), TA), ([1, 3], Const); atol, rtol, fkwargs = (inv = true,))
EnzymeTestUtils.test_reverse(twist!, TA, (copy(A), TA), (1, Const); atol, rtol)
EnzymeTestUtils.test_reverse(twist!, TA, (copy(A), TA), ([1, 3], Const); atol, rtol)
end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ is_ci = get(ENV, "CI", "false") == "true"
Tαs = is_ci ? (Active,) : (Active, Const)
Tβs = is_ci ? (Active,) : (Active, Const)

@timedtestset "Enzyme - Index Manipulations (braid!):" begin
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T) Tα $Tα Tβ $Tβ" for V in spacelist, T in eltypes, Tα in Tαs, Tβ in Tβs
atol = default_tol(T)
rtol = default_tol(T)
Vstr = TensorKit.type_repr(sectortype(eltype(V)))
A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5])
@timedtestset "Enzyme - Index Manipulations (braid!) $(TensorKit.type_repr(sectortype(eltype(V)))) ($T) Tα $Tα Tβ $Tβ" for V in spacelist, T in eltypes, Tα in Tαs, Tβ in Tβs
atol = default_tol(T)
rtol = default_tol(T)
has_braiding = BraidingStyle(sectortype(eltype(V))) isa HasBraiding
if has_braiding
A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')
α = randn(T)
β = randn(T)
p = randcircshift(numout(A), numin(A))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ is_ci = get(ENV, "CI", "false") == "true"
Tαs = is_ci ? (Active,) : (Active, Const)
Tβs = is_ci ? (Active,) : (Active, Const)

@timedtestset "Enzyme - Index Manipulations (permute!):" begin
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
if VERSION >= v"1.11.0-rc" # segfault issues on 1.10
@timedtestset "Enzyme - Index Manipulations (permute!): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
println(TensorKit.type_repr(sectortype(eltype(V))))
atol = default_tol(T)
rtol = default_tol(T)
symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding

symmetricbraiding && @timedtestset "permute! Tα $Tα, Tβ $Tβ" for Tα in Tαs, Tβ in Tβs
A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5])
A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')
α = randn(T)
β = randn(T)
p = randindextuple(numind(A))
Expand Down
35 changes: 35 additions & 0 deletions test/enzyme-indexmanipulations-transform/transpose.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using Test, TestExtras
using TensorKit
using TensorOperations
using VectorInterface: Zero, One
using Enzyme, EnzymeTestUtils
using Random

spacelist = ad_spacelist(fast_tests)
eltypes = (Float64, ComplexF64)

is_ci = get(ENV, "CI", "false") == "true"

Tαs = is_ci ? (Active,) : (Active, Const)
Tβs = is_ci ? (Active,) : (Active, Const)

@timedtestset "Enzyme - Index Manipulations (transpose!) $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
atol = default_tol(T)
rtol = default_tol(T)
A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')
α = randn(T)
β = randn(T)

# repeat a couple times to get some distribution of arrows
p = randcircshift(numout(A), numin(A))
C = randn!(transpose(A, p))
!is_ci && EnzymeTestUtils.test_reverse(TensorKit.transpose!, Duplicated, (copy(C), Duplicated), (A, Duplicated), (p, Const), (One(), Const), (Zero(), Const); atol, rtol)
@testset for Tα in Tαs, Tβ in Tβs
EnzymeTestUtils.test_reverse(TensorKit.transpose!, Duplicated, (copy(C), Duplicated), (A, Duplicated), (p, Const), (α, Tα), (β, Tβ); atol, rtol)
if !(T <: Real) && !is_ci
EnzymeTestUtils.test_reverse(TensorKit.transpose!, Duplicated, (copy(C), Duplicated), (real(A), Duplicated), (p, Const), (α, Tα), (β, Tβ); atol, rtol)
EnzymeTestUtils.test_reverse(TensorKit.transpose!, Duplicated, (copy(C), Duplicated), (A, Duplicated), (p, Const), (real(α), Tα), (β, Tβ); atol, rtol)
EnzymeTestUtils.test_reverse(TensorKit.transpose!, Duplicated, (copy(C), Duplicated), (real(A), Duplicated), (p, Const), (real(α), Tα), (β, Tβ); atol, rtol)
end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ eltypes = (Float64, ComplexF64)
@timedtestset verbose = true "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, TA in (Duplicated,)
atol = default_tol(T)
rtol = default_tol(T)
A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5])
A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')
@testset for insertunit in (insertleftunit, insertrightunit)
EnzymeTestUtils.test_reverse(insertunit, TA, (A, TA), (Val(1), Const); atol, rtol)
EnzymeTestUtils.test_reverse(insertunit, TA, (A, TA), (Val(4), Const); atol, rtol)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ eltypes = (Float64, ComplexF64)
@timedtestset verbose = true "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, TB in (Duplicated,)
atol = default_tol(T)
rtol = default_tol(T)
A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5])
A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')
for i in 1:2
B = insertleftunit(A, i; dual = rand(Bool))
EnzymeTestUtils.test_reverse(removeunit, TB, (B, TB), (Val(i), Const); atol, rtol, fkwargs = (copy = false,))
Expand Down
19 changes: 0 additions & 19 deletions test/enzyme-indexmanipulations/flip.jl

This file was deleted.

37 changes: 0 additions & 37 deletions test/enzyme-indexmanipulations/transpose.jl

This file was deleted.

23 changes: 0 additions & 23 deletions test/enzyme-indexmanipulations/twist.jl

This file was deleted.

Loading