Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ Zygote = "0.7"
julia = "1.10"

[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand Down
29 changes: 23 additions & 6 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,47 @@ using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
using CUDA
using CUDA, CUDA.CUBLAS
using CUDA: i32
using LinearAlgebra
using LinearAlgebra: BlasFloat

include("yacusolver.jl")

function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
return CUSOLVER_HouseholderQR(; kwargs...)
end
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
return LQViaTransposedQR(qr_alg)
end
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
return CUSOLVER_QRIteration(; kwargs...)
end
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
return CUSOLVER_Simple(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
return CUSOLVER_DivideAndConquer(; kwargs...)
end

# include for block sector support
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
return CUSOLVER_HouseholderQR(; kwargs...)
end
function MatrixAlgebraKit.default_lq_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
return LQViaTransposedQR(qr_alg)
end
function MatrixAlgebraKit.default_svd_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
return CUSOLVER_Jacobi(; kwargs...)
end
function MatrixAlgebraKit.default_eig_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
return CUSOLVER_Simple(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
return CUSOLVER_DivideAndConquer(; kwargs...)
end

_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) =
YACUSOLVER.Xgeev!(A, D, V)
Expand Down
7 changes: 5 additions & 2 deletions ext/MatrixAlgebraKitCUDAExt/yacusolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ for (bname, fname, elty, relty) in
)
chkstride1(A, U, Vᴴ, S)
m, n = size(A)
(m < n) && throw(ArgumentError("CUSOLVER's gesvd requires m ≥ n"))
(m < n) && throw(ArgumentError(lazy"CUSOLVER's gesvd requires m ($m) ≥ n ($n)"))
minmn = min(m, n)
if length(U) == 0
jobu = 'N'
Expand Down Expand Up @@ -191,14 +191,17 @@ for (bname, fname, elty, relty) in
(:cusolverDnZgesvdj_bufferSize, :cusolverDnZgesvdj, :ComplexF64, :Float64),
)
@eval begin
#! format: off
function gesvdj!(
A::StridedCuMatrix{$elty},
S::StridedCuVector{$relty} = similar(A, $relty, min(size(A)...)),
U::StridedCuMatrix{$elty} = similar(A, $elty, size(A, 1), min(size(A)...)),
Vᴴ::StridedCuMatrix{$elty} = similar(A, $elty, min(size(A)...), size(A, 2));
tol::$relty = eps($relty),
max_sweeps::Int = 100
max_sweeps::Int = 100,
kwargs...
)
#! format: on
chkstride1(A, U, Vᴴ, S)
m, n = size(A)
minmn = min(m, n)
Expand Down
4 changes: 2 additions & 2 deletions src/implementations/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm)
D, V = DV
if alg isa GPU_Simple
isempty(alg.kwargs) ||
throw(ArgumentError("GPU_Simple (geev) does not accept any keyword arguments"))
@warn "GPU_Simple (geev) does not accept any keyword arguments"
_gpu_geev!(A, D.diag, V)
end
# TODO: make this controllable using a `gaugefix` keyword argument
Expand All @@ -150,7 +150,7 @@ function eig_vals!(A::AbstractMatrix, D, alg::GPU_EigAlgorithm)
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
if alg isa GPU_Simple
isempty(alg.kwargs) ||
throw(ArgumentError("LAPACK_Simple (geev) does not accept any keyword arguments"))
@warn "GPU_Simple (geev) does not accept any keyword arguments"
_gpu_geev!(A, D, V)
end
return D
Expand Down
1 change: 1 addition & 0 deletions src/implementations/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::DiagonalA
@check_scalar(V, A)
return nothing
end

function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, alg::DiagonalAlgorithm)
check_hermitian(A, alg)
@assert isdiag(A)
Expand Down
2 changes: 1 addition & 1 deletion src/implementations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ end
_argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x)))
_largest(x, y) = abs(x) < abs(y) ? y : x

function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
check_input(svd_vals!, A, S, alg)
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
if alg isa GPU_QRIteration
Expand Down
2 changes: 1 addition & 1 deletion test/amd/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64)
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
@test V * C ≈ A
@test isisometric(V)
@test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(N)
hV = collect(V)
hN = collect(N)
Expand Down
2 changes: 0 additions & 2 deletions test/amd/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ using AMDGPU
k = min(m, n)
svd_algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
@testset "algorithm $svd_alg" for svd_alg in svd_algs
n < m && svd_alg isa ROCSOLVER_QRIteration && continue
A = ROCArray(randn(rng, T, m, n))
alg = PolarViaSVD(svd_alg)
W, P = left_polar(A; alg)
Expand Down Expand Up @@ -52,7 +51,6 @@ end
k = min(m, n)
svd_algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
@testset "algorithm $svd_alg" for svd_alg in svd_algs
n > m && svd_alg isa ROCSOLVER_QRIteration && continue
A = ROCArray(randn(rng, T, m, n))
alg = PolarViaSVD(svd_alg)
P, Wᴴ = right_polar(A; alg)
Expand Down
2 changes: 0 additions & 2 deletions test/cuda/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ using CUDA
k = min(m, n)
svd_algs = (CUSOLVER_QRIteration(), CUSOLVER_Jacobi())
@testset "algorithm $svd_alg" for svd_alg in svd_algs
n < m && svd_alg isa CUSOLVER_QRIteration && continue
A = CuArray(randn(rng, T, m, n))
alg = PolarViaSVD(svd_alg)
W, P = left_polar(A; alg)
Expand Down Expand Up @@ -52,7 +51,6 @@ end
k = min(m, n)
svd_algs = (CUSOLVER_QRIteration(), CUSOLVER_Jacobi())
@testset "algorithm $svd_alg" for svd_alg in svd_algs
n > m && svd_alg isa CUSOLVER_QRIteration && continue
A = CuArray(randn(rng, T, m, n))
alg = PolarViaSVD(svd_alg)
P, Wᴴ = right_polar(A; alg)
Expand Down