From 306c42d3431f5c77a9e538c5ddaa65a97d4f3223 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Sat, 18 Apr 2026 00:47:35 +0200 Subject: [PATCH 01/13] pullback reorganization --- .../MatrixAlgebraKitEnzymeExt.jl | 14 +- .../MatrixAlgebraKitMooncakeExt.jl | 30 +- src/pullbacks/eig.jl | 22 ++ src/pullbacks/eigh.jl | 22 ++ src/pullbacks/lq.jl | 123 +++++-- src/pullbacks/qr.jl | 121 +++++-- src/pullbacks/svd.jl | 335 ++++++++++++------ test/testsuite/ad_utils.jl | 162 +-------- test/testsuite/chainrules.jl | 14 +- 9 files changed, 455 insertions(+), 388 deletions(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 85072c949..bcd99243a 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -240,19 +240,7 @@ for f in (:svd_compact!, :svd_full!) USVᴴval = something(cache_USVᴴ, USVᴴ.val) if !isa(A, Const) minmn = min(size(A.val)...) - if $(f == svd_compact!) # compact - svd_pullback!(A.dval, Aval, USVᴴval, dUSVᴴ) - else # full - # TODO: revisit this once `svd_pullback` supports `svd_full` output and adjoints - U, S, Vᴴ = USVᴴval - vU = view(U, :, 1:minmn) - vS = Diagonal(view(diagview(S), 1:minmn)) - vVᴴ = view(Vᴴ, 1:minmn, :) - vdU = view(dUSVᴴ[1], :, 1:minmn) - vdS = Diagonal(view(diagview(dUSVᴴ[2]), 1:minmn)) - vdVᴴ = view(dUSVᴴ[3], 1:minmn, :) - svd_pullback!(A.dval, Aval, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) - end + svd_pullback!(A.dval, Aval, USVᴴval, dUSVᴴ) end !isa(USVᴴ, Const) && make_zero!(USVᴴ.dval) return (nothing, nothing, nothing) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index ef32d6de4..cd3eae8e0 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -418,28 +418,17 @@ for (f!, f) in ( @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) A, dA = arrayify(A_dA) - Ac = copy(A) USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ) U, dU = arrayify(USVᴴ[1], dUSVᴴ[1]) S, dS = arrayify(USVᴴ[2], dUSVᴴ[2]) Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) + Ac = copy(A) USVᴴc = copy.(USVᴴ) output = $f!(A, USVᴴ, Mooncake.primal(alg_dalg)) function svd_adjoint(::NoRData) copy!(A, Ac) - if $(f! == svd_compact!) - svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) - else # full - minmn = min(size(A)...) - vU = view(U, :, 1:minmn) - vS = Diagonal(diagview(S)[1:minmn]) - vVᴴ = view(Vᴴ, 1:minmn, :) - vdU = view(dU, :, 1:minmn) - vdS = Diagonal(diagview(dS)[1:minmn]) - vdVᴴ = view(dVᴴ, 1:minmn, :) - svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) - end + svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) copy!(U, USVᴴc[1]) copy!(S, USVᴴc[2]) copy!(Vᴴ, USVᴴc[3]) @@ -448,7 +437,7 @@ for (f!, f) in ( zero!(dVᴴ) return NoRData(), NoRData(), NoRData(), NoRData() end - return CoDual(output, dUSVᴴ), svd_adjoint + return USVᴴ_dUSVᴴ, svd_adjoint end @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual) @@ -465,18 +454,7 @@ for (f!, f) in ( U, dU = arrayify(U, dU_) S, dS = arrayify(S, dS_) Vᴴ, dVᴴ = arrayify(Vᴴ, dVᴴ_) - if $(f == svd_compact) - svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) - else # full - minmn = min(size(A)...) - vU = view(U, :, 1:minmn) - vS = Diagonal(view(diagview(S), 1:minmn)) - vVᴴ = view(Vᴴ, 1:minmn, :) - vdU = view(dU, :, 1:minmn) - vdS = Diagonal(view(diagview(dS), 1:minmn)) - vdVᴴ = view(dVᴴ, 1:minmn, :) - svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) - end + svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) zero!(dU) zero!(dS) zero!(dVᴴ) diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index a03eb3c4a..7b78121b7 100644 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -201,3 +201,25 @@ function eig_vals_pullback!( ΔDV = (diagonal(ΔD), nothing) return eig_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol) end + +""" + remove_eig_gauge_dependence!(ΔV, D, V; degeneracy_atol = ...) + +Remove the gauge-dependent part from the cotangent `ΔV` of the eigenvector matrix `V`. The +eigenvectors are only determined up to a scalar factor (or an abitrary linear transformation +across eigenvectors associated with degenerate eigenvalues), so the corresponding components of +`ΔV` are projected out. +""" +function remove_eig_gauge_dependence!( + ΔV, D, V, ind = axes(ΔV, 2); + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) + ) + length(ind) == size(ΔV, 2) || throw(DimensionMismatch()) + indV = axes(V, 2)[ind] + Vp = view(V, :, indV) + Ddiag = view(diagview(D), indV) + gaugepart = Vp' * ΔV + gaugepart[abs.(transpose(Ddiag) .- Ddiag) .>= degeneracy_atol] .= 0 + mul!(ΔV, Vp / (Vp' * Vp), gaugepart, -1, 1) + return ΔV +end diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index db78bd6e7..3b517b977 100644 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -191,3 +191,25 @@ function eigh_vals_pullback!( ΔDV = (diagonal(ΔD), nothing) return eigh_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol) end + +""" + remove_eigh_gauge_dependence!(ΔV, D, V; degeneracy_atol = ...) + +Remove the gauge-dependent part from the cotangent `ΔV` of the Hermitian eigenvector matrix +`V`. The eigenvectors are only determined up to a complex phase (or a unitary transformation +across eigenvectors associated with degenerate eigenvalues), so the corresponding anti-Hermitian +components of `V' * ΔV` are projected out. +""" +function remove_eigh_gauge_dependence!( + ΔV, D, V, ind = axes(ΔV, 2); + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) + ) + length(ind) == size(ΔV, 2) || throw(DimensionMismatch()) + indV = axes(V, 2)[ind] + Vp = view(V, :, indV) + Ddiag = view(diagview(D), indV) + gaugepart = project_antihermitian!(Vp' * ΔV) + gaugepart[abs.(transpose(Ddiag) .- Ddiag) .>= degeneracy_atol] .= 0 + mul!(ΔV, Vp, gaugepart, -1, 1) + return ΔV +end diff --git a/src/pullbacks/lq.jl b/src/pullbacks/lq.jl index 1a41c246c..ea088839c 100644 --- a/src/pullbacks/lq.jl +++ b/src/pullbacks/lq.jl @@ -1,30 +1,45 @@ lq_rank(L; kwargs...) = qr_rank(L; kwargs...) -function check_lq_cotangents( +function check_and_prepare_lq_cotangents( L, Q, ΔL, ΔQ, p::Int; gauge_atol::Real = default_pullback_gauge_atol(ΔQ) ) - minmn = min(size(L, 1), size(Q, 2)) + m, n = size(L, 1), size(Q, 2) + minmn = min(m, n) Δgauge = abs(zero(eltype(Q))) + Q₁ = view(Q, 1:p, :) + ΔQ₁ = zero!(similar(Q₁)) if !iszerotangent(ΔQ) - ΔQ₂ = view(ΔQ, (p + 1):minmn, :) - ΔQ₃ = ΔQ[(minmn + 1):size(Q, 1), :] - Δgauge_Q = norm(ΔQ₂, Inf) - Q₁ = view(Q, 1:p, :) - ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' - mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁, -1, 1) - Δgauge_Q = max(Δgauge_Q, norm(ΔQ₃, Inf)) + size(ΔQ) == size(Q) || throw(DimensionMismatch("ΔQ must have the same size as Q")) + ΔQ₁ .= view(ΔQ, 1:p, 1:n) + if p == minmn # full rank case, ΔQ₃ contains gauge-invariant information along Q₁ + Q₃ = view(Q, (minmn + 1):size(Q, 1), :) + ΔQ₃ = view(ΔQ, (minmn + 1):size(Q, 1), :) + ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' + mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁, -1, 1) + Δgauge_Q = norm(ΔQ₃, Inf) + mul!(ΔQ₁, ΔQ₃Q₁ᴴ', Q₃, -1, 1) + else + ΔQ₂ = view(ΔQ, (p + 1):size(ΔQ, 1), :) + Δgauge_Q = norm(ΔQ₂, Inf) + end Δgauge = max(Δgauge, Δgauge_Q) end if !iszerotangent(ΔL) - ΔL22 = view(ΔL, (p + 1):size(ΔL, 1), (p + 1):minmn) - Δgauge_L = norm(view(ΔL22, lowertriangularind(ΔL22)), Inf) - Δgauge_L = max(Δgauge_L, norm(view(ΔL22, diagind(ΔL22)), Inf)) + size(ΔL) == size(L) || throw(DimensionMismatch("ΔL must have the same size as L")) + ΔL₁₁ = LowerTriangular(view(ΔL, 1:p, 1:p)) + ΔL₂₁ = view(ΔL, (p + 1):size(ΔL, 1), 1:p) + ΔL₂₂ = view(ΔL, (p + 1):size(ΔL, 1), (p + 1):minmn) + Δgauge_L = norm(view(ΔL₂₂, lowertriangularind(ΔL₂₂)), Inf) + Δgauge_L = max(Δgauge_L, norm(view(ΔL₂₂, diagind(ΔL₂₂)), Inf)) Δgauge = max(Δgauge, Δgauge_L) + else + ΔL₁₁ = nothing + ΔL₂₁ = nothing end Δgauge ≤ gauge_atol || @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - return nothing + return ΔL₁₁, ΔL₂₁, ΔQ₁ end """ @@ -53,33 +68,21 @@ function lq_pullback!( L, Q = LQ m = size(L, 1) n = size(Q, 2) - minmn = min(m, n) p = lq_rank(L; rank_atol) + (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of L*Q ($m, $n)")) - ΔL, ΔQ = ΔLQ - - Q₁ = view(Q, 1:p, :) L₁₁ = LowerTriangular(view(L, 1:p, 1:p)) + L₂₁ = view(L, (p + 1):m, 1:p) + Q₁ = view(Q, 1:p, :) + ΔA₁ = view(ΔA, 1:p, :) ΔA₂ = view(ΔA, (p + 1):m, :) - check_lq_cotangents(L, Q, ΔL, ΔQ, p; gauge_atol) + ΔL, ΔQ = ΔLQ + ΔL₁₁, ΔL₂₁, ΔQ₁ = check_and_prepare_lq_cotangents(L, Q, ΔL, ΔQ, p; gauge_atol) - ΔQ̃ = zero!(similar(Q, (p, n))) - if !iszerotangent(ΔQ) - ΔQ₁ = view(ΔQ, 1:p, :) - copy!(ΔQ̃, ΔQ₁) - if minmn < size(Q, 1) - ΔQ₃ = view(ΔQ, (minmn + 1):size(ΔQ, 1), :) - Q₃ = view(Q, (minmn + 1):size(Q, 1), :) - ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' - ΔQ̃ = mul!(ΔQ̃, ΔQ₃Q₁ᴴ', Q₃, -1, 1) - end - end if !iszerotangent(ΔL) && m > p - L₂₁ = view(L, (p + 1):m, 1:p) - ΔL₂₁ = view(ΔL, (p + 1):m, 1:p) - ΔQ̃ = mul!(ΔQ̃, L₂₁' * ΔL₂₁, Q₁, -1, 1) + ΔQ₁ = mul!(ΔQ₁, L₂₁' * ΔL₂₁, Q₁, -1, 1) # Adding ΔA₂ contribution ΔA₂ = mul!(ΔA₂, ΔL₂₁, Q₁, 1, 1) end @@ -87,19 +90,15 @@ function lq_pullback!( # construct M M = zero!(similar(L, (p, p))) if !iszerotangent(ΔL) - ΔL₁₁ = LowerTriangular(view(ΔL, 1:p, 1:p)) M = mul!(M, L₁₁', ΔL₁₁, 1, 1) end - M = mul!(M, ΔQ̃, Q₁', -1, 1) + M = mul!(M, ΔQ₁, Q₁', -1, 1) view(M, uppertriangularind(M)) .= conj.(view(M, lowertriangularind(M))) if eltype(M) <: Complex Md = diagview(M) Md .= real.(Md) end - ldiv!(L₁₁', M) - ldiv!(L₁₁', ΔQ̃) - ΔA₁ = mul!(ΔA₁, M, Q₁, +1, 1) - ΔA₁ .+= ΔQ̃ + ΔA₁ .+= ldiv!(L₁₁', mul!(ΔQ₁, M, Q₁, +1, 1)) return ΔA end @@ -134,3 +133,51 @@ function lq_null_pullback!( end return ΔA end + + +""" + remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = ...) + +Remove the gauge-dependent part from the cotangents `ΔL` and `ΔQ` of the LQ factors `L` and +`Q`. For the full LQ decomposition, the extra rows of `Q` beyond the rank `r` are not uniquely +determined by `A`, so the corresponding part of `ΔQ` is projected to remove this ambiguity. +Additionally, columns of `ΔL` beyond the rank are zeroed out. +""" +function remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(L)) + r = MatrixAlgebraKit.lq_rank(L; rank_atol) + minmn = min(size(A)...) + Q₁ = view(Q, 1:r, :) + ΔQ₂ = view(ΔQ, (r + 1):minmn, :) + ΔQ₂ .= 0 + ΔQ₃ = view(ΔQ, (minmn + 1):size(ΔQ, 1), :) # extra rows in the case of lq_full + if r == minmn + ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' + mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁) + else # rank-deficient case, no gauge-invariant information + ΔQ₃ .= 0 + end + ΔL₂₂ = view(ΔL, (r + 1):size(ΔL, 1), (r + 1):minmn) + diagview(ΔL₂₂) .= 0 + view(ΔL₂₂, lowertriangularind(ΔL₂₂)) .= 0 + return ΔL, ΔQ +end + +""" + remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) + +Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the LQ null space `Nᴴ`. The null +space is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the row span of +the compact LQ factor `Q₁`. +""" +function remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) + return mul!(ΔNᴴ, ΔNᴴ * Nᴴ', Nᴴ, -1, 1) +end + +""" + remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) + +Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the right null space `Nᴴ`. The +null space basis is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the +row span of the compact LQ factor `Q₁` of `A`. +""" +remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) = remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) diff --git a/src/pullbacks/qr.jl b/src/pullbacks/qr.jl index 70c5aa89a..055198235 100644 --- a/src/pullbacks/qr.jl +++ b/src/pullbacks/qr.jl @@ -1,31 +1,47 @@ qr_rank(R; rank_atol = default_pullback_rank_atol(R)) = @something findlast(>=(rank_atol) ∘ abs, diagview(R)) 0 -function check_qr_cotangents( +function check_and_prepare_qr_cotangents( Q, R, ΔQ, ΔR, p::Int; gauge_atol::Real = default_pullback_gauge_atol(ΔQ) ) - minmn = min(size(Q, 1), size(R, 2)) + m, n = size(Q, 1), size(R, 2) + minmn = min(m, n) Δgauge = abs(zero(eltype(Q))) + Q₁ = view(Q, :, 1:p) + ΔQ₁ = zero!(similar(Q₁)) if !iszerotangent(ΔQ) - ΔQ₂ = view(ΔQ, :, (p + 1):minmn) - ΔQ₃ = ΔQ[:, (minmn + 1):size(Q, 2)] # extra columns in the case of qr_full - Δgauge_Q = norm(ΔQ₂, Inf) - Q₁ = view(Q, :, 1:p) - Q₁ᴴΔQ₃ = Q₁' * ΔQ₃ - mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃, -1, 1) - Δgauge_Q = max(Δgauge_Q, norm(ΔQ₃, Inf)) + size(ΔQ) == size(Q) || throw(DimensionMismatch("ΔQ must have the same size as Q")) + ΔQ₁ .= view(ΔQ, 1:m, 1:p) + if p == minmn # full rank case, ΔQ₃ contains gauge-invariant information along Q₁ + ΔQ₃ = ΔQ[:, (minmn + 1):size(Q, 2)] # extra columns in the case of qr_full + Q₁ = view(Q, :, 1:minmn) + Q₃ = view(Q, :, (minmn + 1):size(Q, 2)) + Q₁ᴴΔQ₃ = Q₁' * ΔQ₃ + mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃, -1, 1) + Δgauge_Q = norm(ΔQ₃, Inf) + mul!(ΔQ₁, Q₃, Q₁ᴴΔQ₃', -1, 1) + else + ΔQ₂₃ = view(ΔQ, :, (p + 1):size(Q, 2)) + Δgauge_Q = norm(ΔQ₂₃, Inf) + end Δgauge = max(Δgauge, Δgauge_Q) end if !iszerotangent(ΔR) - ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2)) - Δgauge_R = norm(view(ΔR22, uppertriangularind(ΔR22)), Inf) - Δgauge_R = max(Δgauge_R, norm(view(ΔR22, diagind(ΔR22)), Inf)) + size(ΔR) == size(R) || throw(DimensionMismatch("ΔR must have the same size as R")) + ΔR₁₁ = UpperTriangular(view(ΔR, 1:p, 1:p)) + ΔR₁₂ = view(ΔR, 1:p, (p + 1):n) + ΔR₂₂ = view(ΔR, (p + 1):minmn, (p + 1):n) + Δgauge_R = norm(view(ΔR₂₂, uppertriangularind(ΔR₂₂)), Inf) + Δgauge_R = max(Δgauge_R, norm(view(ΔR₂₂, diagind(ΔR₂₂)), Inf)) Δgauge = max(Δgauge, Δgauge_R) + else + ΔR₁₁ = nothing + ΔR₁₂ = nothing end Δgauge ≤ gauge_atol || @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - return nothing + return ΔQ₁, ΔR₁₁, ΔR₁₂ end """ @@ -55,34 +71,22 @@ function qr_pullback!( Q, R = QR m = size(Q, 1) n = size(R, 2) - minmn = min(m, n) - Rd = diagview(R) p = qr_rank(R; rank_atol) + (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of Q*R ($m, $n)")) - ΔQ, ΔR = ΔQR Q₁ = view(Q, :, 1:p) R₁₁ = UpperTriangular(view(R, 1:p, 1:p)) + R₁₂ = view(R, 1:p, (p + 1):n) + ΔA₁ = view(ΔA, :, 1:p) ΔA₂ = view(ΔA, :, (p + 1):n) - check_qr_cotangents(Q, R, ΔQ, ΔR, p; gauge_atol) + ΔQ, ΔR = ΔQR + ΔQ₁, ΔR₁₁, ΔR₁₂ = check_and_prepare_qr_cotangents(Q, R, ΔQ, ΔR, p; gauge_atol) - ΔQ̃ = zero!(similar(Q, (m, p))) - if !iszerotangent(ΔQ) - ΔQ₁ = view(ΔQ, :, 1:p) - copy!(ΔQ̃, ΔQ₁) - if minmn < size(Q, 2) - ΔQ₃ = view(ΔQ, :, (minmn + 1):size(ΔQ, 2)) # extra columns in the case of qr_full - Q₃ = view(Q, :, (minmn + 1):size(Q, 2)) - Q₁ᴴΔQ₃ = Q₁' * ΔQ₃ - ΔQ̃ = mul!(ΔQ̃, Q₃, Q₁ᴴΔQ₃', -1, 1) - end - end if !iszerotangent(ΔR) && n > p - R₁₂ = view(R, 1:p, (p + 1):n) - ΔR₁₂ = view(ΔR, 1:p, (p + 1):n) - ΔQ̃ = mul!(ΔQ̃, Q₁, ΔR₁₂ * R₁₂', -1, 1) + ΔQ₁ = mul!(ΔQ₁, Q₁, ΔR₁₂ * R₁₂', -1, 1) # Adding ΔA₂ contribution ΔA₂ = mul!(ΔA₂, Q₁, ΔR₁₂, 1, 1) end @@ -90,19 +94,15 @@ function qr_pullback!( # construct M M = zero!(similar(R, (p, p))) if !iszerotangent(ΔR) - ΔR₁₁ = UpperTriangular(view(ΔR, 1:p, 1:p)) M = mul!(M, ΔR₁₁, R₁₁', 1, 1) end - M = mul!(M, Q₁', ΔQ̃, -1, 1) + M = mul!(M, Q₁', ΔQ₁, -1, 1) view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M))) if eltype(M) <: Complex Md = diagview(M) Md .= real.(Md) end - rdiv!(M, R₁₁') # R₁₁ is upper triangular - rdiv!(ΔQ̃, R₁₁') - ΔA₁ = mul!(ΔA₁, Q₁, M, +1, 1) - ΔA₁ .+= ΔQ̃ + ΔA₁ .+= rdiv!(mul!(ΔQ₁, Q₁, M, +1, 1), R₁₁') return ΔA end @@ -137,3 +137,50 @@ function qr_null_pullback!( end return ΔA end + +""" + remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = ...) + +Remove the gauge-dependent part from the cotangents `ΔQ` and `ΔR` of the QR factors `Q` and +`R`. For the full QR decomposition, the extra columns of `Q` beyond the rank `r` are not +uniquely determined by `A`, so the corresponding part of `ΔQ` is projected to remove this +ambiguity. Additionally, rows of `ΔR` beyond the rank are zeroed out. +""" +function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(R)) + r = MatrixAlgebraKit.qr_rank(R; rank_atol) + minmn = min(size(A)...) + Q₁ = view(Q, :, 1:r) + ΔQ₂ = view(ΔQ, :, (r + 1):minmn) + ΔQ₂ .= 0 + ΔQ₃ = view(ΔQ, :, (minmn + 1):size(ΔQ, 2)) # extra columns in the case of qr_full + if r == minmn # full rank case, ΔQ₃ contains gauge-invariant information along Q₁ + Q₁ᴴΔQ₃ = Q₁' * ΔQ₃ + mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃) + else # rank-deficient case, no gauge-invariant information + ΔQ₃ .= 0 + end + ΔR₂₂ = view(ΔR, (r + 1):minmn, (r + 1):size(R, 2)) + diagview(ΔR₂₂) .= 0 + view(ΔR₂₂, uppertriangularind(ΔR₂₂)) .= 0 + return ΔQ, ΔR +end + +""" + remove_qr_null_gauge_dependence!(ΔN, A, N) + +Remove the gauge-dependent part from the cotangent `ΔN` of the QR null space `N`. The null +space is only determined up to a unitary rotation, so `ΔN` is projected onto the column span +of the compact QR factor `Q₁`. +""" +function remove_qr_null_gauge_dependence!(ΔN, A, N) + return mul!(ΔN, N, N' * ΔN, -1, 1) +end + +""" + remove_left_null_gauge_dependence!(ΔN, A, N) + +Remove the gauge-dependent part from the cotangent `ΔN` of the left null space `N`. The null +space basis is only determined up to a unitary rotation, so `ΔN` is projected onto the column +span of the compact QR factor `Q₁` of `A`. +""" +remove_left_null_gauge_dependence!(ΔN, A, N) = remove_qr_null_gauge_dependence!(ΔN, A, N) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 01fdc4f70..1d856600a 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -1,11 +1,107 @@ svd_rank(S; rank_atol = default_pullback_rank_atol(S)) = searchsortedlast(S, rank_atol; rev = true) -function check_svd_cotangents(aUΔU, Sr, aVΔV; degeneracy_atol = default_pullback_rank_atol(Sr), gauge_atol = default_pullback_gauge_atol(aUΔU, aVΔV)) - mask = abs.(Sr' .- Sr) .< degeneracy_atol - Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) +function check_and_prepare_svd_cotangents( + U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, r::Int, ind = Colon(); + degeneracy_atol::Real = default_pullback_rank_atol(S), + gauge_atol::Real = default_pullback_gauge_atol(ΔU, ΔSmat, ΔVᴴ) + ) + + m, n = size(U, 1), size(Vᴴ, 2) + minmn = min(m, n) + + U₁ = view(U, :, 1:r) + V₁ᴴ = view(Vᴴ, 1:r, :) + S₁ = view(S, 1:r) + indU = axes(U, 2)[ind] + indV = axes(Vᴴ, 1)[ind] + indS = axes(S, 1)[ind] + Δgauge = zero(eltype(S)) + + if !iszerotangent(ΔU) + ΔgaugeU = zero(eltype(S)) + m == size(ΔU, 1) || throw(DimensionMismatch("first dimension of ΔU ($(size(ΔU, 1))) does not match first dimension of U ($m)")) + length(indU) == size(ΔU, 2) || throw(DimensionMismatch("length of selected U columns ($(length(indU))) does not match second dimension of ΔU ($(size(ΔU, 2)))")) + if indU == 1:r + ΔU₁ = copy(ΔU) + else + ΔU₁ = zero(U₁) + wtmp = similar(U₁, (r,)) + utmp = similar(U₁, (m,)) + for (j, i) in enumerate(indU) + if i <= r + ΔU₁[:, i] .= view(ΔU, :, j) + elseif r == minmn # full rank case, ΔU₃ contains gauge-invariant information along U₁ + mul!(wtmp, U₁', view(ΔU, :, j)) + mul!(ΔU₁, view(U, :, i), wtmp', -1, 1) + utmp .= view(ΔU, :, j) + mul!(utmp, U₁, wtmp, -1, 1) + ΔgaugeU = max(ΔgaugeU, norm(utmp)) + else # remaining columns should be zero + ΔgaugeU = max(ΔgaugeU, norm(view(ΔU, :, j), Inf)) + end + end + end + UᴴΔU₁ = U₁' * ΔU₁ + ΔU₁ = mul!(ΔU₁, U₁, UᴴΔU₁, -1, 1) + aUᴴΔU₁ = project_antihermitian!(UᴴΔU₁) + Δgauge = max(Δgauge, ΔgaugeU) + else + ΔU₁ = nothing + aUᴴΔU₁ = zero!(similar(U₁, (r, r))) + end + if !iszerotangent(ΔVᴴ) + ΔgaugeV = zero(eltype(S)) + n == size(ΔVᴴ, 2) || throw(DimensionMismatch("second dimension of ΔVᴴ ($(size(ΔVᴴ, 2))) does not match second dimension of Vᴴ ($n)")) + length(indV) == size(ΔVᴴ, 1) || throw(DimensionMismatch("length of selected Vᴴ rows ($(length(indV))) does not match first dimension of ΔVᴴ ($(size(ΔVᴴ, 1)))")) + if indV == 1:r + ΔV₁ᴴ = copy(ΔVᴴ) + else + ΔV₁ᴴ = zero(V₁ᴴ) + wtmp = similar(V₁ᴴ, (1, r)) + vtmp = similar(V₁ᴴ, (1, n)) + for (j, i) in enumerate(indV) + if i <= r + ΔV₁ᴴ[i, :] .= view(ΔVᴴ, j, :) + elseif r == minmn # full rank case, ΔV₃ contains gauge-invariant information along Vᴴ₁ + mul!(wtmp, view(ΔVᴴ, j:j, :), V₁ᴴ') + mul!(ΔV₁ᴴ, wtmp', view(Vᴴ, i:i, :), -1, 1) + vtmp .= view(ΔVᴴ, j:j, :) + mul!(vtmp, wtmp, V₁ᴴ, -1, 1) + ΔgaugeV = max(ΔgaugeV, norm(vtmp)) + else # remaining rows should be zero + ΔgaugeV = max(ΔgaugeV, norm(view(ΔVᴴ, j, :), Inf)) + end + end + end + VᴴΔV₁ = V₁ᴴ * ΔV₁ᴴ' + ΔV₁ᴴ = mul!(ΔV₁ᴴ, VᴴΔV₁', V₁ᴴ, -1, 1) + aVᴴΔV₁ = project_antihermitian!(VᴴΔV₁) + Δgauge = max(Δgauge, ΔgaugeV) + else + ΔV₁ᴴ = nothing + aVᴴΔV₁ = zero!(similar(V₁ᴴ, (r, r))) + end + mask = abs.(S₁' .- S₁) .< degeneracy_atol + Δgauge = max(Δgauge, norm(view(aUᴴΔU₁, mask) + view(aVᴴΔV₁, mask), Inf)) + + if !iszerotangent(ΔSmat) + ΔS = diagview(ΔSmat) + length(indS) == length(ΔS) || throw(DimensionMismatch("length of selected S values ($(length(indS))) does not match length of ΔS ($(length(ΔS)))")) + ΔS₁ = zero(S₁) + for (j, i) in enumerate(indS) + if i <= r + ΔS₁[i] = real(ΔS[j]) + else + Δgauge = max(Δgauge, abs(ΔS[j])) + end + end + else + ΔS₁ = nothing + end + Δgauge ≤ gauge_atol || @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - return + return ΔU₁, ΔS₁, ΔV₁ᴴ, aUᴴΔU₁, aVᴴΔV₁ end """ @@ -13,7 +109,7 @@ end ΔA, A, USVᴴ, ΔUSVᴴ, [ind]; rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), - gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ...) ) Adds the pullback from the SVD of `A` to `ΔA` given the output `USVᴴ` of `svd_compact` or @@ -22,9 +118,9 @@ Adds the pullback from the SVD of `A` to `ΔA` given the output `USVᴴ` of `svd In particular, it is assumed that `A ≈ U * S * Vᴴ`, or thus, that no singular values with magnitude less than `rank_atol` are missing from `S`. For the cotangents, an arbitrary number of singular vectors or singular values can be missing, i.e. for a matrix `A` with -size `(m, n)`, `ΔU` and `ΔVᴴ` can have sizes `(m, pU)` and `(pV, n)` respectively, whereas -`diagview(ΔS)` can have length `pS`. In those cases, additionally `ind` is required to -specify which singular vectors and values are present in `ΔU`, `ΔS` and `ΔVᴴ`. +size `(m, n)`, `ΔU` and `ΔVᴴ` can have sizes `(m, p)` and `(p, n)` respectively, whereas +`diagview(ΔS)` can have length `p`. In those cases, an additional list `ind` of length `p` +is required to specify which singular vectors and values are present in `ΔU`, `ΔS` and `ΔVᴴ`. A warning will be printed if the cotangents are not gauge-invariant, i.e. if the anti-hermitian part of `U' * ΔU + Vᴴ * ΔVᴴ'`, restricted to rows `i` and columns `j` for @@ -34,75 +130,38 @@ function svd_pullback!( ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), - gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ...) ) # Extract the SVD components U, Smat, Vᴴ = USVᴴ m, n = size(U, 1), size(Vᴴ, 2) - (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)")) minmn = min(m, n) + (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)")) S = diagview(Smat) - length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)")) r = svd_rank(S; rank_atol) - Ur = view(U, :, 1:r) - Vᴴr = view(Vᴴ, 1:r, :) - Sr = view(S, 1:r) - - # Extract and check the cotangents - ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ - UΔU = fill!(similar(U, (r, r)), 0) - VΔV = fill!(similar(Vᴴ, (r, r)), 0) - if !iszerotangent(ΔU) - m == size(ΔU, 1) || throw(DimensionMismatch("first dimension of ΔU ($(size(ΔU, 1))) does not match first dimension of U ($m)")) - pU = size(ΔU, 2) - pU > r && throw(DimensionMismatch("second dimension of ΔU ($(size(ΔU, 2))) does not match rank of S ($r)")) - indU = axes(U, 2)[ind] - length(indU) == pU || throw(DimensionMismatch("length of selected U columns ($(length(indU))) does not match second dimension of ΔU ($(size(ΔU, 2)))")) - UΔUp = view(UΔU, :, indU) - mul!(UΔUp, Ur', ΔU) - # ΔU -= Ur * UΔUp but one less allocation without overwriting ΔU - ΔU = mul!(copy(ΔU), Ur, UΔUp, -1, 1) - end - if !iszerotangent(ΔVᴴ) - n == size(ΔVᴴ, 2) || throw(DimensionMismatch("second dimension of ΔVᴴ ($(size(ΔVᴴ, 2))) does not match second dimension of Vᴴ ($n)")) - pV = size(ΔVᴴ, 1) - pV > r && throw(DimensionMismatch("first dimension of ΔVᴴ ($(size(ΔVᴴ, 1))) does not match rank of S ($r)")) - indV = axes(Vᴴ, 1)[ind] - length(indV) == pV || throw(DimensionMismatch("length of selected Vᴴ rows ($(length(indV))) does not match first dimension of ΔVᴴ ($(size(ΔVᴴ, 1)))")) - VΔVp = view(VΔV, :, indV) - mul!(VΔVp, Vᴴr, ΔVᴴ') - # ΔVᴴ -= VΔVp' * Vᴴr but one less allocation without overwriting ΔVᴴ - ΔVᴴ = mul!(copy(ΔVᴴ), VΔVp', Vᴴr, -1, 1) - end - # Project onto antihermitian part; hermitian part outside of Grassmann tangent space - aUΔU = project_antihermitian!(UΔU) - aVΔV = project_antihermitian!(VΔV) + U₁ = view(U, :, 1:r) + V₁ᴴ = view(Vᴴ, 1:r, :) + S₁ = view(S, 1:r) - # check whether cotangents arise from gauge-invariance objective function - check_svd_cotangents(aUΔU, Sr, aVΔV; degeneracy_atol, gauge_atol) + ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ + ΔU₁, ΔS₁, ΔV₁ᴴ, aUᴴΔU₁, aVᴴΔV₁ = check_and_prepare_svd_cotangents( + U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, r, ind; degeneracy_atol, gauge_atol + ) - UdΔAV = (aUΔU .+ aVΔV) .* inv_safe.(Sr' .- Sr, degeneracy_atol) .+ - (aUΔU .- aVΔV) .* inv_safe.(Sr' .+ Sr, degeneracy_atol) - if !iszerotangent(ΔSmat) - ΔS = diagview(ΔSmat) - pS = length(ΔS) - indS = axes(S, 1)[ind] - length(indS) == pS || throw(DimensionMismatch("length of selected S diagonals ($(length(indS))) does not match length of ΔS diagonal ($(length(ΔS)))")) - view(diagview(UdΔAV), indS) .+= real.(ΔS) + UdΔAV = (aUᴴΔU₁ .+ aVᴴΔV₁) .* inv_safe.(S₁' .- S₁, degeneracy_atol) .+ + (aUᴴΔU₁ .- aVᴴΔV₁) .* inv_safe.(S₁' .+ S₁, degeneracy_atol) + if !iszerotangent(ΔS₁) + diagview(UdΔAV) .+= real.(ΔS₁) end - ΔA = mul!(ΔA, Ur, UdΔAV * Vᴴr, 1, 1) # add the contribution to ΔA + ΔA = mul!(ΔA, U₁, UdΔAV * V₁ᴴ, 1, 1) # add the contribution to ΔA # Add the remaining contributions - if m > r && !iszerotangent(ΔU) # remaining ΔU is already orthogonal to Ur - Sp = view(S, indU) - Vᴴp = view(Vᴴ, indU, :) - ΔA = mul!(ΔA, ΔU ./ Sp', Vᴴp, 1, 1) + if m > r && !iszerotangent(ΔU₁) # ΔU₁ is already orthogonal to U₁ + ΔA = mul!(ΔA, ΔU₁ ./ S₁', V₁ᴴ, 1, 1) end - if n > r && !iszerotangent(ΔVᴴ) # remaining ΔV is already orthogonal to Vᴴr - Sp = view(S, indV) - Up = view(U, :, indV) - ΔA = mul!(ΔA, Up, Sp .\ ΔVᴴ, 1, 1) + if n > r && !iszerotangent(ΔV₁ᴴ) # ΔV₁ᴴ is already orthogonal to V₁ᴴ + ΔA = mul!(ΔA, U₁, S₁ .\ ΔV₁ᴴ, 1, 1) end return ΔA end @@ -110,7 +169,7 @@ function svd_pullback!( ΔA::Diagonal, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), - gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ...) ) ΔA_full = zero!(similar(ΔA, size(ΔA))) ΔA_full = svd_pullback!(ΔA_full, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol, gauge_atol) @@ -123,7 +182,7 @@ end ΔA, A, USVᴴ, ΔUSVᴴ; rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), - gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ...) ) Adds the pullback from the truncated SVD of `A` to `ΔA`, given the output `USVᴴ` and the @@ -134,7 +193,7 @@ rectangular matrices of left and right singular vectors, and `S` diagonal. For t cotangents, it is assumed that if `ΔU` and `ΔVᴴ` are not zero, then they have the same size as `U` and `Vᴴ` (respectively), and if `ΔS` is not zero, then it is a diagonal matrix of the same size as `S`. For this method to work correctly, it is also assumed that the remaining -singular values (not included in `S`) are (sufficiently) separated from those in `S`. +singular values (not included in `S`) are (sufficiently) smaller than those in `S`. A warning will be printed if the cotangents are not gauge-invariant, i.e. if the anti-hermitian part of `U' * ΔU + Vᴴ * ΔVᴴ'`, restricted to rows `i` and columns `j` for @@ -144,7 +203,8 @@ function svd_trunc_pullback!( ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ; rank_atol::Real = 0, degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), - gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ...), + maxiter::Int = 1000, ) # Extract the SVD components @@ -158,62 +218,54 @@ function svd_trunc_pullback!( # Extract and check the cotangents ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ - UΔU = fill!(similar(U, (p, p)), 0) - VΔV = fill!(similar(Vᴴ, (p, p)), 0) - if !iszerotangent(ΔU) - (m, p) == size(ΔU) || throw(DimensionMismatch()) - mul!(UΔU, U', ΔU) - end - if !iszerotangent(ΔVᴴ) - (p, n) == size(ΔVᴴ) || throw(DimensionMismatch()) - mul!(VΔV, Vᴴ, ΔVᴴ') - # ΔVᴴ -= VΔVp' * Vᴴr but one less allocation without overwriting ΔVᴴ - ΔVᴴ = mul!(copy(ΔVᴴ), VΔV', Vᴴ, -1, 1) - end - - # Project onto antihermitian part; hermitian part outside of Grassmann tangent space - aUΔU = project_antihermitian!(UΔU) - aVΔV = project_antihermitian!(VΔV) - - # check whether cotangents arise from gauge-invariance objective function - check_svd_cotangents(aUΔU, S, aVΔV; degeneracy_atol, gauge_atol) + ΔU, ΔS, ΔVᴴ, aUᴴΔU, aVᴴΔV = check_and_prepare_svd_cotangents( + U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, p; degeneracy_atol, gauge_atol + ) - UdΔAV = (aUΔU .+ aVΔV) .* inv_safe.(S' .- S, degeneracy_atol) .+ - (aUΔU .- aVΔV) .* inv_safe.(S' .+ S, degeneracy_atol) - if !iszerotangent(ΔSmat) - ΔS = diagview(ΔSmat) - p == length(ΔS) || throw(DimensionMismatch()) + # This part is the same as in `svd_pullback!` + UdΔAV = (aUᴴΔU .+ aVᴴΔV) .* inv_safe.(S' .- S, degeneracy_atol) .+ + (aUᴴΔU .- aVᴴΔV) .* inv_safe.(S' .+ S, degeneracy_atol) + if !iszerotangent(ΔS) diagview(UdΔAV) .+= real.(ΔS) end ΔA = mul!(ΔA, U, UdΔAV * Vᴴ, 1, 1) # add the contribution to ΔA - # add contribution from orthogonal complement - Ũ = qr_null(U) - Ṽᴴ = lq_null(Vᴴ) - m̃ = m - p - ñ = n - p - Ã = Ũ' * A * Ṽᴴ' - ÃÃ = similar(A, (m̃ + ñ, m̃ + ñ)) - fill!(ÃÃ, 0) - view(ÃÃ, (1:m̃), m̃ .+ (1:ñ)) .= Ã - view(ÃÃ, m̃ .+ (1:ñ), 1:m̃) .= Ã' - - rhs = similar(Ũ, (m̃ + ñ, p)) - if !iszerotangent(ΔU) - mul!(view(rhs, 1:m̃, :), Ũ', ΔU) - else - fill!(view(rhs, 1:m̃, :), 0) - end - if !iszerotangent(ΔVᴴ) - mul!(view(rhs, m̃ .+ (1:ñ), :), Ṽᴴ, ΔVᴴ') - else - fill!(view(rhs, m̃ .+ (1:ñ), :), 0) + # The contribtutions from the orthogonal complement need to be treated differently + # ΔU and ΔVᴴ are already orthogonal to U and Vᴴ + if !(iszerotangent(ΔU) && iszerotangent(ΔVᴴ)) + US = U * Smat + APAᴴ = mul!(A * A', US, US', -1, 1) + SVᴴ = Smat * Vᴴ + AᴴPA = mul!(A' * A, SVᴴ', SVᴴ, -1, 1) + + rhs = [iszerotangent(ΔU) ? zero(U) : ΔU; iszerotangent(ΔVᴴ) ? zero(Vᴴ') : ΔVᴴ'] + AA = [zero(APAᴴ) (A - U * (U' * A)); (A' - Vᴴ' * (Vᴴ * A')) zero(AᴴPA)] + XY = _sylvester(AA, -Smat, rhs) + + Aperp = A - U * Smat * Vᴴ + x₀ = iszerotangent(ΔU) ? zero(U) : rdiv!(ΔU, Diagonal(S)) + y₀ᴴ = iszerotangent(ΔVᴴ) ? zero(Vᴴ) : ldiv!(Diagonal(S), ΔVᴴ) + X = copy(x₀) + Yᴴ = copy(y₀ᴴ) + xₖ, xₖ₊₁ = x₀, zero(x₀) + yₖᴴ, yₖ₊₁ᴴ = y₀ᴴ, zero(y₀ᴴ) + for k in 1:maxiter + xₖ₊₁ = rdiv!(mul!(xₖ₊₁, Aperp, yₖᴴ'), Diagonal(S)) + yₖ₊₁ᴴ = ldiv!(Diagonal(S), mul!(yₖ₊₁ᴴ, xₖ', Aperp)) + X .+= xₖ₊₁ + Yᴴ .+= yₖ₊₁ᴴ + if norm(xₖ₊₁, Inf) < degeneracy_atol && norm(yₖ₊₁ᴴ, Inf) < degeneracy_atol + break + end + xₖ, xₖ₊₁ = xₖ₊₁, xₖ + yₖᴴ, yₖ₊₁ᴴ = yₖ₊₁ᴴ, yₖᴴ + if k == maxiter + @warn "Sylvester iteration did not converge after $k iterations, final norms: (x: $(norm(xₖ₊₁, Inf)), y: $(norm(yₖ₊₁ᴴ, Inf)))" + end + end + ΔA = mul!(ΔA, X, Vᴴ, 1, 1) + ΔA = mul!(ΔA, U, Yᴴ, 1, 1) end - XY = _sylvester(ÃÃ, -Smat, rhs) - X = view(XY, 1:m̃, :) - Y = view(XY, m̃ .+ (1:ñ), :) - ΔA = mul!(ΔA, Ũ, X * Vᴴ, 1, 1) - ΔA = mul!(ΔA, U, Y' * Ṽᴴ, 1, 1) return ΔA end function svd_trunc_pullback!( @@ -253,3 +305,52 @@ function svd_vals_pullback!( ΔUSVᴴ = (nothing, diagonal(ΔS), nothing) return svd_pullback!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol) end + +""" + remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = ..., rank_atol = ...) + +Remove the gauge-dependent part from the cotangents `ΔU` and `ΔVᴴ` of the SVD factors. The +singular vectors are only determined up to a common complex phase per singular value (or a +unitary transformation across singular vectors associated with degenerate singular values), +so the corresponding anti-Hermitian components of `U₁' * ΔU₁ + Vᴴ₁ * ΔVᴴ₁'` are projected out. +For the full SVD, the extra columns of `U` and rows of `Vᴴ` beyond the rank `r` are +additionally zeroed out, where `r = count(diagview(S) .> rank_atol)`. +""" +function remove_svd_gauge_dependence!( + ΔU, ΔVᴴ, U, S, Vᴴ; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S), + rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(S) + ) + Sdiag = diagview(S) + r = MatrixAlgebraKit.svd_rank(Sdiag; rank_atol) + U₁ = view(U, :, 1:r) + Vᴴ₁ = view(Vᴴ, 1:r, :) + ΔU₁ = view(ΔU, :, 1:r) + ΔVᴴ₁ = view(ΔVᴴ, 1:r, :) + Sdiag = diagview(S) + gaugepart = mul!(U₁' * ΔU₁, Vᴴ₁, ΔVᴴ₁', true, true) + gaugepart = project_antihermitian!(gaugepart) + gaugepart[abs.(transpose(Sdiag) .- Sdiag) .>= degeneracy_atol] .= 0 + mul!(ΔU₁, U₁, gaugepart, -1, 1) + if size(ΔU, 2) > r + if r < length(Sdiag) # rank-deficient case, no stable information can be extracted from extra columns of U + ΔU[:, (r + 1):end] .= 0 + else # the component of ΔU₂ along U₁ contains gauge-invariant information + p = size(ΔU, 2) + ΔU₂ = view(ΔU, :, (r + 1):p) + U₁ᴴΔU₂ = U₁' * ΔU₂ + mul!(ΔU₂, U₁, U₁ᴴΔU₂) + end + end + if size(ΔVᴴ, 1) > r + if r < length(Sdiag) # rank-deficient case, no stable information can be extracted from extra rows of Vᴴ + ΔVᴴ[(r + 1):end, :] .= 0 + else # the component of ΔVᴴ₂ along Vᴴ₁ contains gauge-invariant information + p = size(ΔVᴴ, 1) + ΔVᴴ₂ = view(ΔVᴴ, (r + 1):p, :) + ΔVᴴ₂V₁ = ΔVᴴ₂ * Vᴴ₁' + mul!(ΔVᴴ₂, ΔVᴴ₂V₁, Vᴴ₁) + end + end + return ΔU, ΔVᴴ +end diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index 32fc17485..09558e0ed 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -1,162 +1,12 @@ +using MatrixAlgebraKit: remove_svd_gauge_dependence!, + remove_eig_gauge_dependence!, remove_eigh_gauge_dependence!, + remove_qr_gauge_dependence!, remove_qr_null_gauge_dependence!, + remove_lq_gauge_dependence!, remove_lq_null_gauge_dependence!, + remove_left_null_gauge_dependence!, remove_right_null_gauge_dependence! + structured_randn!(A::AbstractMatrix) = randn!(A) structured_randn!(A::Diagonal) = (randn!(diagview(A)); return A) -""" - remove_eig_gauge_dependence!(ΔV, D, V; degeneracy_atol = ...) - -Remove the gauge-dependent part from the cotangent `ΔV` of the eigenvector matrix `V`. The -eigenvectors are only determined up to a scalar factor (or an abitrary linear transformation -across eigenvectors associated with degenerate eigenvalues), so the corresponding components of -`ΔV` are projected out. -""" -function remove_eig_gauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) - ) - gaugepart = V' * ΔV - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V / (V' * V), gaugepart, -1, 1) - return ΔV -end - -""" - remove_eigh_gauge_dependence!(ΔV, D, V; degeneracy_atol = ...) - -Remove the gauge-dependent part from the cotangent `ΔV` of the Hermitian eigenvector matrix -`V`. The eigenvectors are only determined up to a complex phase (or a unitary transformation -across eigenvectors associated with degenerate eigenvalues), so the corresponding anti-Hermitian -components of `V' * ΔV` are projected out. -""" -function remove_eigh_gauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) - ) - gaugepart = V' * ΔV - gaugepart = project_antihermitian!(gaugepart) - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V, gaugepart, -1, 1) - return ΔV -end - -""" - remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = ..., rank_atol = ...) - -Remove the gauge-dependent part from the cotangents `ΔU` and `ΔVᴴ` of the SVD factors. The -singular vectors are only determined up to a common complex phase per singular value (or a -unitary transformation across singular vectors associated with degenerate singular values), -so the corresponding anti-Hermitian components of `U₁' * ΔU₁ + Vᴴ₁ * ΔVᴴ₁'` are projected out. -For the full SVD, the extra columns of `U` and rows of `Vᴴ` beyond the rank `r` are -additionally zeroed out, where `r = count(diagview(S) .> rank_atol)`. -""" -function remove_svd_gauge_dependence!( - ΔU, ΔVᴴ, U, S, Vᴴ; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S), - rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(S) - ) - r = MatrixAlgebraKit.svd_rank(diagview(S); rank_atol) - U₁ = view(U, :, 1:r) - Vᴴ₁ = view(Vᴴ, 1:r, :) - ΔU₁ = view(ΔU, :, 1:r) - ΔVᴴ₁ = view(ΔVᴴ, 1:r, :) - Sdiag = diagview(S) - gaugepart = mul!(U₁' * ΔU₁, Vᴴ₁, ΔVᴴ₁', true, true) - gaugepart = project_antihermitian!(gaugepart) - gaugepart[abs.(transpose(Sdiag) .- Sdiag) .>= degeneracy_atol] .= 0 - mul!(ΔU₁, U₁, gaugepart, -1, 1) - ΔU[:, (r + 1):end] .= 0 - ΔVᴴ[(r + 1):end, :] .= 0 - return ΔU, ΔVᴴ -end - -""" - remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = ...) - -Remove the gauge-dependent part from the cotangents `ΔQ` and `ΔR` of the QR factors `Q` and -`R`. For the full QR decomposition, the extra columns of `Q` beyond the rank `r` are not -uniquely determined by `A`, so the corresponding part of `ΔQ` is projected to remove this -ambiguity. Additionally, rows of `ΔR` beyond the rank are zeroed out. -""" -function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(R)) - r = MatrixAlgebraKit.qr_rank(R; rank_atol) - minmn = min(size(A)...) - Q₁ = view(Q, :, 1:r) - ΔQ₂ = view(ΔQ, :, (r + 1):minmn) - ΔQ₂ .= 0 - ΔQ₃ = view(ΔQ, :, (minmn + 1):size(ΔQ, 2)) # extra columns in the case of qr_full - Q₁ᴴΔQ₃ = Q₁' * ΔQ₃ - mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃) - ΔR₂₂ = view(ΔR, (r + 1):minmn, (r + 1):size(R, 2)) - MatrixAlgebraKit.diagview(ΔR₂₂) .= 0 - view(ΔR₂₂, MatrixAlgebraKit.uppertriangularind(ΔR₂₂)) .= 0 - return ΔQ, ΔR -end - -""" - remove_qr_null_gauge_dependence!(ΔN, A, N) - -Remove the gauge-dependent part from the cotangent `ΔN` of the QR null space `N`. The null -space is only determined up to a unitary rotation, so `ΔN` is projected onto the column span -of the compact QR factor `Q₁`. -""" -function remove_qr_null_gauge_dependence!(ΔN, A, N) - Q, _ = qr_compact(A) - return mul!(ΔN, Q, Q' * ΔN) -end - -""" - remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = ...) - -Remove the gauge-dependent part from the cotangents `ΔL` and `ΔQ` of the LQ factors `L` and -`Q`. For the full LQ decomposition, the extra rows of `Q` beyond the rank `r` are not uniquely -determined by `A`, so the corresponding part of `ΔQ` is projected to remove this ambiguity. -Additionally, columns of `ΔL` beyond the rank are zeroed out. -""" -function remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(L)) - r = MatrixAlgebraKit.lq_rank(L; rank_atol) - minmn = min(size(A)...) - Q₁ = view(Q, 1:r, :) - ΔQ₂ = view(ΔQ, (r + 1):minmn, :) - ΔQ₂ .= 0 - ΔQ₃ = view(ΔQ, (minmn + 1):size(ΔQ, 1), :) # extra rows in the case of lq_full - ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' - mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁) - ΔL₂₂ = view(ΔL, (r + 1):size(ΔL, 1), (r + 1):minmn) - MatrixAlgebraKit.diagview(ΔL₂₂) .= 0 - view(ΔL₂₂, MatrixAlgebraKit.lowertriangularind(ΔL₂₂)) .= 0 - return ΔL, ΔQ -end - -""" - remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) - -Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the LQ null space `Nᴴ`. The null -space is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the row span of -the compact LQ factor `Q₁`. -""" -function remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) - _, Q = lq_compact(A) - ΔNᴴQᴴ = ΔNᴴ * Q' - return mul!(ΔNᴴ, ΔNᴴQᴴ, Q) -end - -""" - remove_left_null_gauge_dependence!(ΔN, A, N) - -Remove the gauge-dependent part from the cotangent `ΔN` of the left null space `N`. The null -space basis is only determined up to a unitary rotation, so `ΔN` is projected onto the column -span of the compact QR factor `Q₁` of `A`. -""" -remove_left_null_gauge_dependence!(ΔN, A, N) = remove_qr_null_gauge_dependence!(ΔN, A, N) - -""" - remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) - -Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the right null space `Nᴴ`. The -null space basis is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the -row span of the compact LQ factor `Q₁` of `A`. -""" -remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) = remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) - """ call_and_zero!(f!, A, alg) diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl index c722fdb67..558afc839 100644 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -8,7 +8,7 @@ for f in :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, :eig_trunc_no_error, :eigh_trunc_no_error, - :svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals, + :svd_compact, :svd_full, :svd_trunc, :svd_trunc_no_error, :svd_vals, :left_polar, :right_polar, ) copy_f = Symbol(:cr_copy_, f) @@ -418,6 +418,18 @@ function test_chainrules_svd( rrule_f = rrule_via_ad, check_inferred = false ) end + @testset "svd_full" begin + USV, ΔUSVᴴ = ad_svd_full_setup(A) + test_rrule( + cr_copy_svd_full, A, alg ⊢ NoTangent(); + output_tangent = ΔUSVᴴ, atol = atol, rtol = rtol + ) + test_rrule( + config, svd_full, A, alg ⊢ NoTangent(); + output_tangent = ΔUSVᴴ, atol = atol, rtol = rtol, + rrule_f = rrule_via_ad, check_inferred = false + ) + end @testset "svd_vals" begin S, ΔS = ad_svd_vals_setup(A) test_rrule( From 1c64fd7ca491eb16e61545773115f9002c7a2681 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Sat, 18 Apr 2026 22:48:31 +0200 Subject: [PATCH 02/13] add quadratic svd_trunc_pullback --- src/pullbacks/svd.jl | 82 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 73 insertions(+), 9 deletions(-) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 1d856600a..77edcaa78 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -233,15 +233,6 @@ function svd_trunc_pullback!( # The contribtutions from the orthogonal complement need to be treated differently # ΔU and ΔVᴴ are already orthogonal to U and Vᴴ if !(iszerotangent(ΔU) && iszerotangent(ΔVᴴ)) - US = U * Smat - APAᴴ = mul!(A * A', US, US', -1, 1) - SVᴴ = Smat * Vᴴ - AᴴPA = mul!(A' * A, SVᴴ', SVᴴ, -1, 1) - - rhs = [iszerotangent(ΔU) ? zero(U) : ΔU; iszerotangent(ΔVᴴ) ? zero(Vᴴ') : ΔVᴴ'] - AA = [zero(APAᴴ) (A - U * (U' * A)); (A' - Vᴴ' * (Vᴴ * A')) zero(AᴴPA)] - XY = _sylvester(AA, -Smat, rhs) - Aperp = A - U * Smat * Vᴴ x₀ = iszerotangent(ΔU) ? zero(U) : rdiv!(ΔU, Diagonal(S)) y₀ᴴ = iszerotangent(ΔVᴴ) ? zero(Vᴴ) : ldiv!(Diagonal(S), ΔVᴴ) @@ -268,6 +259,79 @@ function svd_trunc_pullback!( end return ΔA end +function svd_trunc_pullback2!( + ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ; + rank_atol::Real = 0, + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ...), + maxiter::Int = 1000, + ) + + # Extract the SVD components + U, Smat, Vᴴ = USVᴴ + m, n = size(U, 1), size(Vᴴ, 2) + (m, n) == size(ΔA) || throw(DimensionMismatch()) + p = size(U, 2) + p == size(Vᴴ, 1) || throw(DimensionMismatch()) + S = diagview(Smat) + p == length(S) || throw(DimensionMismatch()) + + # Extract and check the cotangents + ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ + ΔU, ΔS, ΔVᴴ, aUᴴΔU, aVᴴΔV = check_and_prepare_svd_cotangents( + U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, p; degeneracy_atol, gauge_atol + ) + + # This part is the same as in `svd_pullback!` + UdΔAV = (aUᴴΔU .+ aVᴴΔV) .* inv_safe.(S' .- S, degeneracy_atol) .+ + (aUᴴΔU .- aVᴴΔV) .* inv_safe.(S' .+ S, degeneracy_atol) + if !iszerotangent(ΔS) + diagview(UdΔAV) .+= real.(ΔS) + end + ΔA = mul!(ΔA, U, UdΔAV * Vᴴ, 1, 1) # add the contribution to ΔA + + # The contribtutions from the orthogonal complement need to be treated differently + # ΔU and ΔVᴴ are already orthogonal to U and Vᴴ + if !(iszerotangent(ΔU) && iszerotangent(ΔVᴴ)) + X₀ = iszerotangent(ΔU) ? zero(U) : rdiv!(ΔU, Diagonal(S)) + Y₀ᴴ = iszerotangent(ΔVᴴ) ? zero(Vᴴ) : ldiv!(Diagonal(S), ΔVᴴ) + AP = A - U * Smat * Vᴴ + AP ./= S[1] + S = S ./ S[1] + X₁ = X₀ + rdiv!(AP * Y₀ᴴ', Diagonal(S)) + Y₁ᴴ = Y₀ᴴ + ldiv!(Diagonal(S), X₀' * AP) + Xₖ, Xₖ₊₁ = X₁, X₀ + Yₖᴴ, Yₖ₊₁ᴴ = Y₁ᴴ, Y₀ᴴ + APAᴴₖ, AᴴPAₖ = AP * AP', AP' * AP + APAᴴₖ₊₁, AᴴPAₖ₊₁ = zero(APAᴴₖ), zero(AᴴPAₖ) + Sₖ, Sₖ₊₁ = S .^ 2, S + for k in 1:maxiter + Xₖ₊₁ = rdiv!(mul!(Xₖ₊₁, APAᴴₖ, Xₖ), Diagonal(Sₖ)) + Yₖ₊₁ᴴ = ldiv!(Diagonal(Sₖ), mul!(Yₖ₊₁ᴴ, Yₖᴴ, AᴴPAₖ)) + if norm(Xₖ₊₁, Inf) < degeneracy_atol && norm(Yₖ₊₁ᴴ, Inf) < degeneracy_atol + break + end + Xₖ₊₁ .+= Xₖ + Yₖ₊₁ᴴ .+= Yₖᴴ + if k == maxiter + @warn "Sylvester iteration did not converge after $k iterations, final norms: (X: $(norm(Xₖ₊₁, Inf)), Yᴴ: $(norm(Yₖ₊₁ᴴ, Inf)))" + break + end + Sₖ₊₁ .= Sₖ .^ 2 + APAᴴₖ₊₁ = mul!(APAᴴₖ₊₁, APAᴴₖ, APAᴴₖ) + AᴴPAₖ₊₁ = mul!(AᴴPAₖ₊₁, AᴴPAₖ, AᴴPAₖ) + Xₖ, Xₖ₊₁ = Xₖ₊₁, Xₖ + Yₖᴴ, Yₖ₊₁ᴴ = Yₖ₊₁ᴴ, Yₖᴴ + APAᴴₖ, APAᴴₖ₊₁ = APAᴴₖ₊₁, APAᴴₖ + AᴴPAₖ, AᴴPAₖ₊₁ = AᴴPAₖ₊₁, AᴴPAₖ + Sₖ, Sₖ₊₁ = Sₖ₊₁, Sₖ + end + ΔA = mul!(ΔA, Xₖ, Vᴴ, 1, 1) + ΔA = mul!(ΔA, U, Yₖᴴ, 1, 1) + end + return ΔA +end + function svd_trunc_pullback!( ΔA::Diagonal, A, USVᴴ, ΔUSVᴴ; rank_atol::Real = 0, From b0711774a35cd73ef68cfbefa5944099841378b9 Mon Sep 17 00:00:00 2001 From: Jutho Date: Mon, 20 Apr 2026 11:22:57 +0200 Subject: [PATCH 03/13] Apply suggestions from code review Co-authored-by: Lukas Devos Co-authored-by: Jutho --- src/pullbacks/lq.jl | 8 ++++---- src/pullbacks/qr.jl | 8 ++++---- src/pullbacks/svd.jl | 8 +++++--- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/pullbacks/lq.jl b/src/pullbacks/lq.jl index ea088839c..251fafa86 100644 --- a/src/pullbacks/lq.jl +++ b/src/pullbacks/lq.jl @@ -148,17 +148,17 @@ function remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = MatrixAlgebr minmn = min(size(A)...) Q₁ = view(Q, 1:r, :) ΔQ₂ = view(ΔQ, (r + 1):minmn, :) - ΔQ₂ .= 0 + zero!(ΔQ₂) ΔQ₃ = view(ΔQ, (minmn + 1):size(ΔQ, 1), :) # extra rows in the case of lq_full if r == minmn ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁) else # rank-deficient case, no gauge-invariant information - ΔQ₃ .= 0 + zero!(ΔQ₃) end ΔL₂₂ = view(ΔL, (r + 1):size(ΔL, 1), (r + 1):minmn) - diagview(ΔL₂₂) .= 0 - view(ΔL₂₂, lowertriangularind(ΔL₂₂)) .= 0 + zero!(diagview(ΔL₂₂)) + zero!(view(ΔL₂₂, lowertriangularind(ΔL₂₂))) return ΔL, ΔQ end diff --git a/src/pullbacks/qr.jl b/src/pullbacks/qr.jl index 055198235..97bd45ae6 100644 --- a/src/pullbacks/qr.jl +++ b/src/pullbacks/qr.jl @@ -151,17 +151,17 @@ function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = MatrixAlgebr minmn = min(size(A)...) Q₁ = view(Q, :, 1:r) ΔQ₂ = view(ΔQ, :, (r + 1):minmn) - ΔQ₂ .= 0 + zero!(ΔQ₂) ΔQ₃ = view(ΔQ, :, (minmn + 1):size(ΔQ, 2)) # extra columns in the case of qr_full if r == minmn # full rank case, ΔQ₃ contains gauge-invariant information along Q₁ Q₁ᴴΔQ₃ = Q₁' * ΔQ₃ mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃) else # rank-deficient case, no gauge-invariant information - ΔQ₃ .= 0 + zero!(ΔQ₃) end ΔR₂₂ = view(ΔR, (r + 1):minmn, (r + 1):size(R, 2)) - diagview(ΔR₂₂) .= 0 - view(ΔR₂₂, uppertriangularind(ΔR₂₂)) .= 0 + zero!(diagview(ΔR₂₂)) + zero!(view(ΔR₂₂, uppertriangularind(ΔR₂₂))) return ΔQ, ΔR end diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 77edcaa78..f85668d9a 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -158,10 +158,12 @@ function svd_pullback!( # Add the remaining contributions if m > r && !iszerotangent(ΔU₁) # ΔU₁ is already orthogonal to U₁ - ΔA = mul!(ΔA, ΔU₁ ./ S₁', V₁ᴴ, 1, 1) + ΔU₁ ./= S₁' + ΔA = mul!(ΔA, ΔU₁, V₁ᴴ, 1, 1) end if n > r && !iszerotangent(ΔV₁ᴴ) # ΔV₁ᴴ is already orthogonal to V₁ᴴ - ΔA = mul!(ΔA, U₁, S₁ .\ ΔV₁ᴴ, 1, 1) + ΔV₁ᴴ .= S₁ .\ ΔV₁ᴴ + ΔA = mul!(ΔA, U₁, ΔV₁ᴴ, 1, 1) end return ΔA end @@ -230,7 +232,7 @@ function svd_trunc_pullback!( end ΔA = mul!(ΔA, U, UdΔAV * Vᴴ, 1, 1) # add the contribution to ΔA - # The contribtutions from the orthogonal complement need to be treated differently + # The contributions from the orthogonal complement need to be treated differently # ΔU and ΔVᴴ are already orthogonal to U and Vᴴ if !(iszerotangent(ΔU) && iszerotangent(ΔVᴴ)) Aperp = A - U * Smat * Vᴴ From 44c63d744642085ae36e563ddf7931bb15143082 Mon Sep 17 00:00:00 2001 From: Jutho Date: Mon, 20 Apr 2026 15:09:33 +0200 Subject: [PATCH 04/13] Apply more suggestions from code review Co-authored-by: Lukas Devos --- src/pullbacks/svd.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index f85668d9a..615509e15 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -235,7 +235,7 @@ function svd_trunc_pullback!( # The contributions from the orthogonal complement need to be treated differently # ΔU and ΔVᴴ are already orthogonal to U and Vᴴ if !(iszerotangent(ΔU) && iszerotangent(ΔVᴴ)) - Aperp = A - U * Smat * Vᴴ + Aperp = mul!(copy(A), U, Smat * Vᴴ, -1, 1) x₀ = iszerotangent(ΔU) ? zero(U) : rdiv!(ΔU, Diagonal(S)) y₀ᴴ = iszerotangent(ΔVᴴ) ? zero(Vᴴ) : ldiv!(Diagonal(S), ΔVᴴ) X = copy(x₀) @@ -297,11 +297,13 @@ function svd_trunc_pullback2!( if !(iszerotangent(ΔU) && iszerotangent(ΔVᴴ)) X₀ = iszerotangent(ΔU) ? zero(U) : rdiv!(ΔU, Diagonal(S)) Y₀ᴴ = iszerotangent(ΔVᴴ) ? zero(Vᴴ) : ldiv!(Diagonal(S), ΔVᴴ) - AP = A - U * Smat * Vᴴ + AP = mul!(copy(A), U, Smat * Vᴴ, -1, 1) AP ./= S[1] S = S ./ S[1] - X₁ = X₀ + rdiv!(AP * Y₀ᴴ', Diagonal(S)) - Y₁ᴴ = Y₀ᴴ + ldiv!(Diagonal(S), X₀' * AP) + X₁ = rdiv!(AP * Y₀ᴴ', Diagonal(S)) + X₁ .+= X₀ + Y₁ᴴ = ldiv!(Diagonal(S), X₀' * AP) + Y₁ᴴ .+= Y₀ᴴ Xₖ, Xₖ₊₁ = X₁, X₀ Yₖᴴ, Yₖ₊₁ᴴ = Y₁ᴴ, Y₀ᴴ APAᴴₖ, AᴴPAₖ = AP * AP', AP' * AP @@ -400,7 +402,7 @@ function remove_svd_gauge_dependence!( mul!(ΔU₁, U₁, gaugepart, -1, 1) if size(ΔU, 2) > r if r < length(Sdiag) # rank-deficient case, no stable information can be extracted from extra columns of U - ΔU[:, (r + 1):end] .= 0 + zero!(ΔU[:, (r + 1):end]) else # the component of ΔU₂ along U₁ contains gauge-invariant information p = size(ΔU, 2) ΔU₂ = view(ΔU, :, (r + 1):p) @@ -410,7 +412,7 @@ function remove_svd_gauge_dependence!( end if size(ΔVᴴ, 1) > r if r < length(Sdiag) # rank-deficient case, no stable information can be extracted from extra rows of Vᴴ - ΔVᴴ[(r + 1):end, :] .= 0 + zero!(ΔVᴴ[(r + 1):end, :]) else # the component of ΔVᴴ₂ along Vᴴ₁ contains gauge-invariant information p = size(ΔVᴴ, 1) ΔVᴴ₂ = view(ΔVᴴ, (r + 1):p, :) From 7fdc9b46f5bb4eb6526e58ea6443a600b80a58d5 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Wed, 22 Apr 2026 22:50:31 +0200 Subject: [PATCH 05/13] some changes from review --- src/pullbacks/svd.jl | 75 ++++++++++++++++++-------------------------- 1 file changed, 31 insertions(+), 44 deletions(-) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 615509e15..0c7f74252 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -42,11 +42,11 @@ function check_and_prepare_svd_cotangents( end end UᴴΔU₁ = U₁' * ΔU₁ - ΔU₁ = mul!(ΔU₁, U₁, UᴴΔU₁, -1, 1) + ΔU₊ = mul!(ΔU₁, U₁, UᴴΔU₁, -1, 1) aUᴴΔU₁ = project_antihermitian!(UᴴΔU₁) Δgauge = max(Δgauge, ΔgaugeU) else - ΔU₁ = nothing + ΔU₊ = nothing aUᴴΔU₁ = zero!(similar(U₁, (r, r))) end if !iszerotangent(ΔVᴴ) @@ -74,11 +74,11 @@ function check_and_prepare_svd_cotangents( end end VᴴΔV₁ = V₁ᴴ * ΔV₁ᴴ' - ΔV₁ᴴ = mul!(ΔV₁ᴴ, VᴴΔV₁', V₁ᴴ, -1, 1) + ΔV₊ᴴ = mul!(ΔV₁ᴴ, VᴴΔV₁', V₁ᴴ, -1, 1) aVᴴΔV₁ = project_antihermitian!(VᴴΔV₁) Δgauge = max(Δgauge, ΔgaugeV) else - ΔV₁ᴴ = nothing + ΔV₊ᴴ = nothing aVᴴΔV₁ = zero!(similar(V₁ᴴ, (r, r))) end mask = abs.(S₁' .- S₁) .< degeneracy_atol @@ -101,7 +101,14 @@ function check_and_prepare_svd_cotangents( Δgauge ≤ gauge_atol || @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - return ΔU₁, ΔS₁, ΔV₁ᴴ, aUᴴΔU₁, aVᴴΔV₁ + + UdΔAV = (aUᴴΔU₁ .+ aVᴴΔV₁) .* inv_safe.(S₁' .- S₁, degeneracy_atol) .+ + (aUᴴΔU₁ .- aVᴴΔV₁) .* inv_safe.(S₁' .+ S₁, degeneracy_atol) + if !iszerotangent(ΔS₁) + diagview(UdΔAV) .+= real.(ΔS₁) + end + + return UdΔAV, ΔU₊, ΔV₊ᴴ end """ @@ -145,25 +152,19 @@ function svd_pullback!( S₁ = view(S, 1:r) ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ - ΔU₁, ΔS₁, ΔV₁ᴴ, aUᴴΔU₁, aVᴴΔV₁ = check_and_prepare_svd_cotangents( + UdΔAV, ΔU₊, ΔV₊ᴴ = check_and_prepare_svd_cotangents( U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, r, ind; degeneracy_atol, gauge_atol ) - - UdΔAV = (aUᴴΔU₁ .+ aVᴴΔV₁) .* inv_safe.(S₁' .- S₁, degeneracy_atol) .+ - (aUᴴΔU₁ .- aVᴴΔV₁) .* inv_safe.(S₁' .+ S₁, degeneracy_atol) - if !iszerotangent(ΔS₁) - diagview(UdΔAV) .+= real.(ΔS₁) - end ΔA = mul!(ΔA, U₁, UdΔAV * V₁ᴴ, 1, 1) # add the contribution to ΔA # Add the remaining contributions - if m > r && !iszerotangent(ΔU₁) # ΔU₁ is already orthogonal to U₁ - ΔU₁ ./= S₁' - ΔA = mul!(ΔA, ΔU₁, V₁ᴴ, 1, 1) + if m > r && !iszerotangent(ΔU₊) # ΔU₁ is already orthogonal to U₁ + ΔU₊ ./= S₁' + ΔA = mul!(ΔA, ΔU₊, V₁ᴴ, 1, 1) end - if n > r && !iszerotangent(ΔV₁ᴴ) # ΔV₁ᴴ is already orthogonal to V₁ᴴ - ΔV₁ᴴ .= S₁ .\ ΔV₁ᴴ - ΔA = mul!(ΔA, U₁, ΔV₁ᴴ, 1, 1) + if n > r && !iszerotangent(ΔV₊ᴴ) # ΔV₁ᴴ is already orthogonal to V₁ᴴ + ΔV₊ᴴ .= S₁ .\ ΔV₊ᴴ + ΔA = mul!(ΔA, U₁, ΔV₊ᴴ, 1, 1) end return ΔA end @@ -201,7 +202,7 @@ A warning will be printed if the cotangents are not gauge-invariant, i.e. if the anti-hermitian part of `U' * ΔU + Vᴴ * ΔVᴴ'`, restricted to rows `i` and columns `j` for which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol`. """ -function svd_trunc_pullback!( +function svd_trunc_pullback2!( ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ; rank_atol::Real = 0, degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), @@ -220,24 +221,17 @@ function svd_trunc_pullback!( # Extract and check the cotangents ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ - ΔU, ΔS, ΔVᴴ, aUᴴΔU, aVᴴΔV = check_and_prepare_svd_cotangents( - U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, p; degeneracy_atol, gauge_atol + UdΔAV, ΔU₊, ΔV₊ᴴ = check_and_prepare_svd_cotangents( + U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, r, ind; degeneracy_atol, gauge_atol ) - - # This part is the same as in `svd_pullback!` - UdΔAV = (aUᴴΔU .+ aVᴴΔV) .* inv_safe.(S' .- S, degeneracy_atol) .+ - (aUᴴΔU .- aVᴴΔV) .* inv_safe.(S' .+ S, degeneracy_atol) - if !iszerotangent(ΔS) - diagview(UdΔAV) .+= real.(ΔS) - end ΔA = mul!(ΔA, U, UdΔAV * Vᴴ, 1, 1) # add the contribution to ΔA # The contributions from the orthogonal complement need to be treated differently # ΔU and ΔVᴴ are already orthogonal to U and Vᴴ - if !(iszerotangent(ΔU) && iszerotangent(ΔVᴴ)) + if !(iszerotangent(ΔU₊) && iszerotangent(ΔV₊ᴴ)) Aperp = mul!(copy(A), U, Smat * Vᴴ, -1, 1) - x₀ = iszerotangent(ΔU) ? zero(U) : rdiv!(ΔU, Diagonal(S)) - y₀ᴴ = iszerotangent(ΔVᴴ) ? zero(Vᴴ) : ldiv!(Diagonal(S), ΔVᴴ) + x₀ = iszerotangent(ΔU₊) ? zero(U) : rdiv!(ΔU₊, Diagonal(S)) + y₀ᴴ = iszerotangent(ΔV₊ᴴ) ? zero(Vᴴ) : ldiv!(Diagonal(S), ΔV₊ᴴ) X = copy(x₀) Yᴴ = copy(y₀ᴴ) xₖ, xₖ₊₁ = x₀, zero(x₀) @@ -261,7 +255,7 @@ function svd_trunc_pullback!( end return ΔA end -function svd_trunc_pullback2!( +function svd_trunc_pullback!( ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ; rank_atol::Real = 0, degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), @@ -280,23 +274,16 @@ function svd_trunc_pullback2!( # Extract and check the cotangents ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ - ΔU, ΔS, ΔVᴴ, aUᴴΔU, aVᴴΔV = check_and_prepare_svd_cotangents( - U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, p; degeneracy_atol, gauge_atol + UdΔAV, ΔU₊, ΔV₊ᴴ = check_and_prepare_svd_cotangents( + U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, r, ind; degeneracy_atol, gauge_atol ) - - # This part is the same as in `svd_pullback!` - UdΔAV = (aUᴴΔU .+ aVᴴΔV) .* inv_safe.(S' .- S, degeneracy_atol) .+ - (aUᴴΔU .- aVᴴΔV) .* inv_safe.(S' .+ S, degeneracy_atol) - if !iszerotangent(ΔS) - diagview(UdΔAV) .+= real.(ΔS) - end ΔA = mul!(ΔA, U, UdΔAV * Vᴴ, 1, 1) # add the contribution to ΔA # The contribtutions from the orthogonal complement need to be treated differently # ΔU and ΔVᴴ are already orthogonal to U and Vᴴ - if !(iszerotangent(ΔU) && iszerotangent(ΔVᴴ)) - X₀ = iszerotangent(ΔU) ? zero(U) : rdiv!(ΔU, Diagonal(S)) - Y₀ᴴ = iszerotangent(ΔVᴴ) ? zero(Vᴴ) : ldiv!(Diagonal(S), ΔVᴴ) + if !(iszerotangent(ΔU₊) && iszerotangent(ΔV₊ᴴ)) + X₀ = iszerotangent(ΔU₊) ? zero(U) : rdiv!(ΔU₊, Diagonal(S)) + Y₀ᴴ = iszerotangent(ΔV₊ᴴ) ? zero(Vᴴ) : ldiv!(Diagonal(S), ΔV₊ᴴ) AP = mul!(copy(A), U, Smat * Vᴴ, -1, 1) AP ./= S[1] S = S ./ S[1] From b5a73b17d6900e5798ca5d18e9afdadd9594b3d3 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Thu, 23 Apr 2026 00:28:44 +0200 Subject: [PATCH 06/13] fixes and improved numerical stability --- src/pullbacks/svd.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 0c7f74252..c69d2cb77 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -222,7 +222,7 @@ function svd_trunc_pullback2!( # Extract and check the cotangents ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ UdΔAV, ΔU₊, ΔV₊ᴴ = check_and_prepare_svd_cotangents( - U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, r, ind; degeneracy_atol, gauge_atol + U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, p; degeneracy_atol, gauge_atol ) ΔA = mul!(ΔA, U, UdΔAV * Vᴴ, 1, 1) # add the contribution to ΔA @@ -275,7 +275,7 @@ function svd_trunc_pullback!( # Extract and check the cotangents ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ UdΔAV, ΔU₊, ΔV₊ᴴ = check_and_prepare_svd_cotangents( - U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, r, ind; degeneracy_atol, gauge_atol + U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, p; degeneracy_atol, gauge_atol ) ΔA = mul!(ΔA, U, UdΔAV * Vᴴ, 1, 1) # add the contribution to ΔA @@ -285,20 +285,20 @@ function svd_trunc_pullback!( X₀ = iszerotangent(ΔU₊) ? zero(U) : rdiv!(ΔU₊, Diagonal(S)) Y₀ᴴ = iszerotangent(ΔV₊ᴴ) ? zero(Vᴴ) : ldiv!(Diagonal(S), ΔV₊ᴴ) AP = mul!(copy(A), U, Smat * Vᴴ, -1, 1) - AP ./= S[1] - S = S ./ S[1] - X₁ = rdiv!(AP * Y₀ᴴ', Diagonal(S)) + AP ./= S[end] + Sinv = S[end] ./ S + X₁ = rmul!(AP * Y₀ᴴ', Diagonal(Sinv)) X₁ .+= X₀ - Y₁ᴴ = ldiv!(Diagonal(S), X₀' * AP) + Y₁ᴴ = lmul!(Diagonal(Sinv), X₀' * AP) Y₁ᴴ .+= Y₀ᴴ Xₖ, Xₖ₊₁ = X₁, X₀ Yₖᴴ, Yₖ₊₁ᴴ = Y₁ᴴ, Y₀ᴴ APAᴴₖ, AᴴPAₖ = AP * AP', AP' * AP APAᴴₖ₊₁, AᴴPAₖ₊₁ = zero(APAᴴₖ), zero(AᴴPAₖ) - Sₖ, Sₖ₊₁ = S .^ 2, S + Sinvₖ, Sinvₖ₊₁ = Sinv .^ 2, Sinv for k in 1:maxiter - Xₖ₊₁ = rdiv!(mul!(Xₖ₊₁, APAᴴₖ, Xₖ), Diagonal(Sₖ)) - Yₖ₊₁ᴴ = ldiv!(Diagonal(Sₖ), mul!(Yₖ₊₁ᴴ, Yₖᴴ, AᴴPAₖ)) + Xₖ₊₁ = rmul!(mul!(Xₖ₊₁, APAᴴₖ, Xₖ), Diagonal(Sinvₖ)) + Yₖ₊₁ᴴ = lmul!(Diagonal(Sinvₖ), mul!(Yₖ₊₁ᴴ, Yₖᴴ, AᴴPAₖ)) if norm(Xₖ₊₁, Inf) < degeneracy_atol && norm(Yₖ₊₁ᴴ, Inf) < degeneracy_atol break end @@ -308,14 +308,14 @@ function svd_trunc_pullback!( @warn "Sylvester iteration did not converge after $k iterations, final norms: (X: $(norm(Xₖ₊₁, Inf)), Yᴴ: $(norm(Yₖ₊₁ᴴ, Inf)))" break end - Sₖ₊₁ .= Sₖ .^ 2 + Sinvₖ₊₁ .= Sinvₖ .^ 2 APAᴴₖ₊₁ = mul!(APAᴴₖ₊₁, APAᴴₖ, APAᴴₖ) AᴴPAₖ₊₁ = mul!(AᴴPAₖ₊₁, AᴴPAₖ, AᴴPAₖ) Xₖ, Xₖ₊₁ = Xₖ₊₁, Xₖ Yₖᴴ, Yₖ₊₁ᴴ = Yₖ₊₁ᴴ, Yₖᴴ APAᴴₖ, APAᴴₖ₊₁ = APAᴴₖ₊₁, APAᴴₖ AᴴPAₖ, AᴴPAₖ₊₁ = AᴴPAₖ₊₁, AᴴPAₖ - Sₖ, Sₖ₊₁ = Sₖ₊₁, Sₖ + Sinvₖ, Sinvₖ₊₁ = Sinvₖ₊₁, Sinvₖ end ΔA = mul!(ΔA, Xₖ, Vᴴ, 1, 1) ΔA = mul!(ΔA, U, Yₖᴴ, 1, 1) From b4df95c2c0fbcaf3b1dd7d749bb47579494f3f15 Mon Sep 17 00:00:00 2001 From: Jutho Date: Thu, 23 Apr 2026 08:33:33 +0200 Subject: [PATCH 07/13] one more code suggestion Co-authored-by: Lukas Devos --- src/pullbacks/svd.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index c69d2cb77..29d551231 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -81,8 +81,10 @@ function check_and_prepare_svd_cotangents( ΔV₊ᴴ = nothing aVᴴΔV₁ = zero!(similar(V₁ᴴ, (r, r))) end - mask = abs.(S₁' .- S₁) .< degeneracy_atol - Δgauge = max(Δgauge, norm(view(aUᴴΔU₁, mask) + view(aVᴴΔV₁, mask), Inf)) + bc = Base.broadcasted(S₁', S₁, aUᴴΔU₁, aVᴴΔV₁) do s1, s2, u, v + return abs(s1 - s2) < degeneracy_atol ? zero(u) + zero(v) : u + v + end + Δgauge = max(Δgauge, norm(bc, Inf)) if !iszerotangent(ΔSmat) ΔS = diagview(ΔSmat) From 62a5971be51c023ccc3e9c0c8a27b34791c5b5d7 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 23 Apr 2026 08:08:54 -0400 Subject: [PATCH 08/13] mark gauge dependence removal as public --- src/MatrixAlgebraKit.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 5e59b5763..8d66c4411 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -72,6 +72,15 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter :svd_pullback!, :svd_trunc_pullback!, :svd_vals_pullback! ) ) + eval( + Expr( + :public, :remove_svd_gauge_dependence!, + :remove_eig_gauge_dependence!, :remove_eigh_gauge_dependence!, + :remove_qr_gauge_dependence!, :remove_qr_null_gauge_dependence!, + :remove_lq_gauge_dependence!, :remove_lq_null_gauge_dependence!, + :remove_left_null_gauge_dependence!, :remove_right_null_gauge_dependence!, + ) + ) eval(Expr(:public, :is_left_isometric, :is_right_isometric)) end From e7aa6e4db38846800a4f8ce1deb9160f095be424 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 23 Apr 2026 10:46:52 -0400 Subject: [PATCH 09/13] Update src/pullbacks/svd.jl Co-authored-by: Jutho --- src/pullbacks/svd.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 29d551231..9b14f5922 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -127,9 +127,10 @@ Adds the pullback from the SVD of `A` to `ΔA` given the output `USVᴴ` of `svd In particular, it is assumed that `A ≈ U * S * Vᴴ`, or thus, that no singular values with magnitude less than `rank_atol` are missing from `S`. For the cotangents, an arbitrary number of singular vectors or singular values can be missing, i.e. for a matrix `A` with -size `(m, n)`, `ΔU` and `ΔVᴴ` can have sizes `(m, p)` and `(p, n)` respectively, whereas -`diagview(ΔS)` can have length `p`. In those cases, an additional list `ind` of length `p` -is required to specify which singular vectors and values are present in `ΔU`, `ΔS` and `ΔVᴴ`. +size `(m, n)`, `ΔU`, `ΔS` and `ΔVᴴ` can have sizes `(m, p)`, `(p, p)` and `(p, n)` respectively +and the argument `ind` is a list of length `p` indicating that these are cotangents corresponding to `U[:, ind]`, `S[ind, ind]` and `Vᴴ[ind, :]`, +whereas cotangents with respect to the other rows and columns are zero. +If `ind` is not present, `ΔU`, `ΔS` and `ΔVᴴ` are assumed to have the same size as `U`, `S` and `Vᴴ` respectively. A warning will be printed if the cotangents are not gauge-invariant, i.e. if the anti-hermitian part of `U' * ΔU + Vᴴ * ΔVᴴ'`, restricted to rows `i` and columns `j` for From 36d7581561cbd07069a818ca197315bde21afca1 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 23 Apr 2026 11:18:56 -0400 Subject: [PATCH 10/13] improve error messages --- src/pullbacks/svd.jl | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 9b14f5922..f7349c821 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -19,8 +19,8 @@ function check_and_prepare_svd_cotangents( if !iszerotangent(ΔU) ΔgaugeU = zero(eltype(S)) - m == size(ΔU, 1) || throw(DimensionMismatch("first dimension of ΔU ($(size(ΔU, 1))) does not match first dimension of U ($m)")) - length(indU) == size(ΔU, 2) || throw(DimensionMismatch("length of selected U columns ($(length(indU))) does not match second dimension of ΔU ($(size(ΔU, 2)))")) + m == size(ΔU, 1) || throw(DimensionMismatch(lazy"first dimension of ΔU ($(size(ΔU, 1))) does not match first dimension of U ($m)")) + length(indU) == size(ΔU, 2) || throw(DimensionMismatch(lazy"length of selected U columns ($(length(indU))) does not match second dimension of ΔU ($(size(ΔU, 2)))")) if indU == 1:r ΔU₁ = copy(ΔU) else @@ -51,8 +51,8 @@ function check_and_prepare_svd_cotangents( end if !iszerotangent(ΔVᴴ) ΔgaugeV = zero(eltype(S)) - n == size(ΔVᴴ, 2) || throw(DimensionMismatch("second dimension of ΔVᴴ ($(size(ΔVᴴ, 2))) does not match second dimension of Vᴴ ($n)")) - length(indV) == size(ΔVᴴ, 1) || throw(DimensionMismatch("length of selected Vᴴ rows ($(length(indV))) does not match first dimension of ΔVᴴ ($(size(ΔVᴴ, 1)))")) + n == size(ΔVᴴ, 2) || throw(DimensionMismatch(lazy"second dimension of ΔVᴴ ($(size(ΔVᴴ, 2))) does not match second dimension of Vᴴ ($n)")) + length(indV) == size(ΔVᴴ, 1) || throw(DimensionMismatch(lazy"length of selected Vᴴ rows ($(length(indV))) does not match first dimension of ΔVᴴ ($(size(ΔVᴴ, 1)))")) if indV == 1:r ΔV₁ᴴ = copy(ΔVᴴ) else @@ -88,7 +88,7 @@ function check_and_prepare_svd_cotangents( if !iszerotangent(ΔSmat) ΔS = diagview(ΔSmat) - length(indS) == length(ΔS) || throw(DimensionMismatch("length of selected S values ($(length(indS))) does not match length of ΔS ($(length(ΔS)))")) + length(indS) == length(ΔS) || throw(DimensionMismatch(lazy"length of selected S values ($(length(indS))) does not match length of ΔS ($(length(ΔS)))")) ΔS₁ = zero(S₁) for (j, i) in enumerate(indS) if i <= r @@ -146,7 +146,7 @@ function svd_pullback!( U, Smat, Vᴴ = USVᴴ m, n = size(U, 1), size(Vᴴ, 2) minmn = min(m, n) - (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)")) + (m, n) == size(ΔA) || throw(DimensionMismatch(lazy"size of ΔA ($(size(ΔA))) does not match size of USVᴴ ($m, $n)")) S = diagview(Smat) r = svd_rank(S; rank_atol) @@ -265,15 +265,14 @@ function svd_trunc_pullback!( gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ...), maxiter::Int = 1000, ) - # Extract the SVD components U, Smat, Vᴴ = USVᴴ m, n = size(U, 1), size(Vᴴ, 2) - (m, n) == size(ΔA) || throw(DimensionMismatch()) - p = size(U, 2) - p == size(Vᴴ, 1) || throw(DimensionMismatch()) + (m, n) == size(ΔA) || throw(DimensionMismatch(lazy"size of ΔA ($(size(ΔA))) does not match size of USVᴴ ($m, $n)")) S = diagview(Smat) - p == length(S) || throw(DimensionMismatch()) + p = length(S) + p == size(U, 2) || throw(DimensionMismatch(lazy"U has $p columns but S has $(length(S)) singular values")) + p == size(Vᴴ, 1) || throw(DimensionMismatch(lazy"Vᴴ has $p rows but S has $(length(S)) singular values")) # Extract and check the cotangents ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ From 995dccbce5cfd775619fe0a79ff9605b3ae860ec Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 23 Apr 2026 11:19:13 -0400 Subject: [PATCH 11/13] remove unused function --- src/pullbacks/svd.jl | 53 -------------------------------------------- 1 file changed, 53 deletions(-) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index f7349c821..27d339f07 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -205,59 +205,6 @@ A warning will be printed if the cotangents are not gauge-invariant, i.e. if the anti-hermitian part of `U' * ΔU + Vᴴ * ΔVᴴ'`, restricted to rows `i` and columns `j` for which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol`. """ -function svd_trunc_pullback2!( - ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ; - rank_atol::Real = 0, - degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), - gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ...), - maxiter::Int = 1000, - ) - - # Extract the SVD components - U, Smat, Vᴴ = USVᴴ - m, n = size(U, 1), size(Vᴴ, 2) - (m, n) == size(ΔA) || throw(DimensionMismatch()) - p = size(U, 2) - p == size(Vᴴ, 1) || throw(DimensionMismatch()) - S = diagview(Smat) - p == length(S) || throw(DimensionMismatch()) - - # Extract and check the cotangents - ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ - UdΔAV, ΔU₊, ΔV₊ᴴ = check_and_prepare_svd_cotangents( - U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, p; degeneracy_atol, gauge_atol - ) - ΔA = mul!(ΔA, U, UdΔAV * Vᴴ, 1, 1) # add the contribution to ΔA - - # The contributions from the orthogonal complement need to be treated differently - # ΔU and ΔVᴴ are already orthogonal to U and Vᴴ - if !(iszerotangent(ΔU₊) && iszerotangent(ΔV₊ᴴ)) - Aperp = mul!(copy(A), U, Smat * Vᴴ, -1, 1) - x₀ = iszerotangent(ΔU₊) ? zero(U) : rdiv!(ΔU₊, Diagonal(S)) - y₀ᴴ = iszerotangent(ΔV₊ᴴ) ? zero(Vᴴ) : ldiv!(Diagonal(S), ΔV₊ᴴ) - X = copy(x₀) - Yᴴ = copy(y₀ᴴ) - xₖ, xₖ₊₁ = x₀, zero(x₀) - yₖᴴ, yₖ₊₁ᴴ = y₀ᴴ, zero(y₀ᴴ) - for k in 1:maxiter - xₖ₊₁ = rdiv!(mul!(xₖ₊₁, Aperp, yₖᴴ'), Diagonal(S)) - yₖ₊₁ᴴ = ldiv!(Diagonal(S), mul!(yₖ₊₁ᴴ, xₖ', Aperp)) - X .+= xₖ₊₁ - Yᴴ .+= yₖ₊₁ᴴ - if norm(xₖ₊₁, Inf) < degeneracy_atol && norm(yₖ₊₁ᴴ, Inf) < degeneracy_atol - break - end - xₖ, xₖ₊₁ = xₖ₊₁, xₖ - yₖᴴ, yₖ₊₁ᴴ = yₖ₊₁ᴴ, yₖᴴ - if k == maxiter - @warn "Sylvester iteration did not converge after $k iterations, final norms: (x: $(norm(xₖ₊₁, Inf)), y: $(norm(yₖ₊₁ᴴ, Inf)))" - end - end - ΔA = mul!(ΔA, X, Vᴴ, 1, 1) - ΔA = mul!(ΔA, U, Yᴴ, 1, 1) - end - return ΔA -end function svd_trunc_pullback!( ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ; rank_atol::Real = 0, From ab43dd5c580353bec49f0db0f55b7b7ec57ef6b0 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 23 Apr 2026 11:31:30 -0400 Subject: [PATCH 12/13] formatting --- src/pullbacks/svd.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 27d339f07..94343e39a 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -81,8 +81,8 @@ function check_and_prepare_svd_cotangents( ΔV₊ᴴ = nothing aVᴴΔV₁ = zero!(similar(V₁ᴴ, (r, r))) end - bc = Base.broadcasted(S₁', S₁, aUᴴΔU₁, aVᴴΔV₁) do s1, s2, u, v - return abs(s1 - s2) < degeneracy_atol ? zero(u) + zero(v) : u + v + bc = Base.broadcasted(S₁', S₁, aUᴴΔU₁, aVᴴΔV₁) do s₁, s₂, u, v + return abs(s₁ - s₂) < degeneracy_atol ? zero(u) + zero(v) : u + v end Δgauge = max(Δgauge, norm(bc, Inf)) From efa0171e62bc6b4d04562042c53b197383aa7530 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 23 Apr 2026 11:48:50 -0400 Subject: [PATCH 13/13] more unicode --- src/pullbacks/svd.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 94343e39a..4b3524160 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -235,19 +235,19 @@ function svd_trunc_pullback!( Y₀ᴴ = iszerotangent(ΔV₊ᴴ) ? zero(Vᴴ) : ldiv!(Diagonal(S), ΔV₊ᴴ) AP = mul!(copy(A), U, Smat * Vᴴ, -1, 1) AP ./= S[end] - Sinv = S[end] ./ S - X₁ = rmul!(AP * Y₀ᴴ', Diagonal(Sinv)) + S⁻¹ = S[end] ./ S + X₁ = rmul!(AP * Y₀ᴴ', Diagonal(S⁻¹)) X₁ .+= X₀ - Y₁ᴴ = lmul!(Diagonal(Sinv), X₀' * AP) + Y₁ᴴ = lmul!(Diagonal(S⁻¹), X₀' * AP) Y₁ᴴ .+= Y₀ᴴ Xₖ, Xₖ₊₁ = X₁, X₀ Yₖᴴ, Yₖ₊₁ᴴ = Y₁ᴴ, Y₀ᴴ APAᴴₖ, AᴴPAₖ = AP * AP', AP' * AP APAᴴₖ₊₁, AᴴPAₖ₊₁ = zero(APAᴴₖ), zero(AᴴPAₖ) - Sinvₖ, Sinvₖ₊₁ = Sinv .^ 2, Sinv + S⁻¹ₖ, S⁻¹ₖ₊₁ = S⁻¹ .^ 2, S⁻¹ for k in 1:maxiter - Xₖ₊₁ = rmul!(mul!(Xₖ₊₁, APAᴴₖ, Xₖ), Diagonal(Sinvₖ)) - Yₖ₊₁ᴴ = lmul!(Diagonal(Sinvₖ), mul!(Yₖ₊₁ᴴ, Yₖᴴ, AᴴPAₖ)) + Xₖ₊₁ = rmul!(mul!(Xₖ₊₁, APAᴴₖ, Xₖ), Diagonal(S⁻¹ₖ)) + Yₖ₊₁ᴴ = lmul!(Diagonal(S⁻¹ₖ), mul!(Yₖ₊₁ᴴ, Yₖᴴ, AᴴPAₖ)) if norm(Xₖ₊₁, Inf) < degeneracy_atol && norm(Yₖ₊₁ᴴ, Inf) < degeneracy_atol break end @@ -257,14 +257,14 @@ function svd_trunc_pullback!( @warn "Sylvester iteration did not converge after $k iterations, final norms: (X: $(norm(Xₖ₊₁, Inf)), Yᴴ: $(norm(Yₖ₊₁ᴴ, Inf)))" break end - Sinvₖ₊₁ .= Sinvₖ .^ 2 + S⁻¹ₖ₊₁ .= S⁻¹ₖ .^ 2 APAᴴₖ₊₁ = mul!(APAᴴₖ₊₁, APAᴴₖ, APAᴴₖ) AᴴPAₖ₊₁ = mul!(AᴴPAₖ₊₁, AᴴPAₖ, AᴴPAₖ) Xₖ, Xₖ₊₁ = Xₖ₊₁, Xₖ Yₖᴴ, Yₖ₊₁ᴴ = Yₖ₊₁ᴴ, Yₖᴴ APAᴴₖ, APAᴴₖ₊₁ = APAᴴₖ₊₁, APAᴴₖ AᴴPAₖ, AᴴPAₖ₊₁ = AᴴPAₖ₊₁, AᴴPAₖ - Sinvₖ, Sinvₖ₊₁ = Sinvₖ₊₁, Sinvₖ + S⁻¹ₖ, S⁻¹ₖ₊₁ = S⁻¹ₖ₊₁, S⁻¹ₖ end ΔA = mul!(ΔA, Xₖ, Vᴴ, 1, 1) ΔA = mul!(ΔA, U, Yₖᴴ, 1, 1)