diff --git a/src/array_partition.jl b/src/array_partition.jl index 953011b9..857cf23c 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -31,6 +31,15 @@ end @inline ArrayPartition(f::F, N) where {F <: Function} = ArrayPartition(ntuple(f, Val(N))) ArrayPartition(x...) = ArrayPartition((x...,)) +function (::Type{ArrayPartition{T, S}})(::UndefInitializer, n::Integer) where {T, S <: Tuple} + if length(S.parameters) != 1 + throw(ArgumentError("ArrayPartition{T,S}(undef, n) is only supported for a single partition")) + end + part_type = S.parameters[1] + part = part_type(undef, n) + return ArrayPartition{T, S}((part,)) +end + function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{false}) where {S <: Tuple, copy_x} T = promote_type(map(recursive_bottom_eltype, x)...) if copy_x diff --git a/test/partitions_test.jl b/test/partitions_test.jl index 2e056e7c..cf0d8c92 100644 --- a/test/partitions_test.jl +++ b/test/partitions_test.jl @@ -3,6 +3,15 @@ using RecursiveArrayTools, Test, Statistics, ArrayInterface, Adapt @test length(ArrayPartition()) == 0 @test isempty(ArrayPartition()) +# Test undef initializer for single-partition ArrayPartition +p_undef = ArrayPartition{Float64, Tuple{Vector{Float64}}}(undef, 10) +@test p_undef isa ArrayPartition{Float64, Tuple{Vector{Float64}}} +@test length(p_undef) == 10 +@test length(p_undef.x) == 1 +@test length(p_undef.x[1]) == 10 +# Test that multi-partition throws error +@test_throws ArgumentError ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}(undef, 10) + A = (rand(5), rand(5)) p = ArrayPartition(A) @inferred p[1]