diff --git a/Project.toml b/Project.toml index fd11f26..7327e3b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,15 +1,23 @@ name = "SerializedArrays" uuid = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39" authors = ["ITensor developers and contributors"] -version = "0.1.0" +version = "0.1.1" [deps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DiskArrays = "3c3547ce-8d99-4f5e-a174-61eb10b00ae3" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +[weakdeps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[extensions] +SerializedArraysAdaptExt = "Adapt" +SerializedArraysLinearAlgebraExt = "LinearAlgebra" + [compat] +Adapt = "4.3.0" ConstructionBase = "1.5.8" DiskArrays = "0.4.12" LinearAlgebra = "1.10" diff --git a/ext/SerializedArraysAdaptExt/SerializedArraysAdaptExt.jl b/ext/SerializedArraysAdaptExt/SerializedArraysAdaptExt.jl new file mode 100644 index 0000000..8dbc07b --- /dev/null +++ b/ext/SerializedArraysAdaptExt/SerializedArraysAdaptExt.jl @@ -0,0 +1,10 @@ +module SerializedArraysAdaptExt + +using Adapt: Adapt +using SerializedArrays: SerializedArray + +function Adapt.adapt_storage(arrayt::Type{<:SerializedArray}, a::AbstractArray) + return convert(arrayt, a) +end + +end diff --git a/ext/SerializedArraysLinearAlgebraExt/SerializedArraysLinearAlgebraExt.jl b/ext/SerializedArraysLinearAlgebraExt/SerializedArraysLinearAlgebraExt.jl new file mode 100644 index 0000000..d28d8e7 --- /dev/null +++ b/ext/SerializedArraysLinearAlgebraExt/SerializedArraysLinearAlgebraExt.jl @@ -0,0 +1,25 @@ +module SerializedArraysLinearAlgebraExt + +using LinearAlgebra: LinearAlgebra, mul! +using SerializedArrays: AbstractSerializedMatrix + +function LinearAlgebra.mul!( + a_dest::AbstractMatrix, + a1::AbstractSerializedMatrix, + a2::AbstractSerializedMatrix, + α::Number, + β::Number, +) + mul!(a_dest, copy(a1), copy(a2), α, β) + return a_dest +end + +for f in [:eigen, :qr, :svd] + @eval begin + function LinearAlgebra.$f(a::AbstractSerializedMatrix; kwargs...) + return LinearAlgebra.$f(copy(a)) + end + end +end + +end diff --git a/src/SerializedArrays.jl b/src/SerializedArrays.jl index b888e25..b931ca9 100644 --- a/src/SerializedArrays.jl +++ b/src/SerializedArrays.jl @@ -1,11 +1,69 @@ module SerializedArrays +using Base.PermutedDimsArrays: genperm using ConstructionBase: constructorof -using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked -using LinearAlgebra: LinearAlgebra, mul! +using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked, readblock!, writeblock! using Serialization: deserialize, serialize -struct SerializedArray{T,N,A<:AbstractArray{T,N},Axes} <: AbstractDiskArray{T,N} +abstract type AbstractSerializedArray{T,N} <: AbstractDiskArray{T,N} end +const AbstractSerializedMatrix{T} = AbstractSerializedArray{T,2} +const AbstractSerializedVector{T} = AbstractSerializedArray{T,1} + +function _copyto_write!(dst, src) + writeblock!(dst, src, axes(src)...) + return dst +end +function _copyto_read!(dst, src) + readblock!(src, dst, axes(src)...) + return dst +end + +function Base.copyto!(dst::AbstractSerializedArray, src::AbstractArray) + return _copyto_write!(dst, src) +end +function Base.copyto!(dst::AbstractArray, src::AbstractSerializedArray) + return _copyto_read!(dst, src) +end +# Fix ambiguity error. +function Base.copyto!(dst::AbstractSerializedArray, src::AbstractSerializedArray) + return copyto!(dst, copy(src)) +end +# Fix ambiguity error. +function Base.copyto!(dst::AbstractDiskArray, src::AbstractSerializedArray) + return copyto!(dst, copy(src)) +end +# Fix ambiguity error. +function Base.copyto!(dst::AbstractSerializedArray, src::AbstractDiskArray) + return _copyto_write!(dst, src) +end +# Fix ambiguity error. +function Base.copyto!(dst::PermutedDimsArray, src::AbstractSerializedArray) + return _copyto_read!(dst, src) +end + +function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractSerializedArray) + return copy(a1) == copy(a2) +end +function Base.:(==)(a1::AbstractArray, a2::AbstractSerializedArray) + return a1 == copy(a2) +end +function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractArray) + return copy(a1) == a2 +end + +# # These cause too many ambiguity errors, try bringing them back. +# function Base.convert(arrayt::Type{<:AbstractSerializedArray}, a::AbstractArray) +# return arrayt(a) +# end +# function Base.convert(arrayt::Type{<:AbstractArray}, a::AbstractSerializedArray) +# return convert(arrayt, copy(a)) +# end +# # Fixes ambiguity error. +# function Base.convert(arrayt::Type{<:Array}, a::AbstractSerializedArray) +# return convert(arrayt, copy(a)) +# end + +struct SerializedArray{T,N,A<:AbstractArray{T,N},Axes} <: AbstractSerializedArray{T,N} file::String axes::Axes end @@ -22,17 +80,26 @@ function SerializedArray(a::AbstractArray) return SerializedArray(tempname(), a) end +function Base.convert(arrayt::Type{<:SerializedArray}, a::AbstractArray) + return arrayt(a) +end + function Base.similar(a::SerializedArray, elt::Type, dims::Tuple{Vararg{Int}}) return constructorof(arraytype(a)){elt}(undef, dims...) end +function materialize(a::SerializedArray) + return deserialize(file(a))::arraytype(a) +end function Base.copy(a::SerializedArray) - arrayt = arraytype(a) - return convert(arrayt, deserialize(file(a)))::arrayt + return materialize(a) end Base.size(a::SerializedArray) = length.(axes(a)) +to_axis(r::AbstractUnitRange) = r +to_axis(d::Integer) = Base.OneTo(d) + # # DiskArrays # @@ -64,6 +131,131 @@ function DiskArrays.create_outputarray(::Nothing, a::SerializedArray, output_siz return similar(a, output_size) end +struct PermutedSerializedArray{T,N,P<:PermutedDimsArray{T,N}} <: + AbstractSerializedArray{T,N} + permuted_parent::P +end +Base.parent(a::PermutedSerializedArray) = parent(getfield(a, :permuted_parent)) + +perm(a::PermutedSerializedArray) = perm(a.permuted_parent) +perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p + +iperm(a::PermutedSerializedArray) = iperm(a.permuted_parent) +iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip + +Base.axes(a::PermutedSerializedArray) = genperm(axes(parent(a)), perm(a)) +Base.size(a::PermutedSerializedArray) = length.(axes(a)) + +function PermutedSerializedArray(a::AbstractArray, perm) + a′ = PermutedDimsArray(a, perm) + return PermutedSerializedArray{eltype(a),ndims(a),typeof(a′)}(a′) +end + +function Base.permutedims(a::AbstractSerializedArray, perm) + return PermutedSerializedArray(a, perm) +end + +function Base.similar(a::PermutedSerializedArray, elt::Type, dims::Tuple{Vararg{Int}}) + return similar(parent(a), elt, dims) +end + +function materialize(a::PermutedSerializedArray) + return PermutedDimsArray(copy(parent(a)), perm(a)) +end +function Base.copy(a::PermutedSerializedArray) + return copy(materialize(a)) +end + +haschunks(a::PermutedSerializedArray) = Unchunked() +function DiskArrays.readblock!(a::PermutedSerializedArray, aout, i::OrdinalRange...) + ip = iperm(a) + # Permute the indices + inew = genperm(i, ip) + # Permute the dest block and read from the true parent + DiskArrays.readblock!(parent(a), PermutedDimsArray(aout, ip), inew...) + return nothing +end +function DiskArrays.writeblock!(a::PermutedSerializedArray, v, i::OrdinalRange...) + ip = iperm(a) + inew = genperm(i, ip) + # Permute the dest block and write from the true parent + DiskArrays.writeblock!(parent(a), PermutedDimsArray(v, ip), inew...) + return nothing +end + +struct ReshapedSerializedArray{T,N,P<:AbstractArray{T},Axes} <: AbstractSerializedArray{T,N} + parent::P + axes::Axes +end +Base.parent(a::ReshapedSerializedArray) = getfield(a, :parent) +Base.axes(a::ReshapedSerializedArray) = getfield(a, :axes) + +function ReshapedSerializedArray( + a::AbstractSerializedArray, + ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, +) + return ReshapedSerializedArray{eltype(a),length(ax),typeof(a),typeof(ax)}(a, ax) +end +function ReshapedSerializedArray( + a::AbstractSerializedArray, + shape::Tuple{ + Union{Integer,AbstractUnitRange{<:Integer}}, + Vararg{Union{Integer,AbstractUnitRange{<:Integer}}}, + }, +) + return ReshapedSerializedArray(a, to_axis.(shape)) +end + +Base.size(a::ReshapedSerializedArray) = length.(axes(a)) + +function Base.similar(a::ReshapedSerializedArray, elt::Type, dims::Tuple{Vararg{Int}}) + return similar(parent(a), elt, dims) +end + +function materialize(a::ReshapedSerializedArray) + return reshape(materialize(parent(a)), axes(a)) +end +function Base.copy(a::ReshapedSerializedArray) + a′ = materialize(a) + return a′ isa Base.ReshapedArray ? copy(a′) : a′ +end + +# Special case for handling nested wrappers that aren't +# friendly on GPU. Consider special cases of strded arrays +# and handle with stride manipulations. +function Base.copy(a::ReshapedSerializedArray{<:Any,<:Any,<:PermutedSerializedArray}) + a′ = reshape(copy(parent(a)), axes(a)) + return a′ isa Base.ReshapedArray ? copy(a′) : a′ +end + +function Base.reshape(a::AbstractSerializedArray, dims::Tuple{Int,Vararg{Int}}) + return ReshapedSerializedArray(a, dims) +end + +DiskArrays.haschunks(a::ReshapedSerializedArray) = Unchunked() +function DiskArrays.readblock!( + a::ReshapedSerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N} +) where {N} + if i == axes(a) + aout .= copy(a) + return a + end + aout .= @view copy(a)[i...] + return nothing +end +function DiskArrays.writeblock!( + a::ReshapedSerializedArray{<:Any,N}, ain, i::Vararg{AbstractUnitRange,N} +) where {N} + if i == axes(a) + serialize(file(a), ain) + return a + end + a′ = copy(a) + a′[i...] = ain + serialize(file(a), a′) + return nothing +end + # # Broadcast # @@ -86,7 +278,7 @@ function Base.BroadcastStyle(::DefaultArrayStyle{M}, ::SerializedArrayStyle{N}) end struct BroadcastSerializedArray{T,N,BC<:Broadcasted{<:SerializedArrayStyle{N}}} <: - AbstractDiskArray{T,N} + AbstractSerializedArray{T,N} broadcasted::BC end function BroadcastSerializedArray( @@ -106,15 +298,4 @@ function Base.copy(broadcasted::Broadcasted{SerializedArrayStyle{N}}) where {N} return BroadcastSerializedArray(flatten(broadcasted)) end -# -# LinearAlgebra -# - -function LinearAlgebra.mul!( - a_dest::AbstractMatrix, a1::SerializedArray, a2::SerializedArray, α::Number, β::Number -) - mul!(a_dest, copy(a1), copy(a2), α, β) - return a_dest -end - end diff --git a/test/Project.toml b/test/Project.toml index 3064946..d9107f5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,7 +1,9 @@ [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SerializedArrays = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -10,9 +12,11 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" [compat] +Adapt = "4" Aqua = "0.8" GPUArraysCore = "0.2" JLArrays = "0.2" +LinearAlgebra = "1.10" SafeTestsets = "0.1" SerializedArrays = "0.1" StableRNGs = "1" diff --git a/test/test_adaptext.jl b/test/test_adaptext.jl new file mode 100644 index 0000000..5f461ae --- /dev/null +++ b/test/test_adaptext.jl @@ -0,0 +1,22 @@ +using Adapt: adapt +using JLArrays: JLArray +using SerializedArrays: SerializedArray +using StableRNGs: StableRNG +using Test: @test, @testset +using TestExtras: @constinferred + +elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) +arrayts = (Array, JLArray) +@testset "SerializedArraysAdaptExt (eltype=$elt, arraytype=$arrayt)" for elt in elts, + arrayt in arrayts + + rng = StableRNG(123) + x = arrayt(randn(rng, elt, 4, 4)) + y = PermutedDimsArray(x, (2, 1)) + a = adapt(SerializedArray, x) + @test a isa SerializedArray{elt,2,arrayt{elt,2}} + b = adapt(SerializedArray, y) + @test b isa + PermutedDimsArray{elt,2,(2, 1),(2, 1),<:SerializedArray{elt,2,<:arrayt{elt,2}}} + @test parent(b) == a +end diff --git a/test/test_basics.jl b/test/test_basics.jl index 5bc6690..a4269be 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,6 +1,6 @@ using GPUArraysCore: @allowscalar using JLArrays: JLArray -using SerializedArrays: SerializedArray +using SerializedArrays: PermutedSerializedArray, ReshapedSerializedArray, SerializedArray using StableRNGs: StableRNG using Test: @test, @testset using TestExtras: @constinferred @@ -36,4 +36,64 @@ arrayts = (Array, JLArray) c = @constinferred(a * b) @test c == x * y @test c isa arrayt{elt,2} + + rng = StableRNG(123) + x = arrayt(randn(rng, elt, 4, 4)) + a = SerializedArray(x) + b = similar(a) + @test b isa arrayt{elt,2} + @test size(b) == size(a) == size(x) + + rng = StableRNG(123) + x = arrayt(randn(rng, elt, 4, 4)) + a = permutedims(SerializedArray(x), (2, 1)) + @test a isa PermutedSerializedArray{elt,2} + @test similar(a) isa arrayt{elt,2} + @test copy(a) == permutedims(x, (2, 1)) + + rng = StableRNG(123) + x = arrayt(randn(rng, elt, 4, 4)) + a = reshape(SerializedArray(x), 16) + @test a isa ReshapedSerializedArray{elt,1} + @test similar(a) isa arrayt{elt,1} + @test copy(a) == reshape(x, 16) + + rng = StableRNG(123) + x = arrayt(randn(rng, elt, 4, 4)) + a = reshape(permutedims(SerializedArray(x), (2, 1)), 16) + @test a isa ReshapedSerializedArray{elt,1,<:PermutedSerializedArray{elt,2}} + @test similar(a) isa arrayt{elt,1} + @test copy(a) == reshape(permutedims(x, (2, 1)), 16) + + rng = StableRNG(123) + x = arrayt(randn(rng, elt, 4, 4)) + a = SerializedArray(x) + @test a == a + @test x == a + @test a == x + + rng = StableRNG(123) + x = arrayt(randn(rng, elt, 4, 4)) + y = arrayt(randn(rng, elt, 4, 4)) + a = SerializedArray(x) + b = SerializedArray(y) + copyto!(b, a) + @test b == a + @test b == x + + rng = StableRNG(123) + x = arrayt(randn(rng, elt, 4, 4)) + y = arrayt(randn(rng, elt, 4, 4)) + a = SerializedArray(x) + b = SerializedArray(y) + copyto!(b, x) + @test b == a + + rng = StableRNG(123) + x = arrayt(randn(rng, elt, 4, 4)) + y = arrayt(randn(rng, elt, 4, 4)) + a = SerializedArray(x) + copyto!(y, a) + b = SerializedArray(y) + @test b == a end diff --git a/test/test_linearalgebraext.jl b/test/test_linearalgebraext.jl new file mode 100644 index 0000000..2b5f769 --- /dev/null +++ b/test/test_linearalgebraext.jl @@ -0,0 +1,42 @@ +using JLArrays: JLArray +using LinearAlgebra: eigen, qr, svd +using SerializedArrays: SerializedArray +using StableRNGs: StableRNG +using Test: @test, @testset +using TestExtras: @constinferred + +elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) +arrayts = (Array, JLArray) +@testset "SerializedArraysLinearAlgebraExt (eltype=$elt, arraytype=$arrayt)" for elt in + elts, + arrayt in arrayts + + rng = StableRNG(123) + x = arrayt(randn(rng, elt, 4, 4)) + y = arrayt(randn(rng, elt, 4, 4)) + a = SerializedArray(x) + b = SerializedArray(y) + c = @constinferred(a * b) + @test c == x * y + @test c isa arrayt{elt,2} + + a = permutedims(SerializedArray(x), (2, 1)) + b = permutedims(SerializedArray(y), (2, 1)) + c = @constinferred(a * b) + @test c == permutedims(x, (2, 1)) * permutedims(y, (2, 1)) + @test c isa arrayt{elt,2} + + rng = StableRNG(123) + x = arrayt(randn(rng, elt, 4, 4)) + a = SerializedArray(x) + # `LinearAlgebra.eigen(::JLArray)` is broken with + # a scalar indexing issue. + if arrayt ≠ JLArray + @test eigen(a) == eigen(x) + end + Q, R = qr(a) + Qₓ, Rₓ = qr(x) + @test Q == Qₓ + @test R == Rₓ + @test svd(a) == svd(x) +end