From 54a5188ad4586e9214c14e69bd7831aa624a2566 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 15 Oct 2025 18:10:59 -0400 Subject: [PATCH] GPU improvements for SVD Pull in SVD-specific changes from ksh/tk --- src/implementations/svd.jl | 45 ++++++++++++++++++++++++++++---------- test/amd/svd.jl | 25 +++++++++++++++++++-- test/cuda/svd.jl | 24 +++++++++++++++++--- test/svd.jl | 14 ++++++++++++ 4 files changed, 91 insertions(+), 17 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index e8ec7e21..9ebc7319 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -347,25 +347,46 @@ function _gpu_gesvdj!( ) throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ))) end +function _gpu_gesvd_maybe_transpose!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix) + m, n = size(A) + m ≥ n && return _gpu_gesvd!(A, S, U, Vᴴ) + # both CUSOLVER and ROCSOLVER require m ≥ n for gesvd (QR_Iteration) + # if this condition is not met, do the SVD via adjoint + minmn = min(m, n) + Aᴴ = min(m, n) > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A') + Uᴴ = similar(U') + V = similar(Vᴴ') + if size(U) == (m, m) + _gpu_gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ) + else + _gpu_gesvd!(Aᴴ, S, V, Uᴴ) + end + length(U) > 0 && adjoint!(U, Uᴴ) + length(Vᴴ) > 0 && adjoint!(Vᴴ, V) + return U, S, Vᴴ +end + # GPU SVD implementation -function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) +function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) check_input(svd_full!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ fill!(S, zero(eltype(S))) m, n = size(A) minmn = min(m, n) + if minmn == 0 + one!(U) + zero!(S) + one!(Vᴴ) + return USVᴴ + end if alg isa GPU_QRIteration isempty(alg.kwargs) || - throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments")) - _gpu_gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ) + @warn "GPU_QRIteration does not accept any keyword arguments" + _gpu_gesvd_maybe_transpose!(A, view(S, 1:minmn, 1), U, Vᴴ) elseif alg isa GPU_SVDPolar _gpu_Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...) elseif alg isa GPU_Jacobi _gpu_gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...) - # elseif alg isa LAPACK_Bisection - # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) - # elseif alg isa LAPACK_Jacobi - # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) else throw(ArgumentError("Unsupported SVD algorithm")) end @@ -390,13 +411,13 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran return USVᴴtrunc..., ϵ end -function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) +function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) check_input(svd_compact!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ if alg isa GPU_QRIteration isempty(alg.kwargs) || - throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments")) - _gpu_gesvd!(A, S.diag, U, Vᴴ) + @warn "GPU_QRIteration does not accept any keyword arguments" + _gpu_gesvd_maybe_transpose!(A, S.diag, U, Vᴴ) elseif alg isa GPU_SVDPolar _gpu_Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...) elseif alg isa GPU_Jacobi @@ -416,8 +437,8 @@ function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm) U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) if alg isa GPU_QRIteration isempty(alg.kwargs) || - throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments")) - _gpu_gesvd!(A, S, U, Vᴴ) + @warn "GPU_QRIteration does not accept any keyword arguments" + _gpu_gesvd_maybe_transpose!(A, S, U, Vᴴ) elseif alg isa GPU_SVDPolar _gpu_Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...) elseif alg isa GPU_Jacobi diff --git a/test/amd/svd.jl b/test/amd/svd.jl index 893b41b2..4f58f925 100644 --- a/test/amd/svd.jl +++ b/test/amd/svd.jl @@ -15,7 +15,6 @@ include(joinpath("..", "utilities.jl")) k = min(m, n) algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()) @testset "algorithm $alg" for alg in algs - n > m && alg isa ROCSOLVER_QRIteration && continue # not supported minmn = min(m, n) A = ROCArray(randn(rng, T, m, n)) @@ -41,6 +40,9 @@ include(joinpath("..", "utilities.jl")) Sd = svd_vals(A, alg) @test ROCArray(diagview(S)) ≈ Sd # ROCArray is necessary because norm of ROCArray view with non-unit step is broken + if alg isa ROCSOLVER_QRIteration + @test_warn "GPU_QRIteration does not accept any keyword arguments" svd_compact!(copy!(Ac, A), (U, S, Vᴴ), ROCSOLVER_QRIteration(; bad = "bad")) + end end end end @@ -51,7 +53,6 @@ end @testset "size ($m, $n)" for n in (37, m, 63) algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()) @testset "algorithm $alg" for alg in algs - n > m && alg isa ROCSOLVER_QRIteration && continue # not supported A = ROCArray(randn(rng, T, m, n)) U, S, Vᴴ = svd_full(A; alg) @test U isa ROCMatrix{T} && size(U) == (m, m) @@ -81,6 +82,26 @@ end @test Sc === Sc2 @test ROCArray(diagview(S)) ≈ Sc # ROCArray is necessary because norm of ROCArray view with non-unit step is broken + if alg isa ROCSOLVER_QRIteration + @test_warn "GPU_QRIteration does not accept any keyword arguments" svd_full!(copy!(Ac, A), (U, S, Vᴴ), ROCSOLVER_QRIteration(; bad = "bad")) + @test_warn "GPU_QRIteration does not accept any keyword arguments" svd_vals!(copy!(Ac, A), Sc, ROCSOLVER_QRIteration(; bad = "bad")) + end + end + end + @testset "size (0, 0)" begin + algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()) + @testset "algorithm $alg" for alg in algs + A = ROCArray(randn(rng, T, 0, 0)) + U, S, Vᴴ = svd_full(A; alg) + @test U isa ROCMatrix{T} && size(U) == (0, 0) + @test S isa ROCMatrix{real(T)} && size(S) == (0, 0) + @test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (0, 0) + @test U * S * Vᴴ ≈ A + @test isapproxone(U' * U) + @test isapproxone(U * U') + @test isapproxone(Vᴴ * Vᴴ') + @test isapproxone(Vᴴ' * Vᴴ) + @test all(isposdef, diagview(S)) end end end diff --git a/test/cuda/svd.jl b/test/cuda/svd.jl index 36f923e3..bfe993f8 100644 --- a/test/cuda/svd.jl +++ b/test/cuda/svd.jl @@ -15,7 +15,6 @@ include(joinpath("..", "utilities.jl")) k = min(m, n) algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi()) @testset "algorithm $alg" for alg in algs - n > m && alg isa CUSOLVER_QRIteration && continue # not supported minmn = min(m, n) A = CuArray(randn(rng, T, m, n)) @@ -41,6 +40,9 @@ include(joinpath("..", "utilities.jl")) Sd = svd_vals(A, alg) @test CuArray(diagview(S)) ≈ Sd # CuArray is necessary because norm of CuArray view with non-unit step is broken + if alg isa CUSOLVER_QRIteration + @test_warn "GPU_QRIteration does not accept any keyword arguments" svd_compact!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad")) + end end end end @@ -51,7 +53,6 @@ end @testset "size ($m, $n)" for n in (37, m, 63) algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi()) @testset "algorithm $alg" for alg in algs - n > m && alg isa CUSOLVER_QRIteration && continue # not supported A = CuArray(randn(rng, T, m, n)) U, S, Vᴴ = svd_full(A; alg) @test U isa CuMatrix{T} && size(U) == (m, m) @@ -82,8 +83,26 @@ end @test Sc === Sc2 @test CuArray(diagview(S)) ≈ Sc # CuArray is necessary because norm of CuArray view with non-unit step is broken + if alg isa CUSOLVER_QRIteration + @test_warn "GPU_QRIteration does not accept any keyword arguments" svd_full!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad")) + @test_warn "GPU_QRIteration does not accept any keyword arguments" svd_vals!(copy!(Ac, A), Sc, CUSOLVER_QRIteration(; bad = "bad")) + end end + end + @testset "size (0, 0)" begin + algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi()) @testset "algorithm $alg" for alg in algs + A = CuArray(randn(rng, T, 0, 0)) + U, S, Vᴴ = svd_full(A; alg) + @test U isa CuMatrix{T} && size(U) == (0, 0) + @test S isa CuMatrix{real(T)} && size(S) == (0, 0) + @test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (0, 0) + @test U * S * Vᴴ ≈ A + @test isapproxone(U' * U) + @test isapproxone(U * U') + @test isapproxone(Vᴴ * Vᴴ') + @test isapproxone(Vᴴ' * Vᴴ) + @test all(isposdef, diagview(S)) end end end @@ -96,7 +115,6 @@ end p = min(m, n) - k - 1 algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi(), CUSOLVER_Randomized(; k = k, p = p, niters = 100)) @testset "algorithm $alg" for alg in algs - n > m && alg isa CUSOLVER_QRIteration && continue # not supported hA = randn(rng, T, m, n) S₀ = svd_vals(hA) A = CuArray(hA) diff --git a/test/svd.jl b/test/svd.jl index d9016b4e..acb27946 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -92,6 +92,20 @@ end @test diagview(S) ≈ Sc end end + @testset "size (0, 0)" begin + @testset "algorithm $alg" for alg in + (LAPACK_DivideAndConquer(), LAPACK_QRIteration()) + A = randn(rng, T, 0, 0) + U, S, Vᴴ = svd_full(A; alg) + @test U isa Matrix{T} && size(U) == (0, 0) + @test S isa Matrix{real(T)} && size(S) == (0, 0) + @test Vᴴ isa Matrix{T} && size(Vᴴ) == (0, 0) + @test U * S * Vᴴ ≈ A + @test isunitary(U) + @test isunitary(Vᴴ) + @test all(isposdef, diagview(S)) + end + end end @testset "svd_trunc! for T = $T" for T in BLASFloats