Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,14 @@ function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
return A, B
end

function MatrixAlgebraKit.truncate(::typeof(MatrixAlgebraKit.left_null!), US::Tuple{TU, TS}, strategy::MatrixAlgebraKit.TruncationStrategy) where {TU <: ROCArray, TS}
# TODO: avoid allocation?
U, S = US
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2))))
ind = MatrixAlgebraKit.findtruncated(extended_S, strategy)
trunc_cols = collect(1:size(U, 2))[ind]
Copy link
Member

@lkdvos lkdvos Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
trunc_cols = collect(1:size(U, 2))[ind]
trunc_cols = ind isa AbstractVector{Integer} ? ind : axes(U, 2)[ind]

Does this happen to work as well?

Utrunc = U[:, trunc_cols]
return Utrunc, ind
end

end
342 changes: 342 additions & 0 deletions test/amd/orthnull.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,342 @@
using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using LinearAlgebra: LinearAlgebra, I, mul!, diagm, norm
using MatrixAlgebraKit: GPU_SVDAlgorithm, check_input, copy_input, default_svd_algorithm,
initialize_output, AbstractAlgorithm
using AMDGPU

# testing non-AbstractArray codepaths:
include(joinpath("..", "linearmap.jl"))

@testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
rng = StableRNG(123)
m = 54
@testset for n in (37, m, 63)
minmn = min(m, n)
A = ROCArray(randn(rng, T, m, n))
V, C = @constinferred left_orth(A)
N = @constinferred left_null(A)
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
@test V * C ≈ A
@test isisometric(V)
@test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(N)
hV = collect(V)
hN = collect(N)
@test hV * hV' + hN * hN' ≈ I

M = LinearMap(A)
VM, CM = @constinferred left_orth(M; kind = :svd)
@test parent(VM) * parent(CM) ≈ A

if m > n
nullity = 5
V, C = @constinferred left_orth(A)
AMDGPU.@allowscalar begin
N = @constinferred left_null(A; trunc = (; maxnullity = nullity))
end
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
@test N isa ROCMatrix{T} && size(N) == (m, nullity)
@test V * C ≈ A
@test isisometric(V)
@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(N)
end

for alg_qr in ((; positive = true), (; positive = false), ROCSOLVER_HouseholderQR())
V, C = @constinferred left_orth(A; alg_qr)
N = @constinferred left_null(A; alg_qr)
@test V isa ROCMatrix{T} && size(V) == (m, minmn)
@test C isa ROCMatrix{T} && size(C) == (minmn, n)
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
@test V * C ≈ A
@test isisometric(V)
@test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(N)
hV = collect(V)
hN = collect(N)
@test hV * hV' + hN * hN' ≈ I
end

Ac = similar(A)
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C))
N2 = @constinferred left_null!(copy!(Ac, A), N)
@test V2 === V
@test C2 === C
@test N2 === N
@test V2 * C2 ≈ A
@test isisometric(V2)
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(N2)
hV2 = collect(V2)
hN2 = collect(N2)
@test hV2 * hV2' + hN2 * hN2' ≈ I

atol = eps(real(T))
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = (; atol = atol))
AMDGPU.@allowscalar begin
N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = (; atol = atol))
end
@test V2 !== V
@test C2 !== C
@test N2 !== C
@test V2 * C2 ≈ A
@test isisometric(V2)
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(N2)
hV2 = collect(V2)
hN2 = collect(N2)
@test hV2 * hV2' + hN2 * hN2' ≈ I

rtol = eps(real(T))
for (trunc_orth, trunc_null) in (
((; rtol = rtol), (; rtol = rtol)),
(trunctol(; rtol), trunctol(; rtol, keep_below = true)),
)
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = trunc_orth)
AMDGPU.@allowscalar begin
N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = trunc_null)
end
@test V2 !== V
@test C2 !== C
@test N2 !== C
@test V2 * C2 ≈ A
@test isisometric(V2)
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(N2)
hV2 = collect(V2)
hN2 = collect(N2)
@test hV2 * hV2' + hN2 * hN2' ≈ I
end

@testset for kind in (:qr, :polar, :svd) # explicit kind kwarg
m < n && kind == :polar && continue
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind = kind)
@test V2 === V
@test C2 === C
@test V2 * C2 ≈ A
@test isisometric(V2)
if kind != :polar
N2 = @constinferred left_null!(copy!(Ac, A), N; kind = kind)
@test N2 === N
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(N2)
hV2 = collect(V2)
hN2 = collect(N2)
@test hV2 * hV2' + hN2 * hN2' ≈ I
end

# with kind and tol kwargs
if kind == :svd
V2, C2 = @constinferred left_orth!(
copy!(Ac, A), (V, C); kind = kind,
trunc = (; atol = atol)
)
AMDGPU.@allowscalar begin
N2 = @constinferred left_null!(
copy!(Ac, A), N; kind = kind,
trunc = (; atol = atol)
)
end
@test V2 !== V
@test C2 !== C
@test N2 !== C
@test V2 * C2 ≈ A
@test V2' * V2 ≈ I
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(N2)
hV2 = collect(V2)
hN2 = collect(N2)
@test hV2 * hV2' + hN2 * hN2' ≈ I

V2, C2 = @constinferred left_orth!(
copy!(Ac, A), (V, C); kind = kind,
trunc = (; rtol = rtol)
)
AMDGPU.@allowscalar begin
N2 = @constinferred left_null!(
copy!(Ac, A), N; kind = kind,
trunc = (; rtol = rtol)
)
end
@test V2 !== V
@test C2 !== C
@test N2 !== C
@test V2 * C2 ≈ A
@test isisometric(V2)
@test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(N2)
hV2 = collect(V2)
hN2 = collect(N2)
@test hV2 * hV2' + hN2 * hN2' ≈ I
else
@test_throws ArgumentError left_orth!(
copy!(Ac, A), (V, C); kind = kind,
trunc = (; atol = atol)
)
@test_throws ArgumentError left_orth!(
copy!(Ac, A), (V, C); kind = kind,
trunc = (; rtol = rtol)
)
@test_throws ArgumentError left_null!(
copy!(Ac, A), N; kind = kind,
trunc = (; atol = atol)
)
@test_throws ArgumentError left_null!(
copy!(Ac, A), N; kind = kind,
trunc = (; rtol = rtol)
)
end
end
end
end

@testset "right_orth and right_null for T = $T" for T in (
Float32, Float64, ComplexF32,
ComplexF64,
)
rng = StableRNG(123)
m = 54
@testset for n in (37, m, 63)
minmn = min(m, n)
A = ROCArray(randn(rng, T, m, n))
C, Vᴴ = @constinferred right_orth(A)
Nᴴ = @constinferred right_null(A)
@test C isa ROCMatrix{T} && size(C) == (m, minmn)
@test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (minmn, n)
@test Nᴴ isa ROCMatrix{T} && size(Nᴴ) == (n - minmn, n)
@test C * Vᴴ ≈ A
@test isisometric(Vᴴ; side = :right)
@test LinearAlgebra.norm(A * adjoint(Nᴴ)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(Nᴴ; side = :right)
hVᴴ = collect(Vᴴ)
hNᴴ = collect(Nᴴ)
@test hVᴴ' * hVᴴ + hNᴴ' * hNᴴ ≈ I

M = LinearMap(A)
CM, VMᴴ = @constinferred right_orth(M; kind = :svd)
@test parent(CM) * parent(VMᴴ) ≈ A

Ac = similar(A)
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ))
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ)
@test C2 === C
@test Vᴴ2 === Vᴴ
@test Nᴴ2 === Nᴴ
@test C2 * Vᴴ2 ≈ A
@test isisometric(Vᴴ2; side = :right)
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(Nᴴ; side = :right)
hVᴴ2 = collect(Vᴴ2)
hNᴴ2 = collect(Nᴴ2)
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I

atol = eps(real(T))
rtol = eps(real(T))
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; atol = atol))
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; atol = atol))
@test C2 !== C
@test Vᴴ2 !== Vᴴ
@test Nᴴ2 !== Nᴴ
@test C2 * Vᴴ2 ≈ A
@test isisometric(Vᴴ2; side = :right)
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(Nᴴ; side = :right)
hVᴴ2 = collect(Vᴴ2)
hNᴴ2 = collect(Nᴴ2)
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I

C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; rtol = rtol))
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; rtol = rtol))
@test C2 !== C
@test Vᴴ2 !== Vᴴ
@test Nᴴ2 !== Nᴴ
@test C2 * Vᴴ2 ≈ A
@test isisometric(Vᴴ2; side = :right)
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(Nᴴ2; side = :right)
hVᴴ2 = collect(Vᴴ2)
hNᴴ2 = collect(Nᴴ2)
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I

@testset "kind = $kind" for kind in (:lq, :polar, :svd)
n < m && kind == :polar && continue
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind = kind)
@test C2 === C
@test Vᴴ2 === Vᴴ
@test C2 * Vᴴ2 ≈ A
@test isisometric(Vᴴ2; side = :right)
if kind != :polar
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind = kind)
@test Nᴴ2 === Nᴴ
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(Nᴴ2; side = :right)
hVᴴ2 = collect(Vᴴ2)
hNᴴ2 = collect(Nᴴ2)
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I
end

if kind == :svd
C2, Vᴴ2 = @constinferred right_orth!(
copy!(Ac, A), (C, Vᴴ); kind = kind,
trunc = (; atol = atol)
)
Nᴴ2 = @constinferred right_null!(
copy!(Ac, A), Nᴴ; kind = kind,
trunc = (; atol = atol)
)
@test C2 !== C
@test Vᴴ2 !== Vᴴ
@test Nᴴ2 !== Nᴴ
@test C2 * Vᴴ2 ≈ A
@test isisometric(Vᴴ2; side = :right)
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(Nᴴ2; side = :right)
hVᴴ2 = collect(Vᴴ2)
hNᴴ2 = collect(Nᴴ2)
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I

C2, Vᴴ2 = @constinferred right_orth!(
copy!(Ac, A), (C, Vᴴ); kind = kind,
trunc = (; rtol = rtol)
)
Nᴴ2 = @constinferred right_null!(
copy!(Ac, A), Nᴴ; kind = kind,
trunc = (; rtol = rtol)
)
@test C2 !== C
@test Vᴴ2 !== Vᴴ
@test Nᴴ2 !== Nᴴ
@test C2 * Vᴴ2 ≈ A
@test isisometric(Vᴴ2; side = :right)
@test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(Nᴴ2; side = :right)
hVᴴ2 = collect(Vᴴ2)
hNᴴ2 = collect(Nᴴ2)
@test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ diagm(ones(T, size(Vᴴ2, 2))) atol = m * n * MatrixAlgebraKit.defaulttol(T)
else
@test_throws ArgumentError right_orth!(
copy!(Ac, A), (C, Vᴴ); kind = kind,
trunc = (; atol = atol)
)
@test_throws ArgumentError right_orth!(
copy!(Ac, A), (C, Vᴴ); kind = kind,
trunc = (; rtol = rtol)
)
@test_throws ArgumentError right_null!(
copy!(Ac, A), Nᴴ; kind = kind,
trunc = (; atol = atol)
)
@test_throws ArgumentError right_null!(
copy!(Ac, A), Nᴴ; kind = kind,
trunc = (; rtol = rtol)
)
end
end
end
end
Loading