Skip to content
14 changes: 1 addition & 13 deletions ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,19 +240,7 @@ for f in (:svd_compact!, :svd_full!)
USVᴴval = something(cache_USVᴴ, USVᴴ.val)
if !isa(A, Const)
minmn = min(size(A.val)...)
if $(f == svd_compact!) # compact
svd_pullback!(A.dval, Aval, USVᴴval, dUSVᴴ)
else # full
# TODO: revisit this once `svd_pullback` supports `svd_full` output and adjoints
U, S, Vᴴ = USVᴴval
vU = view(U, :, 1:minmn)
vS = Diagonal(view(diagview(S), 1:minmn))
vVᴴ = view(Vᴴ, 1:minmn, :)
vdU = view(dUSVᴴ[1], :, 1:minmn)
vdS = Diagonal(view(diagview(dUSVᴴ[2]), 1:minmn))
vdVᴴ = view(dUSVᴴ[3], 1:minmn, :)
svd_pullback!(A.dval, Aval, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
end
svd_pullback!(A.dval, Aval, USVᴴval, dUSVᴴ)
end
!isa(USVᴴ, Const) && make_zero!(USVᴴ.dval)
return (nothing, nothing, nothing)
Expand Down
30 changes: 4 additions & 26 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,28 +418,17 @@ for (f!, f) in (
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
A, dA = arrayify(A_dA)
Ac = copy(A)
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
Ac = copy(A)
USVᴴc = copy.(USVᴴ)
output = $f!(A, USVᴴ, Mooncake.primal(alg_dalg))
function svd_adjoint(::NoRData)
copy!(A, Ac)
if $(f! == svd_compact!)
svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
else # full
minmn = min(size(A)...)
vU = view(U, :, 1:minmn)
vS = Diagonal(diagview(S)[1:minmn])
vVᴴ = view(Vᴴ, 1:minmn, :)
vdU = view(dU, :, 1:minmn)
vdS = Diagonal(diagview(dS)[1:minmn])
vdVᴴ = view(dVᴴ, 1:minmn, :)
svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
end
svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
copy!(U, USVᴴc[1])
copy!(S, USVᴴc[2])
copy!(Vᴴ, USVᴴc[3])
Expand All @@ -448,7 +437,7 @@ for (f!, f) in (
zero!(dVᴴ)
return NoRData(), NoRData(), NoRData(), NoRData()
end
return CoDual(output, dUSVᴴ), svd_adjoint
return USVᴴ_dUSVᴴ, svd_adjoint
end
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
Expand All @@ -465,18 +454,7 @@ for (f!, f) in (
U, dU = arrayify(U, dU_)
S, dS = arrayify(S, dS_)
Vᴴ, dVᴴ = arrayify(Vᴴ, dVᴴ_)
if $(f == svd_compact)
svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
else # full
minmn = min(size(A)...)
vU = view(U, :, 1:minmn)
vS = Diagonal(view(diagview(S), 1:minmn))
vVᴴ = view(Vᴴ, 1:minmn, :)
vdU = view(dU, :, 1:minmn)
vdS = Diagonal(view(diagview(dS), 1:minmn))
vdVᴴ = view(dVᴴ, 1:minmn, :)
svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
end
svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
zero!(dU)
zero!(dS)
zero!(dVᴴ)
Expand Down
9 changes: 9 additions & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter
:svd_pullback!, :svd_trunc_pullback!, :svd_vals_pullback!
)
)
eval(
Expr(
:public, :remove_svd_gauge_dependence!,
:remove_eig_gauge_dependence!, :remove_eigh_gauge_dependence!,
:remove_qr_gauge_dependence!, :remove_qr_null_gauge_dependence!,
:remove_lq_gauge_dependence!, :remove_lq_null_gauge_dependence!,
:remove_left_null_gauge_dependence!, :remove_right_null_gauge_dependence!,
)
)
eval(Expr(:public, :is_left_isometric, :is_right_isometric))
end

Expand Down
22 changes: 22 additions & 0 deletions src/pullbacks/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,25 @@ function eig_vals_pullback!(
ΔDV = (diagonal(ΔD), nothing)
return eig_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol)
end

"""
remove_eig_gauge_dependence!(ΔV, D, V; degeneracy_atol = ...)

Remove the gauge-dependent part from the cotangent `ΔV` of the eigenvector matrix `V`. The
eigenvectors are only determined up to a scalar factor (or an abitrary linear transformation
across eigenvectors associated with degenerate eigenvalues), so the corresponding components of
`ΔV` are projected out.
"""
function remove_eig_gauge_dependence!(
ΔV, D, V, ind = axes(ΔV, 2);
degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D)
)
length(ind) == size(ΔV, 2) || throw(DimensionMismatch())
indV = axes(V, 2)[ind]
Vp = view(V, :, indV)
Ddiag = view(diagview(D), indV)
gaugepart = Vp' * ΔV
gaugepart[abs.(transpose(Ddiag) .- Ddiag) .>= degeneracy_atol] .= 0
mul!(ΔV, Vp / (Vp' * Vp), gaugepart, -1, 1)
return ΔV
end
22 changes: 22 additions & 0 deletions src/pullbacks/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,25 @@ function eigh_vals_pullback!(
ΔDV = (diagonal(ΔD), nothing)
return eigh_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol)
end

"""
remove_eigh_gauge_dependence!(ΔV, D, V; degeneracy_atol = ...)

Remove the gauge-dependent part from the cotangent `ΔV` of the Hermitian eigenvector matrix
`V`. The eigenvectors are only determined up to a complex phase (or a unitary transformation
across eigenvectors associated with degenerate eigenvalues), so the corresponding anti-Hermitian
components of `V' * ΔV` are projected out.
"""
function remove_eigh_gauge_dependence!(
ΔV, D, V, ind = axes(ΔV, 2);
degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D)
)
length(ind) == size(ΔV, 2) || throw(DimensionMismatch())
indV = axes(V, 2)[ind]
Vp = view(V, :, indV)
Ddiag = view(diagview(D), indV)
gaugepart = project_antihermitian!(Vp' * ΔV)
gaugepart[abs.(transpose(Ddiag) .- Ddiag) .>= degeneracy_atol] .= 0
mul!(ΔV, Vp, gaugepart, -1, 1)
return ΔV
end
123 changes: 85 additions & 38 deletions src/pullbacks/lq.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,45 @@
lq_rank(L; kwargs...) = qr_rank(L; kwargs...)

function check_lq_cotangents(
function check_and_prepare_lq_cotangents(
L, Q, ΔL, ΔQ, p::Int;
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
)
minmn = min(size(L, 1), size(Q, 2))
m, n = size(L, 1), size(Q, 2)
minmn = min(m, n)
Δgauge = abs(zero(eltype(Q)))
Q₁ = view(Q, 1:p, :)
ΔQ₁ = zero!(similar(Q₁))
if !iszerotangent(ΔQ)
ΔQ₂ = view(ΔQ, (p + 1):minmn, :)
ΔQ₃ = ΔQ[(minmn + 1):size(Q, 1), :]
Δgauge_Q = norm(ΔQ₂, Inf)
Q₁ = view(Q, 1:p, :)
ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁'
mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁, -1, 1)
Δgauge_Q = max(Δgauge_Q, norm(ΔQ₃, Inf))
size(ΔQ) == size(Q) || throw(DimensionMismatch("ΔQ must have the same size as Q"))
ΔQ₁ .= view(ΔQ, 1:p, 1:n)
if p == minmn # full rank case, ΔQ₃ contains gauge-invariant information along Q₁
Q₃ = view(Q, (minmn + 1):size(Q, 1), :)
ΔQ₃ = view(ΔQ, (minmn + 1):size(Q, 1), :)
ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁'
mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁, -1, 1)
Δgauge_Q = norm(ΔQ₃, Inf)
mul!(ΔQ₁, ΔQ₃Q₁ᴴ', Q₃, -1, 1)
else
ΔQ₂ = view(ΔQ, (p + 1):size(ΔQ, 1), :)
Δgauge_Q = norm(ΔQ₂, Inf)
end
Δgauge = max(Δgauge, Δgauge_Q)
end
if !iszerotangent(ΔL)
ΔL22 = view(ΔL, (p + 1):size(ΔL, 1), (p + 1):minmn)
Δgauge_L = norm(view(ΔL22, lowertriangularind(ΔL22)), Inf)
Δgauge_L = max(Δgauge_L, norm(view(ΔL22, diagind(ΔL22)), Inf))
size(ΔL) == size(L) || throw(DimensionMismatch("ΔL must have the same size as L"))
ΔL₁₁ = LowerTriangular(view(ΔL, 1:p, 1:p))
ΔL₂₁ = view(ΔL, (p + 1):size(ΔL, 1), 1:p)
ΔL₂₂ = view(ΔL, (p + 1):size(ΔL, 1), (p + 1):minmn)
Δgauge_L = norm(view(ΔL₂₂, lowertriangularind(ΔL₂₂)), Inf)
Δgauge_L = max(Δgauge_L, norm(view(ΔL₂₂, diagind(ΔL₂₂)), Inf))
Δgauge = max(Δgauge, Δgauge_L)
else
ΔL₁₁ = nothing
ΔL₂₁ = nothing
end
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return nothing
return ΔL₁₁, ΔL₂₁, ΔQ₁
end

"""
Expand Down Expand Up @@ -53,53 +68,37 @@ function lq_pullback!(
L, Q = LQ
m = size(L, 1)
n = size(Q, 2)
minmn = min(m, n)
p = lq_rank(L; rank_atol)
(m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of L*Q ($m, $n)"))

ΔL, ΔQ = ΔLQ

Q₁ = view(Q, 1:p, :)
L₁₁ = LowerTriangular(view(L, 1:p, 1:p))
L₂₁ = view(L, (p + 1):m, 1:p)
Q₁ = view(Q, 1:p, :)

ΔA₁ = view(ΔA, 1:p, :)
ΔA₂ = view(ΔA, (p + 1):m, :)

check_lq_cotangents(L, Q, ΔL, ΔQ, p; gauge_atol)
ΔL, ΔQ = ΔLQ
ΔL₁₁, ΔL₂₁, ΔQ₁ = check_and_prepare_lq_cotangents(L, Q, ΔL, ΔQ, p; gauge_atol)

ΔQ̃ = zero!(similar(Q, (p, n)))
if !iszerotangent(ΔQ)
ΔQ₁ = view(ΔQ, 1:p, :)
copy!(ΔQ̃, ΔQ₁)
if minmn < size(Q, 1)
ΔQ₃ = view(ΔQ, (minmn + 1):size(ΔQ, 1), :)
Q₃ = view(Q, (minmn + 1):size(Q, 1), :)
ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁'
ΔQ̃ = mul!(ΔQ̃, ΔQ₃Q₁ᴴ', Q₃, -1, 1)
end
end
if !iszerotangent(ΔL) && m > p
L₂₁ = view(L, (p + 1):m, 1:p)
ΔL₂₁ = view(ΔL, (p + 1):m, 1:p)
ΔQ̃ = mul!(ΔQ̃, L₂₁' * ΔL₂₁, Q₁, -1, 1)
ΔQ₁ = mul!(ΔQ₁, L₂₁' * ΔL₂₁, Q₁, -1, 1)
# Adding ΔA₂ contribution
ΔA₂ = mul!(ΔA₂, ΔL₂₁, Q₁, 1, 1)
end

# construct M
M = zero!(similar(L, (p, p)))
if !iszerotangent(ΔL)
ΔL₁₁ = LowerTriangular(view(ΔL, 1:p, 1:p))
M = mul!(M, L₁₁', ΔL₁₁, 1, 1)
end
M = mul!(M, ΔQ̃, Q₁', -1, 1)
M = mul!(M, ΔQ₁, Q₁', -1, 1)
view(M, uppertriangularind(M)) .= conj.(view(M, lowertriangularind(M)))
if eltype(M) <: Complex
Md = diagview(M)
Md .= real.(Md)
end
ldiv!(L₁₁', M)
ldiv!(L₁₁', ΔQ̃)
ΔA₁ = mul!(ΔA₁, M, Q₁, +1, 1)
ΔA₁ .+= ΔQ̃
ΔA₁ .+= ldiv!(L₁₁', mul!(ΔQ₁, M, Q₁, +1, 1))
return ΔA
end

Expand Down Expand Up @@ -134,3 +133,51 @@ function lq_null_pullback!(
end
return ΔA
end


"""
remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = ...)

Remove the gauge-dependent part from the cotangents `ΔL` and `ΔQ` of the LQ factors `L` and
`Q`. For the full LQ decomposition, the extra rows of `Q` beyond the rank `r` are not uniquely
determined by `A`, so the corresponding part of `ΔQ` is projected to remove this ambiguity.
Additionally, columns of `ΔL` beyond the rank are zeroed out.
"""
function remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(L))
r = MatrixAlgebraKit.lq_rank(L; rank_atol)
minmn = min(size(A)...)
Q₁ = view(Q, 1:r, :)
ΔQ₂ = view(ΔQ, (r + 1):minmn, :)
zero!(ΔQ₂)
ΔQ₃ = view(ΔQ, (minmn + 1):size(ΔQ, 1), :) # extra rows in the case of lq_full
if r == minmn
ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁'
mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁)
else # rank-deficient case, no gauge-invariant information
zero!(ΔQ₃)
end
ΔL₂₂ = view(ΔL, (r + 1):size(ΔL, 1), (r + 1):minmn)
zero!(diagview(ΔL₂₂))
zero!(view(ΔL₂₂, lowertriangularind(ΔL₂₂)))
return ΔL, ΔQ
end

"""
remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ)

Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the LQ null space `Nᴴ`. The null
space is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the row span of
the compact LQ factor `Q₁`.
"""
function remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ)
return mul!(ΔNᴴ, ΔNᴴ * Nᴴ', Nᴴ, -1, 1)
end

"""
remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ)

Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the right null space `Nᴴ`. The
null space basis is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the
row span of the compact LQ factor `Q₁` of `A`.
"""
remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) = remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ)
Loading
Loading