From d4690e404fe39caf120ec6acf4b95db6a7831ee8 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Wed, 17 Dec 2025 05:59:10 -0500 Subject: [PATCH] Add GPU-compatible iszero for sparse arrays MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This adds a Base.iszero method for AbstractGPUSparseArray that avoids scalar indexing by using GPU-compatible operations: - nnz(A) == 0 checks if there are no stored elements - all(iszero, nonzeros(A)) uses GPU reduction on stored values This fixes compatibility with packages like SciMLOperators.jl that call iszero on operators during arithmetic operations. The implementation covers all GPU sparse array types (vectors, CSC, CSR, BSR, COO matrices) through the abstract supertype. Addresses: JuliaGPU/CUDA.jl#2997 Ref: https://github.com/SciML/SciMLOperators.jl/issues/338 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/host/sparse.jl | 5 ++++ test/testsuite/sparse.jl | 61 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/src/host/sparse.jl b/src/host/sparse.jl index 3fae31877..c5608b9d1 100644 --- a/src/host/sparse.jl +++ b/src/host/sparse.jl @@ -36,6 +36,11 @@ Base.Array(x::AbstractGPUSparseMatrixCSR) = collect(SparseMatrixCSC(x)) Base.Array(x::AbstractGPUSparseMatrixBSR) = collect(SparseMatrixCSC(x)) Base.Array(x::AbstractGPUSparseMatrixCOO) = collect(SparseMatrixCSC(x)) +# iszero that avoids scalar indexing by using GPU-compatible reduction +# nnz(A) == 0 means no stored elements, so it's definitely zero +# all(iszero, nonzeros(A)) uses GPU reduction to check stored values +Base.iszero(A::AbstractGPUSparseArray) = SparseArrays.nnz(A) == 0 || all(iszero, SparseArrays.nonzeros(A)) + SparseArrays.SparseVector(x::AbstractGPUSparseVector) = SparseVector(length(x), Array(SparseArrays.nonzeroinds(x)), Array(SparseArrays.nonzeros(x))) SparseArrays.SparseMatrixCSC(x::AbstractGPUSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(SparseArrays.getcolptr(x)), Array(SparseArrays.rowvals(x)), Array(SparseArrays.nonzeros(x))) diff --git a/test/testsuite/sparse.jl b/test/testsuite/sparse.jl index 08c5ab12c..a31abe6d5 100644 --- a/test/testsuite/sparse.jl +++ b/test/testsuite/sparse.jl @@ -5,12 +5,14 @@ vector(sparse_AT, eltypes) vector_construction(sparse_AT, eltypes) broadcasting_vector(sparse_AT, eltypes) + iszero_vector(sparse_AT, eltypes) elseif sparse_AT <: AbstractSparseMatrix matrix(sparse_AT, eltypes) matrix_construction(sparse_AT, eltypes) broadcasting_matrix(sparse_AT, eltypes) mapreduce_matrix(sparse_AT, eltypes) linalg(sparse_AT, eltypes) + iszero_matrix(sparse_AT, eltypes) end end end @@ -361,3 +363,62 @@ function linalg(AT, eltypes) end end end + +function iszero_vector(AT, eltypes) + for ET in eltypes + @testset "iszero SparseVector($ET)" begin + m = 25 + + # Test non-zero sparse vector + x = sprand(ET, m, 0.5) + while iszero(x) + x = sprand(ET, m, 0.5) + end + d_x = AT(x) + @test iszero(d_x) == iszero(x) + @test iszero(d_x) == false + + # Test zero sparse vector (no stored elements) + z = spzeros(ET, m) + d_z = AT(z) + @test iszero(d_z) == iszero(z) + @test iszero(d_z) == true + + # Test sparse vector with stored zeros (e.g., after operations) + # Create a sparse vector then multiply by zero + x_zeros = x .* zero(ET) + d_x_zeros = d_x .* zero(ET) + @test iszero(d_x_zeros) == iszero(x_zeros) + @test iszero(d_x_zeros) == true + end + end +end + +function iszero_matrix(AT, eltypes) + for ET in eltypes + @testset "iszero SparseMatrix($ET)" begin + m, n = 10, 10 + + # Test non-zero sparse matrix + A = sprand(ET, m, n, 0.5) + while iszero(A) + A = sprand(ET, m, n, 0.5) + end + dA = AT(A) + @test iszero(dA) == iszero(A) + @test iszero(dA) == false + + # Test zero sparse matrix (no stored elements) + ZA = spzeros(ET, m, n) + dZA = AT(ZA) + @test iszero(dZA) == iszero(ZA) + @test iszero(dZA) == true + + # Test sparse matrix with stored zeros + A_zeros = A .* zero(ET) + dA_zeros = dA .* zero(ET) + @test iszero(dA_zeros) == iszero(A_zeros) + @test iszero(dA_zeros) == true + end + end +end