From 59407655d5555260a041c472e7135347a4769056 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 23 Dec 2025 01:45:51 -0500 Subject: [PATCH 1/2] lmul/rmul support for Diagonals --- src/host/linalg.jl | 2 ++ test/testsuite/linalg.jl | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index bc599968..c079944a 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -683,6 +683,7 @@ function generic_rmul!(X::AbstractArray, s::Number) end LinearAlgebra.rmul!(A::AbstractGPUArray, b::Number) = generic_rmul!(A, b) +LinearAlgebra.rmul!(A::Diagonal{T, <:AbstractGPUArray}, b::Number) where {T} = A .* b function generic_lmul!(s::Number, X::AbstractArray) @kernel function lmul_kernel!(X, s) @@ -694,6 +695,7 @@ function generic_lmul!(s::Number, X::AbstractArray) end LinearAlgebra.lmul!(a::Number, B::AbstractGPUArray) = generic_lmul!(a, B) +LinearAlgebra.lmul!(a::Number, B::Diagonal{T, <:AbstractGPUArray}) where {T} = a .* B ## permutedims diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index fa701084..31637977 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -437,6 +437,21 @@ end A_empty = randn(Float32, 0, 0) @test compare(f, AT, A_empty, d) end + + @testset "rmul!/lmul! with diagonal and number" begin + n = 32 + h_d = rand(Float32, n) + h_D = Diagonal(h_d) + d = AT(h_d) + D = Diagonal(d) + a = rand(Float32) + rmul!(D, a) + rmul!(h_D, a) + @test collect(D) ≈ h_D + lmul!(a, D) + lmul!(a, h_D) + @test collect(D) ≈ h_D + end end @testsuite "linalg/mul!/vector-matrix" (AT, eltypes)->begin From 22c437df98f334e42d79b017169a1f1013cc9a31 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 23 Dec 2025 02:14:20 -0500 Subject: [PATCH 2/2] Less stupid way --- src/host/linalg.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index c079944a..1558dc5a 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -683,7 +683,7 @@ function generic_rmul!(X::AbstractArray, s::Number) end LinearAlgebra.rmul!(A::AbstractGPUArray, b::Number) = generic_rmul!(A, b) -LinearAlgebra.rmul!(A::Diagonal{T, <:AbstractGPUArray}, b::Number) where {T} = A .* b +LinearAlgebra.rmul!(A::Diagonal{T, <:AbstractGPUArray}, b::Number) where {T} = generic_rmul!(A.diag, b) function generic_lmul!(s::Number, X::AbstractArray) @kernel function lmul_kernel!(X, s) @@ -695,7 +695,7 @@ function generic_lmul!(s::Number, X::AbstractArray) end LinearAlgebra.lmul!(a::Number, B::AbstractGPUArray) = generic_lmul!(a, B) -LinearAlgebra.lmul!(a::Number, B::Diagonal{T, <:AbstractGPUArray}) where {T} = a .* B +LinearAlgebra.lmul!(a::Number, B::Diagonal{T, <:AbstractGPUArray}) where {T} = generic_lmul!(a, B.diag) ## permutedims