Skip to content

Commit 1ba0502

Browse files
committed
Testsuite for polar
1 parent 9d1ffb8 commit 1ba0502

File tree

10 files changed

+140
-262
lines changed

10 files changed

+140
-262
lines changed

src/implementations/polar.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ copy_input(::typeof(right_polar), A) = copy_input(svd_full, A)
66
function check_input(::typeof(left_polar!), A::AbstractMatrix, WP, ::AbstractAlgorithm)
77
m, n = size(A)
88
W, P = WP
9-
m >= n ||
10-
throw(ArgumentError("input matrix needs at least as many rows as columns"))
9+
m n ||
10+
throw(ArgumentError("input matrix needs at least as many rows ($m) as columns ($n)"))
1111
@assert W isa AbstractMatrix && P isa AbstractMatrix
1212
@check_size(W, (m, n))
1313
@check_scalar(W, A)
@@ -18,8 +18,8 @@ end
1818
function check_input(::typeof(right_polar!), A::AbstractMatrix, PWᴴ, ::AbstractAlgorithm)
1919
m, n = size(A)
2020
P, Wᴴ = PWᴴ
21-
n >= m ||
22-
throw(ArgumentError("input matrix needs at least as many columns as rows"))
21+
n m ||
22+
throw(ArgumentError("input matrix needs at least as many columns ($n) as rows ($m)"))
2323
@assert P isa AbstractMatrix && Wᴴ isa AbstractMatrix
2424
isempty(P) || @check_size(P, (m, m))
2525
@check_scalar(P, A)
@@ -152,7 +152,7 @@ function _right_polarnewton!(A::AbstractMatrix, Wᴴ, P = similar(A, (0, 0)); to
152152
else # m == n
153153
L = A
154154
Lc = view(Wᴴ, 1:m, 1:m)
155-
copy!(Lc, L)
155+
Lc .= L
156156
Lᴴinv = ldiv!(lu!(Lc)', one!(Lᴴinv))
157157
end
158158
γ = sqrt(norm(Lᴴinv) / norm(L)) # scaling factor
@@ -168,7 +168,7 @@ function _right_polarnewton!(A::AbstractMatrix, Wᴴ, P = similar(A, (0, 0)); to
168168
rmul!(L, γ)
169169
rmul!(Lᴴinv, 1 / γ)
170170
L, Lᴴinv = _avgdiff!(L, Lᴴinv)
171-
copy!(Lc, L)
171+
Lc .= L
172172
conv = norm(Lᴴinv, Inf)
173173
i += 1
174174
end

src/yalapack.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,7 +2162,7 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
21622162
jobu = 'N'
21632163
else
21642164
size(U, 1) == m ||
2165-
throw(DimensionMismatch("row size mismatch between A and U"))
2165+
throw(DimensionMismatch("row size mismatch between A ($m) and U ($(size(U, 1)))"))
21662166
size(U, 2) >= (range == 'I' ? iu - il + 1 : minmn) ||
21672167
throw(DimensionMismatch("invalid column size of U"))
21682168
jobu = 'V'
@@ -2171,13 +2171,13 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
21712171
jobvt = 'N'
21722172
else
21732173
size(Vᴴ, 2) == n ||
2174-
throw(DimensionMismatch("column size mismatch between A and Vᴴ"))
2174+
throw(DimensionMismatch("column size mismatch between A ($n) and Vᴴ ($(size(Vᴴ, 2)))"))
21752175
size(Vᴴ, 1) >= (range == 'I' ? iu - il + 1 : minmn) ||
21762176
throw(DimensionMismatch("invalid row size of Vᴴ"))
21772177
jobvt = 'V'
21782178
end
21792179
length(S) == minmn ||
2180-
throw(DimensionMismatch("length mismatch between A and S"))
2180+
throw(DimensionMismatch("length mismatch between A ($minmn) and S ($(length(S)))"))
21812181

21822182
lda = max(1, stride(A, 2))
21832183
ldu = max(1, stride(U, 2))
@@ -2247,15 +2247,15 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
22472247
require_one_based_indexing(A, U, Vᴴ, S)
22482248
chkstride1(A, U, Vᴴ, S)
22492249
m, n = size(A)
2250-
m >= n ||
2251-
throw(ArgumentError("gejsv! requires a matrix with at least as many rows as columns"))
2250+
m n ||
2251+
throw(ArgumentError("gejsv! requires a matrix with at least as many rows ($m) as columns ($n)"))
22522252

22532253
joba = 'G'
22542254
if length(U) == 0
22552255
jobu = 'N'
22562256
else
22572257
size(U, 1) == m ||
2258-
throw(DimensionMismatch("row size mismatch between A and U"))
2258+
throw(DimensionMismatch("row size mismatch between A ($m) and U ($(size(U, 1)))"))
22592259
if size(U, 2) == n
22602260
jobu = 'U'
22612261
elseif size(U, 2) == m
@@ -2268,15 +2268,15 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
22682268
jobv = 'N'
22692269
else
22702270
size(Vᴴ, 2) == n ||
2271-
throw(DimensionMismatch("column size mismatch between A and Vᴴ"))
2271+
throw(DimensionMismatch("column size mismatch between A ($n) and Vᴴ ($(size(Vᴴ, 2)))"))
22722272
if size(Vᴴ, 1) == n
22732273
jobv = 'V'
22742274
else
22752275
throw(DimensionMismatch("invalid row size of Vᴴ"))
22762276
end
22772277
end
22782278
length(S) == n ||
2279-
throw(DimensionMismatch("length mismatch between A and S"))
2279+
throw(DimensionMismatch("length mismatch between A ($minmn) and S ($(length(S)))"))
22802280

22812281
lda = max(1, stride(A, 2))
22822282
mv = Ref{BlasInt}() # unused

test/amd/polar.jl

Lines changed: 0 additions & 83 deletions
This file was deleted.

test/cuda/polar.jl

Lines changed: 0 additions & 83 deletions
This file was deleted.

test/lq.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ for T in (BLASFloats..., GenericFloats...), n in (37, m, 63)
3535
TestSuite.test_lq_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),))
3636
end
3737
end
38-
elseif !is_buildkite
38+
end
39+
if !is_buildkite
3940
if T BLASFloats
4041
TestSuite.test_lq(T, (m, n))
4142
LAPACK_LQ_ALGS = (

test/polar.jl

Lines changed: 33 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,45 @@
11
using MatrixAlgebraKit
22
using Test
3-
using TestExtras
43
using StableRNGs
5-
using LinearAlgebra: LinearAlgebra, I, isposdef
4+
using LinearAlgebra: Diagonal
5+
using CUDA, AMDGPU
66

7-
@testset "left_polar! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
8-
rng = StableRNG(123)
9-
m = 54
10-
@testset "size ($m, $n)" for n in (37, m)
11-
k = min(m, n)
12-
if LinearAlgebra.LAPACK.version() < v"3.12.0"
13-
svdalgs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection())
14-
else
15-
svdalgs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), LAPACK_Jacobi())
16-
end
17-
algs = (PolarViaSVD.(svdalgs)..., PolarNewton())
18-
@testset "algorithm $alg" for alg in algs
19-
A = randn(rng, T, m, n)
7+
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
8+
GenericFloats = (BigFloat, Complex{BigFloat})
209

21-
W, P = left_polar(A; alg)
22-
@test W isa Matrix{T} && size(W) == (m, n)
23-
@test P isa Matrix{T} && size(P) == (n, n)
24-
@test W * P A
25-
@test isisometric(W)
26-
@test isposdef(P)
10+
@isdefined(TestSuite) || include("testsuite/TestSuite.jl")
11+
using .TestSuite
2712

28-
Ac = similar(A)
29-
W2, P2 = @constinferred left_polar!(copy!(Ac, A), (W, P), alg)
30-
@test W2 === W
31-
@test P2 === P
32-
@test W * P A
33-
@test isisometric(W)
34-
@test isposdef(P)
13+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
3514

36-
noP = similar(P, (0, 0))
37-
W2, P2 = @constinferred left_polar!(copy!(Ac, A), (W, noP), alg)
38-
@test P2 === noP
39-
@test W2 === W
40-
@test isisometric(W)
41-
P = W' * A # compute P explicitly to verify W correctness
42-
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
43-
@test isposdef(project_hermitian!(P))
15+
m = 54
16+
for T in (BLASFloats..., GenericFloats...), n in (37, m, 63)
17+
TestSuite.seed_rng!(123)
18+
if T BLASFloats
19+
if CUDA.functional()
20+
CUDA_POLAR_ALGS = (PolarViaSVD.((CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi()))..., PolarNewton())
21+
TestSuite.test_polar(CuMatrix{T}, (m, n), CUDA_POLAR_ALGS)
22+
n == m && TestSuite.test_polar(Diagonal{T, CuVector{T}}, m, (PolarNewton(),))
23+
end
24+
if AMDGPU.functional()
25+
ROC_POLAR_ALGS = (PolarViaSVD.((ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()))..., PolarNewton())
26+
TestSuite.test_polar(ROCMatrix{T}, (m, n), ROC_POLAR_ALGS)
27+
n == m && TestSuite.test_polar(Diagonal{T, ROCVector{T}}, m, (PolarNewton(),))
4428
end
4529
end
46-
end
47-
48-
@testset "right_polar! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
49-
rng = StableRNG(123)
50-
n = 54
51-
@testset "size ($m, $n)" for m in (37, n)
52-
k = min(m, n)
53-
svdalgs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection())
54-
algs = (PolarViaSVD.(svdalgs)..., PolarNewton())
55-
@testset "algorithm $alg" for alg in algs
56-
A = randn(rng, T, m, n)
57-
58-
P, Wᴴ = right_polar(A; alg)
59-
@test Wᴴ isa Matrix{T} && size(Wᴴ) == (m, n)
60-
@test P isa Matrix{T} && size(P) == (m, m)
61-
@test P * Wᴴ A
62-
@test isisometric(Wᴴ; side = :right)
63-
@test isposdef(P)
64-
65-
Ac = similar(A)
66-
P2, Wᴴ2 = @constinferred right_polar!(copy!(Ac, A), (P, Wᴴ), alg)
67-
@test P2 === P
68-
@test Wᴴ2 === Wᴴ
69-
@test P * Wᴴ A
70-
@test isisometric(Wᴴ; side = :right)
71-
@test isposdef(P)
72-
73-
noP = similar(P, (0, 0))
74-
P2, Wᴴ2 = @constinferred right_polar!(copy!(Ac, A), (noP, Wᴴ), alg)
75-
@test P2 === noP
76-
@test Wᴴ2 === Wᴴ
77-
@test isisometric(Wᴴ; side = :right)
78-
P = A * Wᴴ' # compute P explicitly to verify W correctness
79-
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
80-
@test isposdef(project_hermitian!(P))
30+
if !is_buildkite
31+
if T BLASFloats
32+
LAPACK_POLAR_ALGS = (PolarViaSVD.((LAPACK_QRIteration(), LAPACK_Bisection(), LAPACK_DivideAndConquer()))..., PolarNewton())
33+
TestSuite.test_polar(T, (m, n), LAPACK_POLAR_ALGS)
34+
LAPACK_JACOBI = (PolarViaSVD(LAPACK_Jacobi()),)
35+
TestSuite.test_polar(T, (m, n), LAPACK_JACOBI; test_right=false)
36+
elseif T GenericFloats
37+
GLA_POLAR_ALGS = (PolarViaSVD.((GLA_QRIteration(),))..., PolarNewton())
38+
TestSuite.test_polar(T, (m, n), GLA_POLAR_ALGS)
39+
end
40+
if m == n
41+
AT = Diagonal{T, Vector{T}}
42+
TestSuite.test_polar(AT, m, (PolarNewton(),))
8143
end
8244
end
8345
end

0 commit comments

Comments
 (0)