Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SerializedArrays"
uuid = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.1.1"
version = "0.1.2"

[deps]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Expand Down
168 changes: 161 additions & 7 deletions src/SerializedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked, readblock!, writeblock!
using Serialization: deserialize, serialize

#
# AbstractSerializedArray
#

abstract type AbstractSerializedArray{T,N} <: AbstractDiskArray{T,N} end
const AbstractSerializedMatrix{T} = AbstractSerializedArray{T,2}
const AbstractSerializedVector{T} = AbstractSerializedArray{T,1}
Expand Down Expand Up @@ -63,6 +67,10 @@
# return convert(arrayt, copy(a))
# end

#
# SerializedArray
#

struct SerializedArray{T,N,A<:AbstractArray{T,N},Axes} <: AbstractSerializedArray{T,N}
file::String
axes::Axes
Expand Down Expand Up @@ -100,10 +108,7 @@
to_axis(r::AbstractUnitRange) = r
to_axis(d::Integer) = Base.OneTo(d)

#
# DiskArrays
#

# DiskArrays interface
DiskArrays.haschunks(::SerializedArray) = Unchunked()
function DiskArrays.readblock!(
a::SerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N}
Expand Down Expand Up @@ -131,12 +136,18 @@
return similar(a, output_size)
end

#
# PermutedSerializedArray
#

struct PermutedSerializedArray{T,N,P<:PermutedDimsArray{T,N}} <:
AbstractSerializedArray{T,N}
permuted_parent::P
end
Base.parent(a::PermutedSerializedArray) = parent(getfield(a, :permuted_parent))

file(a::PermutedSerializedArray) = file(parent(a))

Check warning on line 149 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L149

Added line #L149 was not covered by tests

perm(a::PermutedSerializedArray) = perm(a.permuted_parent)
perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p

Expand Down Expand Up @@ -172,24 +183,30 @@
# Permute the indices
inew = genperm(i, ip)
# Permute the dest block and read from the true parent
DiskArrays.readblock!(parent(a), PermutedDimsArray(aout, ip), inew...)
readblock!(parent(a), PermutedDimsArray(aout, ip), inew...)

Check warning on line 186 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L186

Added line #L186 was not covered by tests
return nothing
end
function DiskArrays.writeblock!(a::PermutedSerializedArray, v, i::OrdinalRange...)
ip = iperm(a)
inew = genperm(i, ip)
# Permute the dest block and write from the true parent
DiskArrays.writeblock!(parent(a), PermutedDimsArray(v, ip), inew...)
writeblock!(parent(a), PermutedDimsArray(v, ip), inew...)

Check warning on line 193 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L193

Added line #L193 was not covered by tests
return nothing
end

#
# ReshapedSerializedArray
#

struct ReshapedSerializedArray{T,N,P<:AbstractArray{T},Axes} <: AbstractSerializedArray{T,N}
parent::P
axes::Axes
end
Base.parent(a::ReshapedSerializedArray) = getfield(a, :parent)
Base.axes(a::ReshapedSerializedArray) = getfield(a, :axes)

file(a::ReshapedSerializedArray) = file(parent(a))

Check warning on line 208 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L208

Added line #L208 was not covered by tests

function ReshapedSerializedArray(
a::AbstractSerializedArray,
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
Expand Down Expand Up @@ -256,6 +273,141 @@
return nothing
end

#
# SubSerializedArray
#

struct SubSerializedArray{T,N,P,I,L} <: AbstractSerializedArray{T,N}
sub_parent::SubArray{T,N,P,I,L}
end

file(a::SubSerializedArray) = file(parent(a))

# Base methods
function Base.view(a::SerializedArray, i...)
return SubSerializedArray(SubArray(a, Base.to_indices(a, i)))
end
function Base.view(a::SerializedArray, i::CartesianIndices)
return SubSerializedArray(SubArray(a, Base.to_indices(a, i)))

Check warning on line 291 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L290-L291

Added lines #L290 - L291 were not covered by tests
end
Base.view(a::SubSerializedArray, i...) = SubSerializedArray(view(a.sub_parent, i...))
Base.view(a::SubSerializedArray, i::CartesianIndices) = view(a, i.indices...)

Check warning on line 294 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L294

Added line #L294 was not covered by tests
Base.size(a::SubSerializedArray) = size(a.sub_parent)
Base.axes(a::SubSerializedArray) = axes(a.sub_parent)
Base.parent(a::SubSerializedArray) = parent(a.sub_parent)
Base.parentindices(a::SubSerializedArray) = parentindices(a.sub_parent)

function materialize(a::SubSerializedArray)
return view(copy(parent(a)), parentindices(a)...)
end
function Base.copy(a::SubSerializedArray)
return copy(materialize(a))
end

DiskArrays.haschunks(a::SubSerializedArray) = Unchunked()
function DiskArrays.readblock!(a::SubSerializedArray, aout, i::OrdinalRange...)
if i == axes(a)
aout .= copy(a)

Check warning on line 310 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L310

Added line #L310 was not covered by tests
end
aout[i...] = copy(view(a, i...))
return nothing
end
function DiskArrays.writeblock!(a::SubSerializedArray, ain, i::OrdinalRange...)
if i == axes(a)
serialize(file(a), ain)
return a

Check warning on line 318 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L317-L318

Added lines #L317 - L318 were not covered by tests
end
a_parent = copy(parent(a))
pinds = parentindices(view(a.sub_parent, i...))
a_parent[pinds...] = ain
serialize(file(a), a_parent)
return nothing
end

#
# TransposeSerializedArray
#

struct TransposeSerializedArray{T,P<:AbstractSerializedArray{T}} <:
AbstractSerializedMatrix{T}
parent::P
end
Base.parent(a::TransposeSerializedArray) = getfield(a, :parent)

file(a::TransposeSerializedArray) = file(parent(a))

Check warning on line 337 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L337

Added line #L337 was not covered by tests

Base.axes(a::TransposeSerializedArray) = reverse(axes(parent(a)))
Base.size(a::TransposeSerializedArray) = length.(axes(a))

Check warning on line 340 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L340

Added line #L340 was not covered by tests

function Base.transpose(a::AbstractSerializedArray)
return TransposeSerializedArray(a)
end
Base.transpose(a::TransposeSerializedArray) = parent(a)

function Base.similar(a::TransposeSerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
return similar(parent(a), elt, dims)
end

function materialize(a::TransposeSerializedArray)
return transpose(copy(parent(a)))
end
function Base.copy(a::TransposeSerializedArray)
return copy(materialize(a))
end

haschunks(a::TransposeSerializedArray) = Unchunked()
function DiskArrays.readblock!(a::TransposeSerializedArray, aout, i::OrdinalRange...)
readblock!(parent(a), transpose(aout), reverse(i)...)
return nothing

Check warning on line 361 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L358-L361

Added lines #L358 - L361 were not covered by tests
end
function DiskArrays.writeblock!(a::TransposeSerializedArray, ain, i::OrdinalRange...)
writeblock!(parent(a), transpose(aout), reverse(i)...)
return nothing

Check warning on line 365 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L363-L365

Added lines #L363 - L365 were not covered by tests
end

#
# AdjointSerializedArray
#

struct AdjointSerializedArray{T,P<:AbstractSerializedArray{T}} <:
AbstractSerializedMatrix{T}
parent::P
end
Base.parent(a::AdjointSerializedArray) = getfield(a, :parent)

file(a::AdjointSerializedArray) = file(parent(a))

Check warning on line 378 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L378

Added line #L378 was not covered by tests

Base.axes(a::AdjointSerializedArray) = reverse(axes(parent(a)))
Base.size(a::AdjointSerializedArray) = length.(axes(a))

Check warning on line 381 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L381

Added line #L381 was not covered by tests

function Base.adjoint(a::AbstractSerializedArray)
return AdjointSerializedArray(a)
end
Base.adjoint(a::AdjointSerializedArray) = parent(a)
Base.adjoint(a::TransposeSerializedArray{<:Real}) = parent(a)
Base.transpose(a::AdjointSerializedArray{<:Real}) = parent(a)

function Base.similar(a::AdjointSerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
return similar(parent(a), elt, dims)
end

function materialize(a::AdjointSerializedArray)
return adjoint(copy(parent(a)))
end
function Base.copy(a::AdjointSerializedArray)
return copy(materialize(a))
end

haschunks(a::AdjointSerializedArray) = Unchunked()
function DiskArrays.readblock!(a::AdjointSerializedArray, aout, i::OrdinalRange...)
readblock!(parent(a), adjoint(aout), reverse(i)...)
return nothing

Check warning on line 404 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L401-L404

Added lines #L401 - L404 were not covered by tests
end
function DiskArrays.writeblock!(a::AdjointSerializedArray, ain, i::OrdinalRange...)
writeblock!(parent(a), adjoint(aout), reverse(i)...)
return nothing

Check warning on line 408 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L406-L408

Added lines #L406 - L408 were not covered by tests
end

#
# Broadcast
#
Expand All @@ -264,7 +416,9 @@
BroadcastStyle, Broadcasted, DefaultArrayStyle, combine_styles, flatten

struct SerializedArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
Base.BroadcastStyle(arrayt::Type{<:SerializedArray}) = SerializedArrayStyle{ndims(arrayt)}()
function Base.BroadcastStyle(arrayt::Type{<:AbstractSerializedArray})
SerializedArrayStyle{ndims(arrayt)}()
end
function Base.BroadcastStyle(
::SerializedArrayStyle{N}, ::SerializedArrayStyle{M}
) where {N,M}
Expand Down
48 changes: 47 additions & 1 deletion test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
using GPUArraysCore: @allowscalar
using JLArrays: JLArray
using SerializedArrays: PermutedSerializedArray, ReshapedSerializedArray, SerializedArray
using SerializedArrays:
AdjointSerializedArray,
PermutedSerializedArray,
ReshapedSerializedArray,
SerializedArray,
SubSerializedArray,
TransposeSerializedArray
using StableRNGs: StableRNG
using Test: @test, @testset
using TestExtras: @constinferred
Expand Down Expand Up @@ -50,6 +56,33 @@ arrayts = (Array, JLArray)
@test a isa PermutedSerializedArray{elt,2}
@test similar(a) isa arrayt{elt,2}
@test copy(a) == permutedims(x, (2, 1))
@test copy(2a) == 2permutedims(x, (2, 1))

rng = StableRNG(123)
x = arrayt(randn(rng, elt, 4, 4))
a = transpose(SerializedArray(x))
@test a isa TransposeSerializedArray{elt}
@test similar(a) isa arrayt{elt,2}
@test copy(a) == transpose(x)
@test copy(2a) == 2transpose(x)

rng = StableRNG(123)
x = arrayt(randn(rng, elt, 4, 4))
a = adjoint(SerializedArray(x))
@test a isa AdjointSerializedArray{elt}
@test similar(a) isa arrayt{elt,2}
@test copy(a) == adjoint(x)
@test copy(2a) == 2adjoint(x)

rng = StableRNG(123)
x = arrayt(randn(rng, elt, 4, 4))
a = SerializedArray(x)
@test transpose(transpose(a)) === a
@test adjoint(adjoint(a)) === a
if isreal(a)
@test adjoint(transpose(a)) === a
@test transpose(adjoint(a)) === a
end

rng = StableRNG(123)
x = arrayt(randn(rng, elt, 4, 4))
Expand Down Expand Up @@ -96,4 +129,17 @@ arrayts = (Array, JLArray)
copyto!(y, a)
b = SerializedArray(y)
@test b == a

rng = StableRNG(123)
x = arrayt(randn(rng, elt, 4, 4))
y = @view x[2:3, 2:3]
a = SerializedArray(a)
b = @view a[2:3, 2:3]
@test b isa SubSerializedArray{elt,2}
c = 2b
@test 2y == copy(c)
@allowscalar begin
b[1, 1] = 2
@test @constinferred(b[1, 1]) == 2
end
end
Loading