Skip to content

Commit 9d1ffb8

Browse files
authored
Actually test GPU QR/LQ (#124)
* Actually test GPU QR/LQ * Run only the LQ/QR testsuites on BK
1 parent 2ef4ab1 commit 9d1ffb8

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

test/lq.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ for T in (BLASFloats..., GenericFloats...), n in (37, m, 63)
1818
TestSuite.seed_rng!(123)
1919
if T BLASFloats
2020
if CUDA.functional()
21-
CUDA_LQ_ALGS = LQViaTransposedQR.(CUSOLVER_HouseholderLQ(; positive = false), CUSOLVER_HouseholderLQ(; positive = true))
21+
CUDA_LQ_ALGS = LQViaTransposedQR.((CUSOLVER_HouseholderQR(; positive = false), CUSOLVER_HouseholderQR(; positive = true)))
2222
TestSuite.test_lq(CuMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false)
2323
TestSuite.test_lq_algs(CuMatrix{T}, (m, n), CUDA_LQ_ALGS)
2424
if n == m
@@ -27,9 +27,9 @@ for T in (BLASFloats..., GenericFloats...), n in (37, m, 63)
2727
end
2828
end
2929
if AMDGPU.functional()
30-
ROC_LQ_ALGS = LQViaTransposedQR.(ROCSOLVER_HouseholderLQ(; positive = false), ROCSOLVER_HouseholderLQ(; positive = true))
30+
ROC_LQ_ALGS = LQViaTransposedQR.((ROCSOLVER_HouseholderQR(; positive = false), ROCSOLVER_HouseholderQR(; positive = true)))
3131
TestSuite.test_lq(ROCMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false)
32-
TestSuite.test_lq_algs(ROCMatrix{T}, (m, n), CUDA_LQ_ALGS)
32+
TestSuite.test_lq_algs(ROCMatrix{T}, (m, n), ROC_LQ_ALGS)
3333
if n == m
3434
TestSuite.test_lq(Diagonal{T, ROCVector{T}}, m; test_pivoted = false, test_blocksize = false)
3535
TestSuite.test_lq_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),))

test/runtests.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using SafeTestsets
22

3+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
34
# don't run all tests on GPU, only the GPU
45
# specific ones
5-
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
66
if !is_buildkite
77
@safetestset "Algorithms" begin
88
include("algorithms.jl")
@@ -13,10 +13,6 @@ if !is_buildkite
1313
@safetestset "Truncate" begin
1414
include("truncate.jl")
1515
end
16-
@safetestset "QR / LQ Decomposition" begin
17-
include("qr.jl")
18-
include("lq.jl")
19-
end
2016
@safetestset "Singular Value Decomposition" begin
2117
include("svd.jl")
2218
end
@@ -71,6 +67,11 @@ if !is_buildkite
7167
end
7268
end
7369

70+
@safetestset "QR / LQ Decomposition" begin
71+
include("qr.jl")
72+
include("lq.jl")
73+
end
74+
7475
using CUDA
7576
if CUDA.functional()
7677
@safetestset "CUDA Projections" begin

0 commit comments

Comments
 (0)