diff --git a/Project.toml b/Project.toml index 579daf6..fd11f26 100644 --- a/Project.toml +++ b/Project.toml @@ -3,5 +3,15 @@ uuid = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39" authors = ["ITensor developers and contributors"] version = "0.1.0" +[deps] +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +DiskArrays = "3c3547ce-8d99-4f5e-a174-61eb10b00ae3" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + [compat] +ConstructionBase = "1.5.8" +DiskArrays = "0.4.12" +LinearAlgebra = "1.10" +Serialization = "1.10" julia = "1.10" diff --git a/src/SerializedArrays.jl b/src/SerializedArrays.jl index ea728be..b888e25 100644 --- a/src/SerializedArrays.jl +++ b/src/SerializedArrays.jl @@ -1,5 +1,120 @@ module SerializedArrays -# Write your package code here. +using ConstructionBase: constructorof +using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked +using LinearAlgebra: LinearAlgebra, mul! +using Serialization: deserialize, serialize + +struct SerializedArray{T,N,A<:AbstractArray{T,N},Axes} <: AbstractDiskArray{T,N} + file::String + axes::Axes +end +file(a::SerializedArray) = getfield(a, :file) +Base.axes(a::SerializedArray) = getfield(a, :axes) +arraytype(a::SerializedArray{<:Any,<:Any,A}) where {A} = A + +function SerializedArray(file::String, a::AbstractArray) + serialize(file, a) + ax = axes(a) + return SerializedArray{eltype(a),ndims(a),typeof(a),typeof(ax)}(file, ax) +end +function SerializedArray(a::AbstractArray) + return SerializedArray(tempname(), a) +end + +function Base.similar(a::SerializedArray, elt::Type, dims::Tuple{Vararg{Int}}) + return constructorof(arraytype(a)){elt}(undef, dims...) +end + +function Base.copy(a::SerializedArray) + arrayt = arraytype(a) + return convert(arrayt, deserialize(file(a)))::arrayt +end + +Base.size(a::SerializedArray) = length.(axes(a)) + +# +# DiskArrays +# + +DiskArrays.haschunks(::SerializedArray) = Unchunked() +function DiskArrays.readblock!( + a::SerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N} +) where {N} + if i == axes(a) + aout .= copy(a) + return a + end + aout .= @view copy(a)[i...] + return a +end +function DiskArrays.writeblock!( + a::SerializedArray{<:Any,N}, ain, i::Vararg{AbstractUnitRange,N} +) where {N} + if i == axes(a) + serialize(file(a), ain) + return a + end + a′ = copy(a) + a′[i...] = ain + serialize(file(a), a′) + return a +end +function DiskArrays.create_outputarray(::Nothing, a::SerializedArray, output_size::Tuple) + return similar(a, output_size) +end + +# +# Broadcast +# + +using Base.Broadcast: + BroadcastStyle, Broadcasted, DefaultArrayStyle, combine_styles, flatten + +struct SerializedArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end +Base.BroadcastStyle(arrayt::Type{<:SerializedArray}) = SerializedArrayStyle{ndims(arrayt)}() +function Base.BroadcastStyle( + ::SerializedArrayStyle{N}, ::SerializedArrayStyle{M} +) where {N,M} + SerializedArrayStyle{max(N, M)}() +end +function Base.BroadcastStyle(::SerializedArrayStyle{N}, ::DefaultArrayStyle{M}) where {N,M} + return SerializedArrayStyle{max(N, M)}() +end +function Base.BroadcastStyle(::DefaultArrayStyle{M}, ::SerializedArrayStyle{N}) where {N,M} + return SerializedArrayStyle{max(N, M)}() +end + +struct BroadcastSerializedArray{T,N,BC<:Broadcasted{<:SerializedArrayStyle{N}}} <: + AbstractDiskArray{T,N} + broadcasted::BC +end +function BroadcastSerializedArray( + broadcasted::B +) where {B<:Broadcasted{<:SerializedArrayStyle{N}}} where {N} + ElType = Base.Broadcast.combine_eltypes(broadcasted.f, broadcasted.args) + return BroadcastSerializedArray{ElType,N,B}(broadcasted) +end +Base.size(a::BroadcastSerializedArray) = size(a.broadcasted) +Base.broadcastable(a::BroadcastSerializedArray) = a.broadcasted +function Base.copy(a::BroadcastSerializedArray) + # Broadcast over the materialized arrays. + return copy(Base.Broadcast.broadcasted(a.broadcasted.f, copy.(a.broadcasted.args)...)) +end + +function Base.copy(broadcasted::Broadcasted{SerializedArrayStyle{N}}) where {N} + return BroadcastSerializedArray(flatten(broadcasted)) +end + +# +# LinearAlgebra +# + +function LinearAlgebra.mul!( + a_dest::AbstractMatrix, a1::SerializedArray, a2::SerializedArray, α::Number, β::Number +) + mul!(a_dest, copy(a1), copy(a2), α, β) + return a_dest +end end diff --git a/test/Project.toml b/test/Project.toml index 3309f69..3064946 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,13 +1,21 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SerializedArrays = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" [compat] -Aqua = "0.8.9" +Aqua = "0.8" +GPUArraysCore = "0.2" +JLArrays = "0.2" SafeTestsets = "0.1" SerializedArrays = "0.1" +StableRNGs = "1" Suppressor = "0.2" Test = "1.10" +TestExtras = "0.3" diff --git a/test/test_basics.jl b/test/test_basics.jl index 808728b..5bc6690 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,6 +1,39 @@ -using SerializedArrays: SerializedArrays +using GPUArraysCore: @allowscalar +using JLArrays: JLArray +using SerializedArrays: SerializedArray +using StableRNGs: StableRNG using Test: @test, @testset +using TestExtras: @constinferred -@testset "SerializedArrays" begin - # Tests go here. +elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) +arrayts = (Array, JLArray) +@testset "SerializedArrays (eltype=$elt, arraytype=$arrayt)" for elt in elts, + arrayt in arrayts + + rng = StableRNG(123) + x = arrayt(randn(rng, elt, 4, 4)) + a = SerializedArray(x) + @test @constinferred(copy(a)) == x + @test typeof(copy(a)) == typeof(x) + + x = arrayt(zeros(elt, 4, 4)) + a = SerializedArray(x) + @allowscalar begin + a[1, 1] = 2 + @test @constinferred(a[1, 1]) == 2 + end + + x = arrayt(zeros(elt, 4, 4)) + a = SerializedArray(x) + b = 2a + @test @constinferred(copy(b)) == 2x + + rng = StableRNG(123) + x = arrayt(randn(rng, elt, 4, 4)) + y = arrayt(randn(rng, elt, 4, 4)) + a = SerializedArray(x) + b = SerializedArray(y) + c = @constinferred(a * b) + @test c == x * y + @test c isa arrayt{elt,2} end