From f7c14ed8fb325fa39292a95ae8fbd654f6660846 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 16 Apr 2025 16:49:19 -0400 Subject: [PATCH 1/3] Distinguish deepmemory from memory --- Project.toml | 2 +- src/SerializedArrays.jl | 95 ++++++++++++++++++++++------------------- test/test_basics.jl | 2 +- 3 files changed, 52 insertions(+), 47 deletions(-) diff --git a/Project.toml b/Project.toml index bf08400..628893f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SerializedArrays" uuid = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39" authors = ["ITensor developers and contributors"] -version = "0.1.3" +version = "0.2.0" [deps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" diff --git a/src/SerializedArrays.jl b/src/SerializedArrays.jl index 13e1808..7b926ac 100644 --- a/src/SerializedArrays.jl +++ b/src/SerializedArrays.jl @@ -1,11 +1,25 @@ module SerializedArrays +export SerializedArray, disk, memory + using Base.PermutedDimsArrays: genperm using ConstructionBase: constructorof using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked, readblock!, writeblock! using Serialization: deserialize, serialize -memory(a) = a +adapt_serialized(to, x) = adapt_structure_serialized(to, x) +adapt_serialized(to) = Base.Fix1(adapt_structure_serialized, to) +adapt_structure_serialized(to, x) = adapt_storage_serialized(to, x) +adapt_storage_serialized(to, x) = x + +struct DeepMemoryAdaptor end +deepmemory(x) = adapt_serialized(DeepMemoryAdaptor(), x) + +struct MemoryAdaptor end +memory(x) = adapt_serialized(MemoryAdaptor(), x) +function adapt_storage_serialized(::MemoryAdaptor, x) + return adapt_serialized(DeepMemoryAdaptor(), x) +end # # AbstractSerializedArray @@ -15,9 +29,12 @@ abstract type AbstractSerializedArray{T,N} <: AbstractDiskArray{T,N} end const AbstractSerializedMatrix{T} = AbstractSerializedArray{T,2} const AbstractSerializedVector{T} = AbstractSerializedArray{T,1} -memory(a::AbstractSerializedArray) = copy(a) disk(a::AbstractSerializedArray) = a +function Base.copy(a::AbstractSerializedArray) + return copy(deepmemory(a)) +end + function _copyto_write!(dst, src) writeblock!(dst, src, axes(src)...) return dst @@ -62,18 +79,6 @@ function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractArray) return equals_serialized(a1, a2) end -# # These cause too many ambiguity errors, try bringing them back. -# function Base.convert(arrayt::Type{<:AbstractSerializedArray}, a::AbstractArray) -# return arrayt(a) -# end -# function Base.convert(arrayt::Type{<:AbstractArray}, a::AbstractSerializedArray) -# return convert(arrayt, memory(a)) -# end -# # Fixes ambiguity error. -# function Base.convert(arrayt::Type{<:Array}, a::AbstractSerializedArray) -# return convert(arrayt, memory(a)) -# end - # # SerializedArray # @@ -105,11 +110,11 @@ function Base.similar(a::SerializedArray, elt::Type, dims::Tuple{Vararg{Int}}) return constructorof(arraytype(a)){elt}(undef, dims...) end -function materialize(a::SerializedArray) +function adapt_structure_serialized(::DeepMemoryAdaptor, a::SerializedArray) return deserialize(file(a))::arraytype(a) end function Base.copy(a::SerializedArray) - return materialize(a) + return deepmemory(a) end Base.size(a::SerializedArray) = length.(axes(a)) @@ -123,7 +128,7 @@ function DiskArrays.readblock!( a::SerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N} ) where {N} if i == axes(a) - aout .= memory(a) + aout .= deepmemory(a) return a end aout .= @view memory(a)[i...] @@ -179,11 +184,13 @@ function Base.similar(a::PermutedSerializedArray, elt::Type, dims::Tuple{Vararg{ return similar(parent(a), elt, dims) end -function materialize(a::PermutedSerializedArray) - return PermutedDimsArray(memory(parent(a)), perm(a)) +function adapt_structure_serialized(to, a::PermutedSerializedArray) + return PermutedDimsArray(adapt_serialized(to, parent(a)), perm(a)) end -function Base.copy(a::PermutedSerializedArray) - return copy(materialize(a)) + +# Special case to eagerly instantiate permutations. +function adapt_structure_serialized(to::MemoryAdaptor, a::PermutedSerializedArray) + return copy(a) end haschunks(a::PermutedSerializedArray) = Unchunked() @@ -238,11 +245,11 @@ function Base.similar(a::ReshapedSerializedArray, elt::Type, dims::Tuple{Vararg{ return similar(parent(a), elt, dims) end -function materialize(a::ReshapedSerializedArray) - return reshape(materialize(parent(a)), axes(a)) +function adapt_structure_serialized(to, a::ReshapedSerializedArray) + return reshape(adapt_serialized(to, parent(a)), axes(a)) end function Base.copy(a::ReshapedSerializedArray) - a′ = materialize(a) + a′ = deepmemory(a) return a′ isa Base.ReshapedArray ? copy(a′) : a′ end @@ -250,7 +257,7 @@ end # friendly on GPU. Consider special cases of strded arrays # and handle with stride manipulations. function Base.copy(a::ReshapedSerializedArray{<:Any,<:Any,<:PermutedSerializedArray}) - a′ = reshape(memory(parent(a)), axes(a)) + a′ = memory(a) return a′ isa Base.ReshapedArray ? copy(a′) : a′ end @@ -306,17 +313,14 @@ Base.axes(a::SubSerializedArray) = axes(a.sub_parent) Base.parent(a::SubSerializedArray) = parent(a.sub_parent) Base.parentindices(a::SubSerializedArray) = parentindices(a.sub_parent) -function materialize(a::SubSerializedArray) - return view(copy(parent(a)), parentindices(a)...) -end -function Base.copy(a::SubSerializedArray) - return copy(materialize(a)) +function adapt_structure_serialized(to, a::SubSerializedArray) + return view(adapt_serialized(to, parent(a)), parentindices(a)...) end DiskArrays.haschunks(a::SubSerializedArray) = Unchunked() function DiskArrays.readblock!(a::SubSerializedArray, aout, i::OrdinalRange...) if i == axes(a) - aout .= memory(a) + aout .= deepmemory(a) end aout[i...] = memory(view(a, i...)) return nothing @@ -326,7 +330,7 @@ function DiskArrays.writeblock!(a::SubSerializedArray, ain, i::OrdinalRange...) serialize(file(a), ain) return a end - a_parent = memory(parent(a)) + a_parent = deepmemory(parent(a)) pinds = parentindices(view(a.sub_parent, i...)) a_parent[pinds...] = ain serialize(file(a), a_parent) @@ -357,11 +361,8 @@ function Base.similar(a::TransposeSerializedArray, elt::Type, dims::Tuple{Vararg return similar(parent(a), elt, dims) end -function materialize(a::TransposeSerializedArray) - return transpose(memory(parent(a))) -end -function Base.copy(a::TransposeSerializedArray) - return copy(materialize(a)) +function adapt_structure_serialized(to, a::TransposeSerializedArray) + return transpose(adapt_serialized(to, parent(a))) end haschunks(a::TransposeSerializedArray) = Unchunked() @@ -400,11 +401,8 @@ function Base.similar(a::AdjointSerializedArray, elt::Type, dims::Tuple{Vararg{I return similar(parent(a), elt, dims) end -function materialize(a::AdjointSerializedArray) - return adjoint(memory(parent(a))) -end -function Base.copy(a::AdjointSerializedArray) - return copy(materialize(a)) +function adapt_structure_serialized(to, a::AdjointSerializedArray) + return adjoint(adapt_serialized(to, parent(a))) end haschunks(a::AdjointSerializedArray) = Unchunked() @@ -452,9 +450,16 @@ function BroadcastSerializedArray( end Base.size(a::BroadcastSerializedArray) = size(a.broadcasted) Base.broadcastable(a::BroadcastSerializedArray) = a.broadcasted -function Base.copy(a::BroadcastSerializedArray) - # Broadcast over the materialized arrays. - return copy(Base.Broadcast.broadcasted(a.broadcasted.f, memory.(a.broadcasted.args)...)) + +function adapt_structure_serialized(to, a::BroadcastSerializedArray) + return Base.Broadcast.broadcasted( + a.broadcasted.f, map(adapt_serialized(to), a.broadcasted.args)... + ) +end + +# Special case to eagerly instantiate broadcasts. +function adapt_storage_serialized(::MemoryAdaptor, a::BroadcastSerializedArray) + return copy(a) end function Base.copy(broadcasted::Broadcasted{SerializedArrayStyle{N}}) where {N} diff --git a/test/test_basics.jl b/test/test_basics.jl index b686735..7c667be 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -141,7 +141,7 @@ arrayts = (Array, JLArray) rng = StableRNG(123) x = arrayt(randn(rng, elt, 4, 4)) y = @view x[2:3, 2:3] - a = SerializedArray(a) + a = SerializedArray(x) b = @view a[2:3, 2:3] @test b isa SubSerializedArray{elt,2} c = 2b From 917b4e64d10877d6ea3bf803ee8897514965f405 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 16 Apr 2025 16:57:03 -0400 Subject: [PATCH 2/3] Tweaks --- src/SerializedArrays.jl | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/SerializedArrays.jl b/src/SerializedArrays.jl index 7b926ac..97a177c 100644 --- a/src/SerializedArrays.jl +++ b/src/SerializedArrays.jl @@ -32,7 +32,7 @@ const AbstractSerializedVector{T} = AbstractSerializedArray{T,1} disk(a::AbstractSerializedArray) = a function Base.copy(a::AbstractSerializedArray) - return copy(deepmemory(a)) + return copy(memory(a)) end function _copyto_write!(dst, src) @@ -114,7 +114,7 @@ function adapt_structure_serialized(::DeepMemoryAdaptor, a::SerializedArray) return deserialize(file(a))::arraytype(a) end function Base.copy(a::SerializedArray) - return deepmemory(a) + return memory(a) end Base.size(a::SerializedArray) = length.(axes(a)) @@ -190,7 +190,7 @@ end # Special case to eagerly instantiate permutations. function adapt_structure_serialized(to::MemoryAdaptor, a::PermutedSerializedArray) - return copy(a) + return copy(deepmemory(a)) end haschunks(a::PermutedSerializedArray) = Unchunked() @@ -249,14 +249,9 @@ function adapt_structure_serialized(to, a::ReshapedSerializedArray) return reshape(adapt_serialized(to, parent(a)), axes(a)) end function Base.copy(a::ReshapedSerializedArray) - a′ = deepmemory(a) - return a′ isa Base.ReshapedArray ? copy(a′) : a′ -end - -# Special case for handling nested wrappers that aren't -# friendly on GPU. Consider special cases of strded arrays -# and handle with stride manipulations. -function Base.copy(a::ReshapedSerializedArray{<:Any,<:Any,<:PermutedSerializedArray}) + # `memory` instantiates `PermutedSerializedArray`, which is + # friendlier for GPU. Consider special cases of strded arrays + # and handle with stride manipulations. a′ = memory(a) return a′ isa Base.ReshapedArray ? copy(a′) : a′ end From 71e4cb0ddc1e8851796b31f2251033527fc270cf Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 16 Apr 2025 17:06:29 -0400 Subject: [PATCH 3/3] Simplification of adapt logic --- docs/Project.toml | 2 +- examples/Project.toml | 2 +- src/SerializedArrays.jl | 13 +++++++++---- test/Project.toml | 2 +- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 25db743..6b6f724 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,4 +6,4 @@ SerializedArrays = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39" [compat] Documenter = "1" Literate = "2" -SerializedArrays = "0.1" +SerializedArrays = "0.2" diff --git a/examples/Project.toml b/examples/Project.toml index e611acc..5a186da 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -2,4 +2,4 @@ SerializedArrays = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39" [compat] -SerializedArrays = "0.1" +SerializedArrays = "0.2" diff --git a/src/SerializedArrays.jl b/src/SerializedArrays.jl index 97a177c..dbed708 100644 --- a/src/SerializedArrays.jl +++ b/src/SerializedArrays.jl @@ -17,9 +17,6 @@ deepmemory(x) = adapt_serialized(DeepMemoryAdaptor(), x) struct MemoryAdaptor end memory(x) = adapt_serialized(MemoryAdaptor(), x) -function adapt_storage_serialized(::MemoryAdaptor, x) - return adapt_serialized(DeepMemoryAdaptor(), x) -end # # AbstractSerializedArray @@ -110,9 +107,17 @@ function Base.similar(a::SerializedArray, elt::Type, dims::Tuple{Vararg{Int}}) return constructorof(arraytype(a)){elt}(undef, dims...) end -function adapt_structure_serialized(::DeepMemoryAdaptor, a::SerializedArray) +function _memory(a::SerializedArray) return deserialize(file(a))::arraytype(a) end + +function adapt_storage_serialized(::DeepMemoryAdaptor, a::SerializedArray) + return _memory(a) +end +function adapt_storage_serialized(::MemoryAdaptor, a::SerializedArray) + return _memory(a) +end + function Base.copy(a::SerializedArray) return memory(a) end diff --git a/test/Project.toml b/test/Project.toml index d9107f5..2efa00f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -18,7 +18,7 @@ GPUArraysCore = "0.2" JLArrays = "0.2" LinearAlgebra = "1.10" SafeTestsets = "0.1" -SerializedArrays = "0.1" +SerializedArrays = "0.2" StableRNGs = "1" Suppressor = "0.2" Test = "1.10"