Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/implementations/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ copy_input(::typeof(right_polar), A) = copy_input(svd_full, A)
function check_input(::typeof(left_polar!), A::AbstractMatrix, WP, ::AbstractAlgorithm)
m, n = size(A)
W, P = WP
m >= n ||
throw(ArgumentError("input matrix needs at least as many rows as columns"))
m n ||
throw(ArgumentError("input matrix needs at least as many rows ($m) as columns ($n)"))
@assert W isa AbstractMatrix && P isa AbstractMatrix
@check_size(W, (m, n))
@check_scalar(W, A)
Expand All @@ -18,8 +18,8 @@ end
function check_input(::typeof(right_polar!), A::AbstractMatrix, PWᴴ, ::AbstractAlgorithm)
m, n = size(A)
P, Wᴴ = PWᴴ
n >= m ||
throw(ArgumentError("input matrix needs at least as many columns as rows"))
n m ||
throw(ArgumentError("input matrix needs at least as many columns ($n) as rows ($m)"))
@assert P isa AbstractMatrix && Wᴴ isa AbstractMatrix
isempty(P) || @check_size(P, (m, m))
@check_scalar(P, A)
Expand Down Expand Up @@ -107,19 +107,19 @@ function _left_polarnewton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol =
if m > n # initial QR
Q, R = qr_compact!(A)
Rc = view(A, 1:n, 1:n)
copy!(Rc, R)
Rc .= R
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is copy! still problematic on GPU?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copying to a SubArray is

Rᴴinv = ldiv!(UpperTriangular(Rc)', one!(Rᴴinv))
else # m == n
R = A
Rc = view(W, 1:n, 1:n)
copy!(Rc, R)
Rc .= R
Rᴴinv = ldiv!(lu!(Rc)', one!(Rᴴinv))
end
γ = sqrt(norm(Rᴴinv) / norm(R)) # scaling factor
rmul!(R, γ)
rmul!(Rᴴinv, 1 / γ)
R, Rᴴinv = _avgdiff!(R, Rᴴinv)
copy!(Rc, R)
Rc .= R
i = 1
conv = norm(Rᴴinv, Inf)
while i < maxiter && conv > tol
Expand All @@ -128,7 +128,7 @@ function _left_polarnewton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol =
rmul!(R, γ)
rmul!(Rᴴinv, 1 / γ)
R, Rᴴinv = _avgdiff!(R, Rᴴinv)
copy!(Rc, R)
Rc .= R
conv = norm(Rᴴinv, Inf)
i += 1
end
Expand All @@ -152,7 +152,7 @@ function _right_polarnewton!(A::AbstractMatrix, Wᴴ, P = similar(A, (0, 0)); to
else # m == n
L = A
Lc = view(Wᴴ, 1:m, 1:m)
copy!(Lc, L)
Lc .= L
Lᴴinv = ldiv!(lu!(Lc)', one!(Lᴴinv))
end
γ = sqrt(norm(Lᴴinv) / norm(L)) # scaling factor
Expand All @@ -168,7 +168,7 @@ function _right_polarnewton!(A::AbstractMatrix, Wᴴ, P = similar(A, (0, 0)); to
rmul!(L, γ)
rmul!(Lᴴinv, 1 / γ)
L, Lᴴinv = _avgdiff!(L, Lᴴinv)
copy!(Lc, L)
Lc .= L
conv = norm(Lᴴinv, Inf)
i += 1
end
Expand Down
16 changes: 8 additions & 8 deletions src/yalapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2162,7 +2162,7 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
jobu = 'N'
else
size(U, 1) == m ||
throw(DimensionMismatch("row size mismatch between A and U"))
throw(DimensionMismatch("row size mismatch between A ($m) and U ($(size(U, 1)))"))
size(U, 2) >= (range == 'I' ? iu - il + 1 : minmn) ||
throw(DimensionMismatch("invalid column size of U"))
jobu = 'V'
Expand All @@ -2171,13 +2171,13 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
jobvt = 'N'
else
size(Vᴴ, 2) == n ||
throw(DimensionMismatch("column size mismatch between A and Vᴴ"))
throw(DimensionMismatch("column size mismatch between A ($n) and Vᴴ ($(size(Vᴴ, 2)))"))
size(Vᴴ, 1) >= (range == 'I' ? iu - il + 1 : minmn) ||
throw(DimensionMismatch("invalid row size of Vᴴ"))
jobvt = 'V'
end
length(S) == minmn ||
throw(DimensionMismatch("length mismatch between A and S"))
throw(DimensionMismatch("length mismatch between A ($minmn) and S ($(length(S)))"))

lda = max(1, stride(A, 2))
ldu = max(1, stride(U, 2))
Expand Down Expand Up @@ -2247,15 +2247,15 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
require_one_based_indexing(A, U, Vᴴ, S)
chkstride1(A, U, Vᴴ, S)
m, n = size(A)
m >= n ||
throw(ArgumentError("gejsv! requires a matrix with at least as many rows as columns"))
m n ||
throw(ArgumentError("gejsv! requires a matrix with at least as many rows ($m) as columns ($n)"))

joba = 'G'
if length(U) == 0
jobu = 'N'
else
size(U, 1) == m ||
throw(DimensionMismatch("row size mismatch between A and U"))
throw(DimensionMismatch("row size mismatch between A ($m) and U ($(size(U, 1)))"))
if size(U, 2) == n
jobu = 'U'
elseif size(U, 2) == m
Expand All @@ -2268,15 +2268,15 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
jobv = 'N'
else
size(Vᴴ, 2) == n ||
throw(DimensionMismatch("column size mismatch between A and Vᴴ"))
throw(DimensionMismatch("column size mismatch between A ($n) and Vᴴ ($(size(Vᴴ, 2)))"))
if size(Vᴴ, 1) == n
jobv = 'V'
else
throw(DimensionMismatch("invalid row size of Vᴴ"))
end
end
length(S) == n ||
throw(DimensionMismatch("length mismatch between A and S"))
throw(DimensionMismatch("length mismatch between A ($minmn) and S ($(length(S)))"))

lda = max(1, stride(A, 2))
mv = Ref{BlasInt}() # unused
Expand Down
83 changes: 0 additions & 83 deletions test/amd/polar.jl

This file was deleted.

83 changes: 0 additions & 83 deletions test/cuda/polar.jl

This file was deleted.

3 changes: 2 additions & 1 deletion test/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ for T in (BLASFloats..., GenericFloats...), n in (37, m, 63)
TestSuite.test_lq_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),))
end
end
elseif !is_buildkite
end
if !is_buildkite
if T ∈ BLASFloats
TestSuite.test_lq(T, (m, n))
LAPACK_LQ_ALGS = (
Expand Down
Loading