Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,15 @@ uuid = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.1.0"

[deps]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DiskArrays = "3c3547ce-8d99-4f5e-a174-61eb10b00ae3"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[compat]
ConstructionBase = "1.5.8"
DiskArrays = "0.4.12"
LinearAlgebra = "1.10"
Serialization = "1.10"
julia = "1.10"
117 changes: 116 additions & 1 deletion src/SerializedArrays.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,120 @@
module SerializedArrays

# Write your package code here.
using ConstructionBase: constructorof
using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked
using LinearAlgebra: LinearAlgebra, mul!
using Serialization: deserialize, serialize

struct SerializedArray{T,N,A<:AbstractArray{T,N},Axes} <: AbstractDiskArray{T,N}
file::String
axes::Axes
end
file(a::SerializedArray) = getfield(a, :file)
Base.axes(a::SerializedArray) = getfield(a, :axes)
arraytype(a::SerializedArray{<:Any,<:Any,A}) where {A} = A

function SerializedArray(file::String, a::AbstractArray)
serialize(file, a)
ax = axes(a)
return SerializedArray{eltype(a),ndims(a),typeof(a),typeof(ax)}(file, ax)
end
function SerializedArray(a::AbstractArray)
return SerializedArray(tempname(), a)
end

function Base.similar(a::SerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
return constructorof(arraytype(a)){elt}(undef, dims...)
end

function Base.copy(a::SerializedArray)
arrayt = arraytype(a)
return convert(arrayt, deserialize(file(a)))::arrayt
end

Base.size(a::SerializedArray) = length.(axes(a))

#
# DiskArrays
#

DiskArrays.haschunks(::SerializedArray) = Unchunked()
function DiskArrays.readblock!(
a::SerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N}
) where {N}
if i == axes(a)
aout .= copy(a)
return a

Check warning on line 46 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L45-L46

Added lines #L45 - L46 were not covered by tests
end
aout .= @view copy(a)[i...]
return a
end
function DiskArrays.writeblock!(
a::SerializedArray{<:Any,N}, ain, i::Vararg{AbstractUnitRange,N}
) where {N}
if i == axes(a)
serialize(file(a), ain)
return a

Check warning on line 56 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L55-L56

Added lines #L55 - L56 were not covered by tests
end
a′ = copy(a)
a′[i...] = ain
serialize(file(a), a′)
return a
end
function DiskArrays.create_outputarray(::Nothing, a::SerializedArray, output_size::Tuple)
return similar(a, output_size)

Check warning on line 64 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L63-L64

Added lines #L63 - L64 were not covered by tests
end

#
# Broadcast
#

using Base.Broadcast:
BroadcastStyle, Broadcasted, DefaultArrayStyle, combine_styles, flatten

struct SerializedArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
Base.BroadcastStyle(arrayt::Type{<:SerializedArray}) = SerializedArrayStyle{ndims(arrayt)}()
function Base.BroadcastStyle(

Check warning on line 76 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L76

Added line #L76 was not covered by tests
::SerializedArrayStyle{N}, ::SerializedArrayStyle{M}
) where {N,M}
SerializedArrayStyle{max(N, M)}()

Check warning on line 79 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L79

Added line #L79 was not covered by tests
end
function Base.BroadcastStyle(::SerializedArrayStyle{N}, ::DefaultArrayStyle{M}) where {N,M}
return SerializedArrayStyle{max(N, M)}()
end
function Base.BroadcastStyle(::DefaultArrayStyle{M}, ::SerializedArrayStyle{N}) where {N,M}
return SerializedArrayStyle{max(N, M)}()
end

struct BroadcastSerializedArray{T,N,BC<:Broadcasted{<:SerializedArrayStyle{N}}} <:
AbstractDiskArray{T,N}
broadcasted::BC
end
function BroadcastSerializedArray(
broadcasted::B
) where {B<:Broadcasted{<:SerializedArrayStyle{N}}} where {N}
ElType = Base.Broadcast.combine_eltypes(broadcasted.f, broadcasted.args)
return BroadcastSerializedArray{ElType,N,B}(broadcasted)
end
Base.size(a::BroadcastSerializedArray) = size(a.broadcasted)
Base.broadcastable(a::BroadcastSerializedArray) = a.broadcasted

Check warning on line 99 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L98-L99

Added lines #L98 - L99 were not covered by tests
function Base.copy(a::BroadcastSerializedArray)
# Broadcast over the materialized arrays.
return copy(Base.Broadcast.broadcasted(a.broadcasted.f, copy.(a.broadcasted.args)...))
end

function Base.copy(broadcasted::Broadcasted{SerializedArrayStyle{N}}) where {N}
return BroadcastSerializedArray(flatten(broadcasted))
end

#
# LinearAlgebra
#

function LinearAlgebra.mul!(
a_dest::AbstractMatrix, a1::SerializedArray, a2::SerializedArray, α::Number, β::Number
)
mul!(a_dest, copy(a1), copy(a2), α, β)
return a_dest
end

end
10 changes: 9 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SerializedArrays = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

[compat]
Aqua = "0.8.9"
Aqua = "0.8"
GPUArraysCore = "0.2"
JLArrays = "0.2"
SafeTestsets = "0.1"
SerializedArrays = "0.1"
StableRNGs = "1"
Suppressor = "0.2"
Test = "1.10"
TestExtras = "0.3"
39 changes: 36 additions & 3 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,39 @@
using SerializedArrays: SerializedArrays
using GPUArraysCore: @allowscalar
using JLArrays: JLArray
using SerializedArrays: SerializedArray
using StableRNGs: StableRNG
using Test: @test, @testset
using TestExtras: @constinferred

@testset "SerializedArrays" begin
# Tests go here.
elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
arrayts = (Array, JLArray)
@testset "SerializedArrays (eltype=$elt, arraytype=$arrayt)" for elt in elts,
arrayt in arrayts

rng = StableRNG(123)
x = arrayt(randn(rng, elt, 4, 4))
a = SerializedArray(x)
@test @constinferred(copy(a)) == x
@test typeof(copy(a)) == typeof(x)

x = arrayt(zeros(elt, 4, 4))
a = SerializedArray(x)
@allowscalar begin
a[1, 1] = 2
@test @constinferred(a[1, 1]) == 2
end

x = arrayt(zeros(elt, 4, 4))
a = SerializedArray(x)
b = 2a
@test @constinferred(copy(b)) == 2x

rng = StableRNG(123)
x = arrayt(randn(rng, elt, 4, 4))
y = arrayt(randn(rng, elt, 4, 4))
a = SerializedArray(x)
b = SerializedArray(y)
c = @constinferred(a * b)
@test c == x * y
@test c isa arrayt{elt,2}
end
Loading