From fbfcc7f8d36c56be8215a399bfa93ec92b6dd2e3 Mon Sep 17 00:00:00 2001 From: Claude Code Date: Wed, 7 Jan 2026 18:12:58 -0500 Subject: [PATCH] Fix zeromatrix to preserve GPU array types for ArrayPartition The previous implementation used `reduce(vcat, vec.(A.x))` which could cause type conversion issues with GPU arrays, leading to scalar indexing errors when using implicit ODE solvers with ArrayPartition of CuArrays. The fix uses `foldl` with an explicit `init` value from the first element of the tuple, ensuring the result array type matches the input type. This preserves GPU array types (CuArray, MtlArray, etc.) when building the zero matrix. Fixes #496 Co-Authored-By: Claude Opus 4.5 --- src/array_partition.jl | 6 +++++- src/named_array_partition.jl | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/array_partition.jl b/src/array_partition.jl index 8f9ce1f1..a90ba108 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -548,7 +548,11 @@ end ## Linear Algebra function ArrayInterface.zeromatrix(A::ArrayPartition) - x = reduce(vcat, vec.(A.x)) + # Use foldl with explicit init to preserve array type (important for GPU arrays) + # Starting with vec of first element ensures the result type matches the input + vecs = vec.(A.x) + rest = Base.tail(vecs) + x = isempty(rest) ? vecs[1] : foldl(vcat, rest; init = vecs[1]) return x .* x' .* false end diff --git a/src/named_array_partition.jl b/src/named_array_partition.jl index d80fe090..ea11379d 100644 --- a/src/named_array_partition.jl +++ b/src/named_array_partition.jl @@ -174,7 +174,10 @@ end #Overwrite ArrayInterface zeromatrix to work with NamedArrayPartitions & implicit solvers within OrdinaryDiffEq function ArrayInterface.zeromatrix(A::NamedArrayPartition) B = ArrayPartition(A) - x = reduce(vcat, vec.(B.x)) + # Use foldl with explicit init to preserve array type (important for GPU arrays) + vecs = vec.(B.x) + rest = Base.tail(vecs) + x = isempty(rest) ? vecs[1] : foldl(vcat, rest; init = vecs[1]) return x .* x' .* false end