diff --git a/src/array_partition.jl b/src/array_partition.jl index 8f9ce1f1..a90ba108 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -548,7 +548,11 @@ end ## Linear Algebra function ArrayInterface.zeromatrix(A::ArrayPartition) - x = reduce(vcat, vec.(A.x)) + # Use foldl with explicit init to preserve array type (important for GPU arrays) + # Starting with vec of first element ensures the result type matches the input + vecs = vec.(A.x) + rest = Base.tail(vecs) + x = isempty(rest) ? vecs[1] : foldl(vcat, rest; init = vecs[1]) return x .* x' .* false end diff --git a/src/named_array_partition.jl b/src/named_array_partition.jl index d80fe090..ea11379d 100644 --- a/src/named_array_partition.jl +++ b/src/named_array_partition.jl @@ -174,7 +174,10 @@ end #Overwrite ArrayInterface zeromatrix to work with NamedArrayPartitions & implicit solvers within OrdinaryDiffEq function ArrayInterface.zeromatrix(A::NamedArrayPartition) B = ArrayPartition(A) - x = reduce(vcat, vec.(B.x)) + # Use foldl with explicit init to preserve array type (important for GPU arrays) + vecs = vec.(B.x) + rest = Base.tail(vecs) + x = isempty(rest) ? vecs[1] : foldl(vcat, rest; init = vecs[1]) return x .* x' .* false end