Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
name = "SerializedArrays"
uuid = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.1.0"
version = "0.1.1"

[deps]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DiskArrays = "3c3547ce-8d99-4f5e-a174-61eb10b00ae3"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[extensions]
SerializedArraysAdaptExt = "Adapt"
SerializedArraysLinearAlgebraExt = "LinearAlgebra"

[compat]
Adapt = "4.3.0"
ConstructionBase = "1.5.8"
DiskArrays = "0.4.12"
LinearAlgebra = "1.10"
Expand Down
10 changes: 10 additions & 0 deletions ext/SerializedArraysAdaptExt/SerializedArraysAdaptExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module SerializedArraysAdaptExt

using Adapt: Adapt
using SerializedArrays: SerializedArray

function Adapt.adapt_storage(arrayt::Type{<:SerializedArray}, a::AbstractArray)
return convert(arrayt, a)
end

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module SerializedArraysLinearAlgebraExt

using LinearAlgebra: LinearAlgebra, mul!
using SerializedArrays: AbstractSerializedMatrix

function LinearAlgebra.mul!(
a_dest::AbstractMatrix,
a1::AbstractSerializedMatrix,
a2::AbstractSerializedMatrix,
α::Number,
β::Number,
)
mul!(a_dest, copy(a1), copy(a2), α, β)
return a_dest
end

for f in [:eigen, :qr, :svd]
@eval begin
function LinearAlgebra.$f(a::AbstractSerializedMatrix; kwargs...)
return LinearAlgebra.$f(copy(a))
end
end
end

end
215 changes: 198 additions & 17 deletions src/SerializedArrays.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,69 @@
module SerializedArrays

using Base.PermutedDimsArrays: genperm
using ConstructionBase: constructorof
using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked
using LinearAlgebra: LinearAlgebra, mul!
using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked, readblock!, writeblock!
using Serialization: deserialize, serialize

struct SerializedArray{T,N,A<:AbstractArray{T,N},Axes} <: AbstractDiskArray{T,N}
abstract type AbstractSerializedArray{T,N} <: AbstractDiskArray{T,N} end
const AbstractSerializedMatrix{T} = AbstractSerializedArray{T,2}
const AbstractSerializedVector{T} = AbstractSerializedArray{T,1}

function _copyto_write!(dst, src)
writeblock!(dst, src, axes(src)...)
return dst
end
function _copyto_read!(dst, src)
readblock!(src, dst, axes(src)...)
return dst
end

function Base.copyto!(dst::AbstractSerializedArray, src::AbstractArray)
return _copyto_write!(dst, src)
end
function Base.copyto!(dst::AbstractArray, src::AbstractSerializedArray)
return _copyto_read!(dst, src)
end
# Fix ambiguity error.
function Base.copyto!(dst::AbstractSerializedArray, src::AbstractSerializedArray)
return copyto!(dst, copy(src))
end
# Fix ambiguity error.
function Base.copyto!(dst::AbstractDiskArray, src::AbstractSerializedArray)
return copyto!(dst, copy(src))

Check warning on line 33 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L32-L33

Added lines #L32 - L33 were not covered by tests
end
# Fix ambiguity error.
function Base.copyto!(dst::AbstractSerializedArray, src::AbstractDiskArray)
return _copyto_write!(dst, src)

Check warning on line 37 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L36-L37

Added lines #L36 - L37 were not covered by tests
end
# Fix ambiguity error.
function Base.copyto!(dst::PermutedDimsArray, src::AbstractSerializedArray)
return _copyto_read!(dst, src)

Check warning on line 41 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L40-L41

Added lines #L40 - L41 were not covered by tests
end

function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractSerializedArray)
return copy(a1) == copy(a2)
end
function Base.:(==)(a1::AbstractArray, a2::AbstractSerializedArray)
return a1 == copy(a2)
end
function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractArray)
return copy(a1) == a2
end

# # These cause too many ambiguity errors, try bringing them back.
# function Base.convert(arrayt::Type{<:AbstractSerializedArray}, a::AbstractArray)
# return arrayt(a)
# end
# function Base.convert(arrayt::Type{<:AbstractArray}, a::AbstractSerializedArray)
# return convert(arrayt, copy(a))
# end
# # Fixes ambiguity error.
# function Base.convert(arrayt::Type{<:Array}, a::AbstractSerializedArray)
# return convert(arrayt, copy(a))
# end

struct SerializedArray{T,N,A<:AbstractArray{T,N},Axes} <: AbstractSerializedArray{T,N}
file::String
axes::Axes
end
Expand All @@ -22,17 +80,26 @@
return SerializedArray(tempname(), a)
end

function Base.convert(arrayt::Type{<:SerializedArray}, a::AbstractArray)
return arrayt(a)
end

function Base.similar(a::SerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
return constructorof(arraytype(a)){elt}(undef, dims...)
end

function materialize(a::SerializedArray)
return deserialize(file(a))::arraytype(a)
end
function Base.copy(a::SerializedArray)
arrayt = arraytype(a)
return convert(arrayt, deserialize(file(a)))::arrayt
return materialize(a)
end

Base.size(a::SerializedArray) = length.(axes(a))

to_axis(r::AbstractUnitRange) = r

Check warning on line 100 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L100

Added line #L100 was not covered by tests
to_axis(d::Integer) = Base.OneTo(d)

#
# DiskArrays
#
Expand Down Expand Up @@ -64,6 +131,131 @@
return similar(a, output_size)
end

struct PermutedSerializedArray{T,N,P<:PermutedDimsArray{T,N}} <:
AbstractSerializedArray{T,N}
permuted_parent::P
end
Base.parent(a::PermutedSerializedArray) = parent(getfield(a, :permuted_parent))

perm(a::PermutedSerializedArray) = perm(a.permuted_parent)
perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p

iperm(a::PermutedSerializedArray) = iperm(a.permuted_parent)
iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip

Check warning on line 144 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L143-L144

Added lines #L143 - L144 were not covered by tests

Base.axes(a::PermutedSerializedArray) = genperm(axes(parent(a)), perm(a))
Base.size(a::PermutedSerializedArray) = length.(axes(a))

function PermutedSerializedArray(a::AbstractArray, perm)
a′ = PermutedDimsArray(a, perm)
return PermutedSerializedArray{eltype(a),ndims(a),typeof(a′)}(a′)
end

function Base.permutedims(a::AbstractSerializedArray, perm)
return PermutedSerializedArray(a, perm)
end

function Base.similar(a::PermutedSerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
return similar(parent(a), elt, dims)
end

function materialize(a::PermutedSerializedArray)
return PermutedDimsArray(copy(parent(a)), perm(a))
end
function Base.copy(a::PermutedSerializedArray)
return copy(materialize(a))
end

haschunks(a::PermutedSerializedArray) = Unchunked()
function DiskArrays.readblock!(a::PermutedSerializedArray, aout, i::OrdinalRange...)
ip = iperm(a)

Check warning on line 171 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L169-L171

Added lines #L169 - L171 were not covered by tests
# Permute the indices
inew = genperm(i, ip)

Check warning on line 173 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L173

Added line #L173 was not covered by tests
# Permute the dest block and read from the true parent
DiskArrays.readblock!(parent(a), PermutedDimsArray(aout, ip), inew...)
return nothing

Check warning on line 176 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L175-L176

Added lines #L175 - L176 were not covered by tests
end
function DiskArrays.writeblock!(a::PermutedSerializedArray, v, i::OrdinalRange...)
ip = iperm(a)
inew = genperm(i, ip)

Check warning on line 180 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L178-L180

Added lines #L178 - L180 were not covered by tests
# Permute the dest block and write from the true parent
DiskArrays.writeblock!(parent(a), PermutedDimsArray(v, ip), inew...)
return nothing

Check warning on line 183 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L182-L183

Added lines #L182 - L183 were not covered by tests
end

struct ReshapedSerializedArray{T,N,P<:AbstractArray{T},Axes} <: AbstractSerializedArray{T,N}
parent::P
axes::Axes
end
Base.parent(a::ReshapedSerializedArray) = getfield(a, :parent)
Base.axes(a::ReshapedSerializedArray) = getfield(a, :axes)

function ReshapedSerializedArray(
a::AbstractSerializedArray,
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
)
return ReshapedSerializedArray{eltype(a),length(ax),typeof(a),typeof(ax)}(a, ax)
end
function ReshapedSerializedArray(
a::AbstractSerializedArray,
shape::Tuple{
Union{Integer,AbstractUnitRange{<:Integer}},
Vararg{Union{Integer,AbstractUnitRange{<:Integer}}},
},
)
return ReshapedSerializedArray(a, to_axis.(shape))
end

Base.size(a::ReshapedSerializedArray) = length.(axes(a))

Check warning on line 209 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L209

Added line #L209 was not covered by tests

function Base.similar(a::ReshapedSerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
return similar(parent(a), elt, dims)
end

function materialize(a::ReshapedSerializedArray)
return reshape(materialize(parent(a)), axes(a))
end
function Base.copy(a::ReshapedSerializedArray)
a′ = materialize(a)
return a′ isa Base.ReshapedArray ? copy(a′) : a′
end

# Special case for handling nested wrappers that aren't
# friendly on GPU. Consider special cases of strded arrays
# and handle with stride manipulations.
function Base.copy(a::ReshapedSerializedArray{<:Any,<:Any,<:PermutedSerializedArray})
a′ = reshape(copy(parent(a)), axes(a))
return a′ isa Base.ReshapedArray ? copy(a′) : a′
end

function Base.reshape(a::AbstractSerializedArray, dims::Tuple{Int,Vararg{Int}})
return ReshapedSerializedArray(a, dims)
end

DiskArrays.haschunks(a::ReshapedSerializedArray) = Unchunked()
function DiskArrays.readblock!(

Check warning on line 236 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L235-L236

Added lines #L235 - L236 were not covered by tests
a::ReshapedSerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N}
) where {N}
if i == axes(a)
aout .= copy(a)
return a

Check warning on line 241 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L239-L241

Added lines #L239 - L241 were not covered by tests
end
aout .= @view copy(a)[i...]
return nothing

Check warning on line 244 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L243-L244

Added lines #L243 - L244 were not covered by tests
end
function DiskArrays.writeblock!(

Check warning on line 246 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L246

Added line #L246 was not covered by tests
a::ReshapedSerializedArray{<:Any,N}, ain, i::Vararg{AbstractUnitRange,N}
) where {N}
if i == axes(a)
serialize(file(a), ain)
return a

Check warning on line 251 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L249-L251

Added lines #L249 - L251 were not covered by tests
end
a′ = copy(a)
a′[i...] = ain
serialize(file(a), a′)
return nothing

Check warning on line 256 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L253-L256

Added lines #L253 - L256 were not covered by tests
end

#
# Broadcast
#
Expand All @@ -86,7 +278,7 @@
end

struct BroadcastSerializedArray{T,N,BC<:Broadcasted{<:SerializedArrayStyle{N}}} <:
AbstractDiskArray{T,N}
AbstractSerializedArray{T,N}
broadcasted::BC
end
function BroadcastSerializedArray(
Expand All @@ -106,15 +298,4 @@
return BroadcastSerializedArray(flatten(broadcasted))
end

#
# LinearAlgebra
#

function LinearAlgebra.mul!(
a_dest::AbstractMatrix, a1::SerializedArray, a2::SerializedArray, α::Number, β::Number
)
mul!(a_dest, copy(a1), copy(a2), α, β)
return a_dest
end

end
4 changes: 4 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SerializedArrays = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -10,9 +12,11 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

[compat]
Adapt = "4"
Aqua = "0.8"
GPUArraysCore = "0.2"
JLArrays = "0.2"
LinearAlgebra = "1.10"
SafeTestsets = "0.1"
SerializedArrays = "0.1"
StableRNGs = "1"
Expand Down
22 changes: 22 additions & 0 deletions test/test_adaptext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using Adapt: adapt
using JLArrays: JLArray
using SerializedArrays: SerializedArray
using StableRNGs: StableRNG
using Test: @test, @testset
using TestExtras: @constinferred

elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
arrayts = (Array, JLArray)
@testset "SerializedArraysAdaptExt (eltype=$elt, arraytype=$arrayt)" for elt in elts,
arrayt in arrayts

rng = StableRNG(123)
x = arrayt(randn(rng, elt, 4, 4))
y = PermutedDimsArray(x, (2, 1))
a = adapt(SerializedArray, x)
@test a isa SerializedArray{elt,2,arrayt{elt,2}}
b = adapt(SerializedArray, y)
@test b isa
PermutedDimsArray{elt,2,(2, 1),(2, 1),<:SerializedArray{elt,2,<:arrayt{elt,2}}}
@test parent(b) == a
end
Loading
Loading