Skip to content

Commit 178e892

Browse files
lkdvoskshyatt
authored andcommitted
add testsuite
1 parent ba24875 commit 178e892

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1657
-4278
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@ GenericSchur = "0.5.6"
3333
JET = "0.9, 0.10"
3434
LinearAlgebra = "1"
3535
Mooncake = "0.4.174"
36+
Random = "1"
3637
SafeTestsets = "0.1"
3738
StableRNGs = "1"
3839
Test = "1"
39-
TestExtras = "0.2,0.3"
40+
TestExtras = "0.3.2"
4041
Zygote = "0.7"
4142
julia = "1.10"
4243

@@ -47,11 +48,12 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
4748
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4849
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
4950
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
51+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
5052
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
5153
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
5254
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5355
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
5456
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5557

5658
[targets]
57-
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"]
59+
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Random", "Mooncake"]

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -159,43 +159,4 @@ function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
159159
return A, B
160160
end
161161

162-
function MatrixAlgebraKit.truncate(
163-
::typeof(left_null!), US::Tuple{TU, TS}, strategy::TruncationStrategy
164-
) where {TU <: ROCMatrix, TS}
165-
# TODO: avoid allocation?
166-
U, S = US
167-
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2))))
168-
ind = MatrixAlgebraKit.findtruncated(extended_S, strategy)
169-
trunc_cols = collect(1:size(U, 2))[ind]
170-
Utrunc = U[:, trunc_cols]
171-
return Utrunc, ind
172-
end
173-
function MatrixAlgebraKit.truncate(
174-
::typeof(right_null!), SVᴴ::Tuple{TS, TVᴴ}, strategy::TruncationStrategy
175-
) where {TS, TVᴴ <: ROCMatrix}
176-
# TODO: avoid allocation?
177-
S, Vᴴ = SVᴴ
178-
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 2) - size(S, 1))))
179-
ind = MatrixAlgebraKit.findtruncated(extended_S, strategy)
180-
trunc_rows = collect(1:size(Vᴴ, 1))[ind]
181-
Vᴴtrunc = Vᴴ[trunc_rows, :]
182-
return Vᴴtrunc, ind
183-
end
184-
185-
# disambiguate:
186-
function MatrixAlgebraKit.truncate(
187-
::typeof(left_null!), (U, S)::Tuple{TU, TS}, ::NoTruncation
188-
) where {TU <: ROCMatrix, TS}
189-
m, n = size(S)
190-
ind = (n + 1):m
191-
return U[:, ind], ind
192-
end
193-
function MatrixAlgebraKit.truncate(
194-
::typeof(right_null!), (S, Vᴴ)::Tuple{TS, TVᴴ}, ::NoTruncation
195-
) where {TS, TVᴴ <: ROCMatrix}
196-
m, n = size(S)
197-
ind = (m + 1):n
198-
return Vᴴ[ind, :], ind
199-
end
200-
201162
end

src/implementations/qr.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,12 @@ function _gpu_unmqr!(
270270
end
271271

272272
function _gpu_qr!(
273-
A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; positive = false, blocksize = 1
273+
A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; pivoted = false, positive = false, blocksize = 1
274274
)
275275
blocksize > 1 &&
276276
throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a blocked implementation for a QR decomposition"))
277+
pivoted &&
278+
throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a pivoted implementation for a QR decomposition"))
277279
m, n = size(A)
278280
minmn = min(m, n)
279281
computeR = length(R) > 0
@@ -309,10 +311,12 @@ function _gpu_qr!(
309311
end
310312

311313
function _gpu_qr_null!(
312-
A::AbstractMatrix, N::AbstractMatrix; positive = false, blocksize = 1
314+
A::AbstractMatrix, N::AbstractMatrix; positive = false, blocksize = 1, pivoted = false
313315
)
314316
blocksize > 1 &&
315317
throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a blocked implementation for a QR decomposition"))
318+
pivoted &&
319+
throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a pivoted implementation for a QR decomposition"))
316320
m, n = size(A)
317321
minmn = min(m, n)
318322
fill!(N, zero(eltype(N)))

src/implementations/schur.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,29 @@ function check_input(::typeof(schur_vals!), A::AbstractMatrix, vals, ::AbstractA
2626
return nothing
2727
end
2828

29+
function check_input(::typeof(schur_full!), A::AbstractMatrix, TZv, ::DiagonalAlgorithm)
30+
m, n = size(A)
31+
@assert m == n && isdiag(A)
32+
T, Z, vals = TZv
33+
@assert vals isa AbstractVector && Z isa Diagonal
34+
@check_scalar(T, A)
35+
@check_size(Z, (m, m))
36+
@check_scalar(Z, A)
37+
@check_size(vals, (n,))
38+
# Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable
39+
@check_scalar(vals, A)
40+
return nothing
41+
end
42+
function check_input(::typeof(schur_vals!), A::AbstractMatrix, vals, ::DiagonalAlgorithm)
43+
m, n = size(A)
44+
@assert m == n && isdiag(A)
45+
@assert vals isa AbstractVector
46+
@check_size(vals, (n,))
47+
# Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable
48+
@check_scalar(vals, A)
49+
return nothing
50+
end
51+
2952
# Outputs
3053
# -------
3154
function initialize_output(::typeof(schur_full!), A::AbstractMatrix, ::AbstractAlgorithm)
@@ -39,6 +62,17 @@ function initialize_output(::typeof(schur_vals!), A::AbstractMatrix, ::AbstractA
3962
vals = similar(A, complex(eltype(A)), n)
4063
return vals
4164
end
65+
function initialize_output(::typeof(schur_full!), A::Diagonal, ::DiagonalAlgorithm)
66+
n = size(A, 1)
67+
Z = similar(A)
68+
vals = similar(A, eltype(A), n)
69+
return (A, Z, vals)
70+
end
71+
function initialize_output(::typeof(schur_vals!), A::Diagonal, ::DiagonalAlgorithm)
72+
n = size(A, 1)
73+
vals = similar(A, eltype(A), n)
74+
return vals
75+
end
4276

4377
# Implementation
4478
# --------------
@@ -72,3 +106,20 @@ function schur_vals!(A::AbstractMatrix, vals, alg::LAPACK_EigAlgorithm)
72106
end
73107
return vals
74108
end
109+
110+
# Diagonal logic
111+
# --------------
112+
function schur_full!(A::Diagonal, (T, Z, vals)::Tuple{Diagonal, Diagonal, <:AbstractVector}, alg::DiagonalAlgorithm)
113+
check_input(schur_full!, A, (T, Z, vals), alg)
114+
copy!(vals, diagview(A))
115+
one!(Z)
116+
T === A || copy!(T, A)
117+
return T, Z, vals
118+
end
119+
120+
function schur_vals!(A::Diagonal, vals::AbstractVector, alg::DiagonalAlgorithm)
121+
check_input(schur_vals!, A, vals, alg)
122+
Ad = diagview(A)
123+
vals === Ad || copy!(vals, Ad)
124+
return vals
125+
end

src/implementations/svd.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ end
152152
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
153153
check_input(svd_compact!, A, USVᴴ, alg)
154154
U, S, Vᴴ = USVᴴ
155+
if length(A) == 0
156+
one!(U)
157+
zero!(S)
158+
one!(Vᴴ)
159+
return USVᴴ
160+
end
155161

156162
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
157163
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})
@@ -382,6 +388,12 @@ end
382388
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
383389
check_input(svd_compact!, A, USVᴴ, alg)
384390
U, S, Vᴴ = USVᴴ
391+
if length(A) == 0
392+
one!(U)
393+
zero!(S)
394+
one!(Vᴴ)
395+
return USVᴴ
396+
end
385397

386398
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
387399
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})
@@ -406,6 +418,10 @@ _largest(x, y) = abs(x) < abs(y) ? y : x
406418

407419
function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
408420
check_input(svd_vals!, A, S, alg)
421+
if length(A) == 0
422+
zero!(S)
423+
return S
424+
end
409425
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
410426

411427
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})

test/amd/eigh.jl

Lines changed: 0 additions & 105 deletions
This file was deleted.

0 commit comments

Comments
 (0)