Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions src/implementations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,25 +347,46 @@ function _gpu_gesvdj!(
)
throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ)))
end
function _gpu_gesvd_maybe_transpose!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix)
m, n = size(A)
m ≥ n && return _gpu_gesvd!(A, S, U, Vᴴ)
# both CUSOLVER and ROCSOLVER require m ≥ n for gesvd (QR_Iteration)
# if this condition is not met, do the SVD via adjoint
minmn = min(m, n)
Aᴴ = min(m, n) > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A')
Uᴴ = similar(U')
V = similar(Vᴴ')
if size(U) == (m, m)
_gpu_gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ)
else
_gpu_gesvd!(Aᴴ, S, V, Uᴴ)
end
length(U) > 0 && adjoint!(U, Uᴴ)
length(Vᴴ) > 0 && adjoint!(Vᴴ, V)
return U, S, Vᴴ
end

# GPU SVD implementation
function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
check_input(svd_full!, A, USVᴴ, alg)
U, S, Vᴴ = USVᴴ
fill!(S, zero(eltype(S)))
m, n = size(A)
minmn = min(m, n)
if minmn == 0
one!(U)
zero!(S)
one!(Vᴴ)
return USVᴴ
end
if alg isa GPU_QRIteration
isempty(alg.kwargs) ||
throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments"))
_gpu_gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ)
@warn "GPU_QRIteration does not accept any keyword arguments"
_gpu_gesvd_maybe_transpose!(A, view(S, 1:minmn, 1), U, Vᴴ)
elseif alg isa GPU_SVDPolar
_gpu_Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
elseif alg isa GPU_Jacobi
_gpu_gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
# elseif alg isa LAPACK_Bisection
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
# elseif alg isa LAPACK_Jacobi
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
else
throw(ArgumentError("Unsupported SVD algorithm"))
end
Expand All @@ -390,13 +411,13 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran
return USVᴴtrunc..., ϵ
end

function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
check_input(svd_compact!, A, USVᴴ, alg)
U, S, Vᴴ = USVᴴ
if alg isa GPU_QRIteration
isempty(alg.kwargs) ||
throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments"))
_gpu_gesvd!(A, S.diag, U, Vᴴ)
@warn "GPU_QRIteration does not accept any keyword arguments"
_gpu_gesvd_maybe_transpose!(A, S.diag, U, Vᴴ)
elseif alg isa GPU_SVDPolar
_gpu_Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...)
elseif alg isa GPU_Jacobi
Expand All @@ -416,8 +437,8 @@ function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
if alg isa GPU_QRIteration
isempty(alg.kwargs) ||
throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments"))
_gpu_gesvd!(A, S, U, Vᴴ)
@warn "GPU_QRIteration does not accept any keyword arguments"
_gpu_gesvd_maybe_transpose!(A, S, U, Vᴴ)
elseif alg isa GPU_SVDPolar
_gpu_Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...)
elseif alg isa GPU_Jacobi
Expand Down
25 changes: 23 additions & 2 deletions test/amd/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ include(joinpath("..", "utilities.jl"))
k = min(m, n)
algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
@testset "algorithm $alg" for alg in algs
n > m && alg isa ROCSOLVER_QRIteration && continue # not supported
minmn = min(m, n)
A = ROCArray(randn(rng, T, m, n))

Expand All @@ -41,6 +40,9 @@ include(joinpath("..", "utilities.jl"))
Sd = svd_vals(A, alg)
@test ROCArray(diagview(S)) ≈ Sd
# ROCArray is necessary because norm of ROCArray view with non-unit step is broken
if alg isa ROCSOLVER_QRIteration
@test_warn "GPU_QRIteration does not accept any keyword arguments" svd_compact!(copy!(Ac, A), (U, S, Vᴴ), ROCSOLVER_QRIteration(; bad = "bad"))
end
end
end
end
Expand All @@ -51,7 +53,6 @@ end
@testset "size ($m, $n)" for n in (37, m, 63)
algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
@testset "algorithm $alg" for alg in algs
n > m && alg isa ROCSOLVER_QRIteration && continue # not supported
A = ROCArray(randn(rng, T, m, n))
U, S, Vᴴ = svd_full(A; alg)
@test U isa ROCMatrix{T} && size(U) == (m, m)
Expand Down Expand Up @@ -81,6 +82,26 @@ end
@test Sc === Sc2
@test ROCArray(diagview(S)) ≈ Sc
# ROCArray is necessary because norm of ROCArray view with non-unit step is broken
if alg isa ROCSOLVER_QRIteration
@test_warn "GPU_QRIteration does not accept any keyword arguments" svd_full!(copy!(Ac, A), (U, S, Vᴴ), ROCSOLVER_QRIteration(; bad = "bad"))
@test_warn "GPU_QRIteration does not accept any keyword arguments" svd_vals!(copy!(Ac, A), Sc, ROCSOLVER_QRIteration(; bad = "bad"))
end
end
end
@testset "size (0, 0)" begin
algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
@testset "algorithm $alg" for alg in algs
A = ROCArray(randn(rng, T, 0, 0))
U, S, Vᴴ = svd_full(A; alg)
@test U isa ROCMatrix{T} && size(U) == (0, 0)
@test S isa ROCMatrix{real(T)} && size(S) == (0, 0)
@test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (0, 0)
@test U * S * Vᴴ ≈ A
@test isapproxone(U' * U)
@test isapproxone(U * U')
@test isapproxone(Vᴴ * Vᴴ')
@test isapproxone(Vᴴ' * Vᴴ)
@test all(isposdef, diagview(S))
end
end
end
Expand Down
24 changes: 21 additions & 3 deletions test/cuda/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ include(joinpath("..", "utilities.jl"))
k = min(m, n)
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi())
@testset "algorithm $alg" for alg in algs
n > m && alg isa CUSOLVER_QRIteration && continue # not supported
minmn = min(m, n)
A = CuArray(randn(rng, T, m, n))

Expand All @@ -41,6 +40,9 @@ include(joinpath("..", "utilities.jl"))
Sd = svd_vals(A, alg)
@test CuArray(diagview(S)) ≈ Sd
# CuArray is necessary because norm of CuArray view with non-unit step is broken
if alg isa CUSOLVER_QRIteration
@test_warn "GPU_QRIteration does not accept any keyword arguments" svd_compact!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad"))
end
end
end
end
Expand All @@ -51,7 +53,6 @@ end
@testset "size ($m, $n)" for n in (37, m, 63)
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi())
@testset "algorithm $alg" for alg in algs
n > m && alg isa CUSOLVER_QRIteration && continue # not supported
A = CuArray(randn(rng, T, m, n))
U, S, Vᴴ = svd_full(A; alg)
@test U isa CuMatrix{T} && size(U) == (m, m)
Expand Down Expand Up @@ -82,8 +83,26 @@ end
@test Sc === Sc2
@test CuArray(diagview(S)) ≈ Sc
# CuArray is necessary because norm of CuArray view with non-unit step is broken
if alg isa CUSOLVER_QRIteration
@test_warn "GPU_QRIteration does not accept any keyword arguments" svd_full!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad"))
@test_warn "GPU_QRIteration does not accept any keyword arguments" svd_vals!(copy!(Ac, A), Sc, CUSOLVER_QRIteration(; bad = "bad"))
end
end
end
@testset "size (0, 0)" begin
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi())
@testset "algorithm $alg" for alg in algs
A = CuArray(randn(rng, T, 0, 0))
U, S, Vᴴ = svd_full(A; alg)
@test U isa CuMatrix{T} && size(U) == (0, 0)
@test S isa CuMatrix{real(T)} && size(S) == (0, 0)
@test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (0, 0)
@test U * S * Vᴴ ≈ A
@test isapproxone(U' * U)
@test isapproxone(U * U')
@test isapproxone(Vᴴ * Vᴴ')
@test isapproxone(Vᴴ' * Vᴴ)
@test all(isposdef, diagview(S))
end
end
end
Expand All @@ -96,7 +115,6 @@ end
p = min(m, n) - k - 1
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi(), CUSOLVER_Randomized(; k = k, p = p, niters = 100))
@testset "algorithm $alg" for alg in algs
n > m && alg isa CUSOLVER_QRIteration && continue # not supported
hA = randn(rng, T, m, n)
S₀ = svd_vals(hA)
A = CuArray(hA)
Expand Down
14 changes: 14 additions & 0 deletions test/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ end
@test diagview(S) ≈ Sc
end
end
@testset "size (0, 0)" begin
@testset "algorithm $alg" for alg in
(LAPACK_DivideAndConquer(), LAPACK_QRIteration())
A = randn(rng, T, 0, 0)
U, S, Vᴴ = svd_full(A; alg)
@test U isa Matrix{T} && size(U) == (0, 0)
@test S isa Matrix{real(T)} && size(S) == (0, 0)
@test Vᴴ isa Matrix{T} && size(Vᴴ) == (0, 0)
@test U * S * Vᴴ ≈ A
@test isunitary(U)
@test isunitary(Vᴴ)
@test all(isposdef, diagview(S))
end
end
end

@testset "svd_trunc! for T = $T" for T in BLASFloats
Expand Down