Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions src/pullbacks/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,18 @@ function check_and_prepare_qr_cotangents(
ΔR₁₁ = UpperTriangular(view(ΔR, 1:p, 1:p))
ΔR₁₂ = view(ΔR, 1:p, (p + 1):n)
ΔR₂₂ = view(ΔR, (p + 1):minmn, (p + 1):n)
Δgauge_R = norm(view(ΔR₂₂, uppertriangularind(ΔR₂₂)), Inf)
Δgauge_R = max(Δgauge_R, norm(view(ΔR₂₂, diagind(ΔR₂₂)), Inf))
Δgauge = max(Δgauge, Δgauge_R)
if p < minmn # otherwise ΔR₂₂ is empty
# uppertriangularind generates linear indices
# compute the appropriate offset in ΔR so we aren't
# operating on a view-of-view, which doesn't work
# for GPU arrays
I = uppertriangularind(ΔR₂₂)
upper_inds = view(LinearIndices(ΔR), (p + 1):minmn, (p + 1):n)[I]
ΔR₂₂upper = view(ΔR, upper_inds)
Δgauge_R = norm(ΔR₂₂upper, Inf)
Δgauge_R = max(Δgauge_R, norm(view(ΔR₂₂, diagind(ΔR₂₂)), Inf))
Δgauge = max(Δgauge, Δgauge_R)
end
else
ΔR₁₁ = nothing
ΔR₁₂ = nothing
Expand Down Expand Up @@ -75,7 +84,7 @@ function qr_pullback!(


Q₁ = view(Q, :, 1:p)
R₁₁ = UpperTriangular(view(R, 1:p, 1:p))
R₁₁ = UpperTriangular(R[1:p, 1:p])

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a subtle and impactful change. The UpperTriangular wrapper is really only necessary to enable the rdiv! call below. If GPUs cannot deal with UpperTriangular of a view of a GPUArray, then maybe we need to call the corresponding BLAS/LAPACK methods directly, or have some intermediate wrapper like rdiv_uppertriangular!.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If GPUs cannot deal with UpperTriangular of a view of a GPUArray

Indeed they can't :(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I even wonder how rdiv!(::Matrix, ::UpperTriangular) is evaluated on the GPU, since you need cuSOLVERDx to access TRSM.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW cuSOLVERDx is only for device side code, so it can only be called by running CUDA kernels, not from host side code as we are doing here...

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, this is internally in dividing using \ by a general matrix, which is then evaluated by computing its QR decomposition and then directly calling cuBLAS.trsm! on the triangular factor. Here we already have the triangular factor, so I indeed want to call trsm!, but using generic code, which is why I was using ldiv!/rdiv!. And I don't see where rdiv!(first_arg, second_arg::UpperTrangiular) is then actually lowered to cuBLAS.trsm! in the CuArray case (and why that only works for a pure CuMatrix and not for a view over it.)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To unblock this: shall we indeed create a helper rdiv!_uppertriangular! that for now just avoids the copy on the CPU and simply copies on the GPU with a # TODO: dispatch to trsm! directly

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But Then I don't understand why it doesn't work. The argument B::StridedCuMatrix in https://github.com/JuliaGPU/CUDA.jl/blob/fbb90981cbde21d979087ad518a510f5b38f95b3/lib/cublas/src/linalg.jl#L443 should accept a view of a CuMatrix, no? My apologies for being annoying, my lack of access to a GPU to test these things myself makes me ask these questions.

R₁₂ = view(R, 1:p, (p + 1):n)

ΔA₁ = view(ΔA, :, 1:p)
Expand All @@ -101,7 +110,8 @@ function qr_pullback!(
Md = diagview(M)
Md .= real.(Md)
end
ΔA₁ .+= rdiv!(mul!(ΔQ₁, Q₁, M, +1, 1), R₁₁')
mul!(ΔQ₁, Q₁, M, +1, 1)
ΔA₁ .+= rdiv!(ΔQ₁, R₁₁')
return ΔA
end

Expand Down Expand Up @@ -147,7 +157,8 @@ ambiguity. Additionally, rows of `ΔR` beyond the rank are zeroed out.
"""
function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(R))
r = MatrixAlgebraKit.qr_rank(R; rank_atol)
minmn = min(size(A)...)
m, n = size(A, 1), size(A, 2)
minmn = min(m, n)
Q₁ = view(Q, :, 1:r)
ΔQ₂ = view(ΔQ, :, (r + 1):minmn)
zero!(ΔQ₂)
Expand All @@ -160,7 +171,16 @@ function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = MatrixAlgebr
end
ΔR₂₂ = view(ΔR, (r + 1):minmn, (r + 1):size(R, 2))
zero!(diagview(ΔR₂₂))
zero!(view(ΔR₂₂, uppertriangularind(ΔR₂₂)))
if r < minmn
# uppertriangularind generates linear indices
# compute the appropriate offset in ΔR so we aren't
# operating on a view-of-view, which doesn't work
# for GPU arrays
Comment thread
kshyatt marked this conversation as resolved.
I = uppertriangularind(ΔR₂₂)
upper_inds = view(LinearIndices(ΔR), (r + 1):minmn, (r + 1):n)[I]
ΔR₂₂upper = view(ΔR, upper_inds)
zero!(ΔR₂₂upper)
end
return ΔQ, ΔR
end

Expand Down
7 changes: 7 additions & 0 deletions test/mooncake/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,11 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.test_mooncake_qr(AT, (m, m); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
if T ∈ BLASFloats && CUDA.functional()
TestSuite.test_mooncake_qr(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
#=if m == n
AT = Diagonal{T, CuVector{T}}
TestSuite.test_mooncake_qr(AT, (m, m); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end=# # currently broken
end
end
Loading