diff --git a/Project.toml b/Project.toml index 934c0ceb..d7044a33 100644 --- a/Project.toml +++ b/Project.toml @@ -38,8 +38,10 @@ Zygote = "0.7" julia = "1.10" [extras] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 68d25c36..e4c77453 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -8,30 +8,47 @@ using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev! import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd! -using CUDA +using CUDA, CUDA.CUBLAS using CUDA: i32 using LinearAlgebra using LinearAlgebra: BlasFloat include("yacusolver.jl") -function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix} +function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}} return CUSOLVER_HouseholderQR(; kwargs...) end -function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix} +function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}} qr_alg = CUSOLVER_HouseholderQR(; kwargs...) return LQViaTransposedQR(qr_alg) end -function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix} +function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}} return CUSOLVER_QRIteration(; kwargs...) end -function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix} +function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}} return CUSOLVER_Simple(; kwargs...) end -function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix} +function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}} return CUSOLVER_DivideAndConquer(; kwargs...) end +# include for block sector support +function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}} + return CUSOLVER_HouseholderQR(; kwargs...) +end +function MatrixAlgebraKit.default_lq_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}} + qr_alg = CUSOLVER_HouseholderQR(; kwargs...) + return LQViaTransposedQR(qr_alg) +end +function MatrixAlgebraKit.default_svd_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}} + return CUSOLVER_Jacobi(; kwargs...) +end +function MatrixAlgebraKit.default_eig_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}} + return CUSOLVER_Simple(; kwargs...) +end +function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}} + return CUSOLVER_DivideAndConquer(; kwargs...) +end _gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) = YACUSOLVER.Xgeev!(A, D, V) diff --git a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl index 62f0f224..aba01fe6 100644 --- a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl +++ b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl @@ -30,7 +30,7 @@ for (bname, fname, elty, relty) in ) chkstride1(A, U, Vᴴ, S) m, n = size(A) - (m < n) && throw(ArgumentError("CUSOLVER's gesvd requires m ≥ n")) + (m < n) && throw(ArgumentError(lazy"CUSOLVER's gesvd requires m ($m) ≥ n ($n)")) minmn = min(m, n) if length(U) == 0 jobu = 'N' @@ -191,14 +191,17 @@ for (bname, fname, elty, relty) in (:cusolverDnZgesvdj_bufferSize, :cusolverDnZgesvdj, :ComplexF64, :Float64), ) @eval begin + #! format: off function gesvdj!( A::StridedCuMatrix{$elty}, S::StridedCuVector{$relty} = similar(A, $relty, min(size(A)...)), U::StridedCuMatrix{$elty} = similar(A, $elty, size(A, 1), min(size(A)...)), Vᴴ::StridedCuMatrix{$elty} = similar(A, $elty, min(size(A)...), size(A, 2)); tol::$relty = eps($relty), - max_sweeps::Int = 100 + max_sweeps::Int = 100, + kwargs... ) + #! format: on chkstride1(A, U, Vᴴ, S) m, n = size(A) minmn = min(m, n) diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index d491ee94..8532c29a 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -137,7 +137,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm) D, V = DV if alg isa GPU_Simple isempty(alg.kwargs) || - throw(ArgumentError("GPU_Simple (geev) does not accept any keyword arguments")) + @warn "GPU_Simple (geev) does not accept any keyword arguments" _gpu_geev!(A, D.diag, V) end # TODO: make this controllable using a `gaugefix` keyword argument @@ -150,7 +150,7 @@ function eig_vals!(A::AbstractMatrix, D, alg::GPU_EigAlgorithm) V = similar(A, complex(eltype(A)), (size(A, 1), 0)) if alg isa GPU_Simple isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_Simple (geev) does not accept any keyword arguments")) + @warn "GPU_Simple (geev) does not accept any keyword arguments" _gpu_geev!(A, D, V) end return D diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 0cfa6db0..13d0b9d3 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -50,6 +50,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::DiagonalA @check_scalar(V, A) return nothing end + function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, alg::DiagonalAlgorithm) check_hermitian(A, alg) @assert isdiag(A) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index fed36cd1..ba54f233 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -418,7 +418,7 @@ end _argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x))) _largest(x, y) = abs(x) < abs(y) ? y : x -function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm) +function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm) check_input(svd_vals!, A, S, alg) U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) if alg isa GPU_QRIteration diff --git a/test/amd/orthnull.jl b/test/amd/orthnull.jl index 6ed44228..3223979c 100644 --- a/test/amd/orthnull.jl +++ b/test/amd/orthnull.jl @@ -25,7 +25,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) @test N isa ROCMatrix{T} && size(N) == (m, m - minmn) @test V * C ≈ A @test isisometric(V) - @test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(N) hV = collect(V) hN = collect(N) diff --git a/test/amd/polar.jl b/test/amd/polar.jl index 066c0f78..4040b674 100644 --- a/test/amd/polar.jl +++ b/test/amd/polar.jl @@ -13,7 +13,6 @@ using AMDGPU k = min(m, n) svd_algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()) @testset "algorithm $svd_alg" for svd_alg in svd_algs - n < m && svd_alg isa ROCSOLVER_QRIteration && continue A = ROCArray(randn(rng, T, m, n)) alg = PolarViaSVD(svd_alg) W, P = left_polar(A; alg) @@ -52,7 +51,6 @@ end k = min(m, n) svd_algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()) @testset "algorithm $svd_alg" for svd_alg in svd_algs - n > m && svd_alg isa ROCSOLVER_QRIteration && continue A = ROCArray(randn(rng, T, m, n)) alg = PolarViaSVD(svd_alg) P, Wᴴ = right_polar(A; alg) diff --git a/test/cuda/polar.jl b/test/cuda/polar.jl index a7a76f27..1f512367 100644 --- a/test/cuda/polar.jl +++ b/test/cuda/polar.jl @@ -13,7 +13,6 @@ using CUDA k = min(m, n) svd_algs = (CUSOLVER_QRIteration(), CUSOLVER_Jacobi()) @testset "algorithm $svd_alg" for svd_alg in svd_algs - n < m && svd_alg isa CUSOLVER_QRIteration && continue A = CuArray(randn(rng, T, m, n)) alg = PolarViaSVD(svd_alg) W, P = left_polar(A; alg) @@ -52,7 +51,6 @@ end k = min(m, n) svd_algs = (CUSOLVER_QRIteration(), CUSOLVER_Jacobi()) @testset "algorithm $svd_alg" for svd_alg in svd_algs - n > m && svd_alg isa CUSOLVER_QRIteration && continue A = CuArray(randn(rng, T, m, n)) alg = PolarViaSVD(svd_alg) P, Wᴴ = right_polar(A; alg)