Skip to content

Commit dd7c3ad

Browse files
author
Katharine Hyatt
committed
Only use the scalar method for AMDGPU
1 parent 34839bd commit dd7c3ad

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,15 @@ function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
154154
return A, B
155155
end
156156

157+
function MatrixAlgebraKit.truncate(::typeof(MatrixAlgebraKit.left_null!), US::Tuple{TU, TS}, strategy::MatrixAlgebraKit.TruncationStrategy) where {TU <: ROCArray, TS}
158+
# TODO: avoid allocation?
159+
U, S = US
160+
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2))))
161+
ind = MatrixAlgebraKit.findtruncated(extended_S, strategy)
162+
trunc_cols = collect(1:size(U, 2))[ind]
163+
Utrunc = similar(U, (size(U, 1), length(trunc_cols)))
164+
Utrunc .= U[:, trunc_cols]
165+
return Utrunc, ind
166+
end
167+
157168
end

src/implementations/truncation.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,7 @@ function truncate(::typeof(left_null!), (U, S), strategy::TruncationStrategy)
1717
# TODO: avoid allocation?
1818
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2))))
1919
ind = findtruncated(extended_S, strategy)
20-
trunc_cols = collect(1:size(U, 2))[ind]
21-
Utrunc = similar(U, (size(U, 1), length(trunc_cols)))
22-
Utrunc .= U[:, trunc_cols]
23-
return Utrunc, ind
20+
return U[:, ind], ind
2421
end
2522
function truncate(::typeof(right_null!), (S, Vᴴ), strategy::TruncationStrategy)
2623
# TODO: avoid allocation?

0 commit comments

Comments
 (0)