diff --git a/Project.toml b/Project.toml index 7327e3b..a5428d3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SerializedArrays" uuid = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39" authors = ["ITensor developers and contributors"] -version = "0.1.1" +version = "0.1.2" [deps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" diff --git a/src/SerializedArrays.jl b/src/SerializedArrays.jl index b931ca9..6e46642 100644 --- a/src/SerializedArrays.jl +++ b/src/SerializedArrays.jl @@ -5,6 +5,10 @@ using ConstructionBase: constructorof using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked, readblock!, writeblock! using Serialization: deserialize, serialize +# +# AbstractSerializedArray +# + abstract type AbstractSerializedArray{T,N} <: AbstractDiskArray{T,N} end const AbstractSerializedMatrix{T} = AbstractSerializedArray{T,2} const AbstractSerializedVector{T} = AbstractSerializedArray{T,1} @@ -63,6 +67,10 @@ end # return convert(arrayt, copy(a)) # end +# +# SerializedArray +# + struct SerializedArray{T,N,A<:AbstractArray{T,N},Axes} <: AbstractSerializedArray{T,N} file::String axes::Axes @@ -100,10 +108,7 @@ Base.size(a::SerializedArray) = length.(axes(a)) to_axis(r::AbstractUnitRange) = r to_axis(d::Integer) = Base.OneTo(d) -# -# DiskArrays -# - +# DiskArrays interface DiskArrays.haschunks(::SerializedArray) = Unchunked() function DiskArrays.readblock!( a::SerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N} @@ -131,12 +136,18 @@ function DiskArrays.create_outputarray(::Nothing, a::SerializedArray, output_siz return similar(a, output_size) end +# +# PermutedSerializedArray +# + struct PermutedSerializedArray{T,N,P<:PermutedDimsArray{T,N}} <: AbstractSerializedArray{T,N} permuted_parent::P end Base.parent(a::PermutedSerializedArray) = parent(getfield(a, :permuted_parent)) +file(a::PermutedSerializedArray) = file(parent(a)) + perm(a::PermutedSerializedArray) = perm(a.permuted_parent) perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p @@ -172,17 +183,21 @@ function DiskArrays.readblock!(a::PermutedSerializedArray, aout, i::OrdinalRange # Permute the indices inew = genperm(i, ip) # Permute the dest block and read from the true parent - DiskArrays.readblock!(parent(a), PermutedDimsArray(aout, ip), inew...) + readblock!(parent(a), PermutedDimsArray(aout, ip), inew...) return nothing end function DiskArrays.writeblock!(a::PermutedSerializedArray, v, i::OrdinalRange...) ip = iperm(a) inew = genperm(i, ip) # Permute the dest block and write from the true parent - DiskArrays.writeblock!(parent(a), PermutedDimsArray(v, ip), inew...) + writeblock!(parent(a), PermutedDimsArray(v, ip), inew...) return nothing end +# +# ReshapedSerializedArray +# + struct ReshapedSerializedArray{T,N,P<:AbstractArray{T},Axes} <: AbstractSerializedArray{T,N} parent::P axes::Axes @@ -190,6 +205,8 @@ end Base.parent(a::ReshapedSerializedArray) = getfield(a, :parent) Base.axes(a::ReshapedSerializedArray) = getfield(a, :axes) +file(a::ReshapedSerializedArray) = file(parent(a)) + function ReshapedSerializedArray( a::AbstractSerializedArray, ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, @@ -256,6 +273,141 @@ function DiskArrays.writeblock!( return nothing end +# +# SubSerializedArray +# + +struct SubSerializedArray{T,N,P,I,L} <: AbstractSerializedArray{T,N} + sub_parent::SubArray{T,N,P,I,L} +end + +file(a::SubSerializedArray) = file(parent(a)) + +# Base methods +function Base.view(a::SerializedArray, i...) + return SubSerializedArray(SubArray(a, Base.to_indices(a, i))) +end +function Base.view(a::SerializedArray, i::CartesianIndices) + return SubSerializedArray(SubArray(a, Base.to_indices(a, i))) +end +Base.view(a::SubSerializedArray, i...) = SubSerializedArray(view(a.sub_parent, i...)) +Base.view(a::SubSerializedArray, i::CartesianIndices) = view(a, i.indices...) +Base.size(a::SubSerializedArray) = size(a.sub_parent) +Base.axes(a::SubSerializedArray) = axes(a.sub_parent) +Base.parent(a::SubSerializedArray) = parent(a.sub_parent) +Base.parentindices(a::SubSerializedArray) = parentindices(a.sub_parent) + +function materialize(a::SubSerializedArray) + return view(copy(parent(a)), parentindices(a)...) +end +function Base.copy(a::SubSerializedArray) + return copy(materialize(a)) +end + +DiskArrays.haschunks(a::SubSerializedArray) = Unchunked() +function DiskArrays.readblock!(a::SubSerializedArray, aout, i::OrdinalRange...) + if i == axes(a) + aout .= copy(a) + end + aout[i...] = copy(view(a, i...)) + return nothing +end +function DiskArrays.writeblock!(a::SubSerializedArray, ain, i::OrdinalRange...) + if i == axes(a) + serialize(file(a), ain) + return a + end + a_parent = copy(parent(a)) + pinds = parentindices(view(a.sub_parent, i...)) + a_parent[pinds...] = ain + serialize(file(a), a_parent) + return nothing +end + +# +# TransposeSerializedArray +# + +struct TransposeSerializedArray{T,P<:AbstractSerializedArray{T}} <: + AbstractSerializedMatrix{T} + parent::P +end +Base.parent(a::TransposeSerializedArray) = getfield(a, :parent) + +file(a::TransposeSerializedArray) = file(parent(a)) + +Base.axes(a::TransposeSerializedArray) = reverse(axes(parent(a))) +Base.size(a::TransposeSerializedArray) = length.(axes(a)) + +function Base.transpose(a::AbstractSerializedArray) + return TransposeSerializedArray(a) +end +Base.transpose(a::TransposeSerializedArray) = parent(a) + +function Base.similar(a::TransposeSerializedArray, elt::Type, dims::Tuple{Vararg{Int}}) + return similar(parent(a), elt, dims) +end + +function materialize(a::TransposeSerializedArray) + return transpose(copy(parent(a))) +end +function Base.copy(a::TransposeSerializedArray) + return copy(materialize(a)) +end + +haschunks(a::TransposeSerializedArray) = Unchunked() +function DiskArrays.readblock!(a::TransposeSerializedArray, aout, i::OrdinalRange...) + readblock!(parent(a), transpose(aout), reverse(i)...) + return nothing +end +function DiskArrays.writeblock!(a::TransposeSerializedArray, ain, i::OrdinalRange...) + writeblock!(parent(a), transpose(aout), reverse(i)...) + return nothing +end + +# +# AdjointSerializedArray +# + +struct AdjointSerializedArray{T,P<:AbstractSerializedArray{T}} <: + AbstractSerializedMatrix{T} + parent::P +end +Base.parent(a::AdjointSerializedArray) = getfield(a, :parent) + +file(a::AdjointSerializedArray) = file(parent(a)) + +Base.axes(a::AdjointSerializedArray) = reverse(axes(parent(a))) +Base.size(a::AdjointSerializedArray) = length.(axes(a)) + +function Base.adjoint(a::AbstractSerializedArray) + return AdjointSerializedArray(a) +end +Base.adjoint(a::AdjointSerializedArray) = parent(a) +Base.adjoint(a::TransposeSerializedArray{<:Real}) = parent(a) +Base.transpose(a::AdjointSerializedArray{<:Real}) = parent(a) + +function Base.similar(a::AdjointSerializedArray, elt::Type, dims::Tuple{Vararg{Int}}) + return similar(parent(a), elt, dims) +end + +function materialize(a::AdjointSerializedArray) + return adjoint(copy(parent(a))) +end +function Base.copy(a::AdjointSerializedArray) + return copy(materialize(a)) +end + +haschunks(a::AdjointSerializedArray) = Unchunked() +function DiskArrays.readblock!(a::AdjointSerializedArray, aout, i::OrdinalRange...) + readblock!(parent(a), adjoint(aout), reverse(i)...) + return nothing +end +function DiskArrays.writeblock!(a::AdjointSerializedArray, ain, i::OrdinalRange...) + writeblock!(parent(a), adjoint(aout), reverse(i)...) + return nothing +end + # # Broadcast # @@ -264,7 +416,9 @@ using Base.Broadcast: BroadcastStyle, Broadcasted, DefaultArrayStyle, combine_styles, flatten struct SerializedArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end -Base.BroadcastStyle(arrayt::Type{<:SerializedArray}) = SerializedArrayStyle{ndims(arrayt)}() +function Base.BroadcastStyle(arrayt::Type{<:AbstractSerializedArray}) + SerializedArrayStyle{ndims(arrayt)}() +end function Base.BroadcastStyle( ::SerializedArrayStyle{N}, ::SerializedArrayStyle{M} ) where {N,M} diff --git a/test/test_basics.jl b/test/test_basics.jl index a4269be..ca6087a 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,6 +1,12 @@ using GPUArraysCore: @allowscalar using JLArrays: JLArray -using SerializedArrays: PermutedSerializedArray, ReshapedSerializedArray, SerializedArray +using SerializedArrays: + AdjointSerializedArray, + PermutedSerializedArray, + ReshapedSerializedArray, + SerializedArray, + SubSerializedArray, + TransposeSerializedArray using StableRNGs: StableRNG using Test: @test, @testset using TestExtras: @constinferred @@ -50,6 +56,33 @@ arrayts = (Array, JLArray) @test a isa PermutedSerializedArray{elt,2} @test similar(a) isa arrayt{elt,2} @test copy(a) == permutedims(x, (2, 1)) + @test copy(2a) == 2permutedims(x, (2, 1)) + + rng = StableRNG(123) + x = arrayt(randn(rng, elt, 4, 4)) + a = transpose(SerializedArray(x)) + @test a isa TransposeSerializedArray{elt} + @test similar(a) isa arrayt{elt,2} + @test copy(a) == transpose(x) + @test copy(2a) == 2transpose(x) + + rng = StableRNG(123) + x = arrayt(randn(rng, elt, 4, 4)) + a = adjoint(SerializedArray(x)) + @test a isa AdjointSerializedArray{elt} + @test similar(a) isa arrayt{elt,2} + @test copy(a) == adjoint(x) + @test copy(2a) == 2adjoint(x) + + rng = StableRNG(123) + x = arrayt(randn(rng, elt, 4, 4)) + a = SerializedArray(x) + @test transpose(transpose(a)) === a + @test adjoint(adjoint(a)) === a + if isreal(a) + @test adjoint(transpose(a)) === a + @test transpose(adjoint(a)) === a + end rng = StableRNG(123) x = arrayt(randn(rng, elt, 4, 4)) @@ -96,4 +129,17 @@ arrayts = (Array, JLArray) copyto!(y, a) b = SerializedArray(y) @test b == a + + rng = StableRNG(123) + x = arrayt(randn(rng, elt, 4, 4)) + y = @view x[2:3, 2:3] + a = SerializedArray(a) + b = @view a[2:3, 2:3] + @test b isa SubSerializedArray{elt,2} + c = 2b + @test 2y == copy(c) + @allowscalar begin + b[1, 1] = 2 + @test @constinferred(b[1, 1]) == 2 + end end