From e97e8ea014916dfe11849a86a693aa1f2ff4216d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 15 Apr 2025 11:58:14 -0400 Subject: [PATCH] Generalize matmul, introduce disk and memory interface --- Project.toml | 2 +- .../SerializedArraysLinearAlgebraExt.jl | 32 ++++++++++-- src/SerializedArrays.jl | 51 +++++++++++-------- test/test_basics.jl | 10 +++- test/test_linearalgebraext.jl | 8 +++ 5 files changed, 77 insertions(+), 26 deletions(-) diff --git a/Project.toml b/Project.toml index a5428d3..bf08400 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SerializedArrays" uuid = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39" authors = ["ITensor developers and contributors"] -version = "0.1.2" +version = "0.1.3" [deps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" diff --git a/ext/SerializedArraysLinearAlgebraExt/SerializedArraysLinearAlgebraExt.jl b/ext/SerializedArraysLinearAlgebraExt/SerializedArraysLinearAlgebraExt.jl index d28d8e7..7eb44fd 100644 --- a/ext/SerializedArraysLinearAlgebraExt/SerializedArraysLinearAlgebraExt.jl +++ b/ext/SerializedArraysLinearAlgebraExt/SerializedArraysLinearAlgebraExt.jl @@ -1,7 +1,14 @@ module SerializedArraysLinearAlgebraExt using LinearAlgebra: LinearAlgebra, mul! -using SerializedArrays: AbstractSerializedMatrix +using SerializedArrays: AbstractSerializedMatrix, memory + +function mul_serialized!( + a_dest::AbstractMatrix, a1::AbstractMatrix, a2::AbstractMatrix, α::Number, β::Number +) + mul!(a_dest, memory(a1), memory(a2), α, β) + return a_dest +end function LinearAlgebra.mul!( a_dest::AbstractMatrix, @@ -10,8 +17,27 @@ function LinearAlgebra.mul!( α::Number, β::Number, ) - mul!(a_dest, copy(a1), copy(a2), α, β) - return a_dest + return mul_serialized!(a_dest, a1, a2, α, β) +end + +function LinearAlgebra.mul!( + a_dest::AbstractMatrix, + a1::AbstractMatrix, + a2::AbstractSerializedMatrix, + α::Number, + β::Number, +) + return mul_serialized!(a_dest, a1, a2, α, β) +end + +function LinearAlgebra.mul!( + a_dest::AbstractMatrix, + a1::AbstractSerializedMatrix, + a2::AbstractMatrix, + α::Number, + β::Number, +) + return mul_serialized!(a_dest, a1, a2, α, β) end for f in [:eigen, :qr, :svd] diff --git a/src/SerializedArrays.jl b/src/SerializedArrays.jl index 6e46642..13e1808 100644 --- a/src/SerializedArrays.jl +++ b/src/SerializedArrays.jl @@ -5,6 +5,8 @@ using ConstructionBase: constructorof using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked, readblock!, writeblock! using Serialization: deserialize, serialize +memory(a) = a + # # AbstractSerializedArray # @@ -13,6 +15,9 @@ abstract type AbstractSerializedArray{T,N} <: AbstractDiskArray{T,N} end const AbstractSerializedMatrix{T} = AbstractSerializedArray{T,2} const AbstractSerializedVector{T} = AbstractSerializedArray{T,1} +memory(a::AbstractSerializedArray) = copy(a) +disk(a::AbstractSerializedArray) = a + function _copyto_write!(dst, src) writeblock!(dst, src, axes(src)...) return dst @@ -30,11 +35,11 @@ function Base.copyto!(dst::AbstractArray, src::AbstractSerializedArray) end # Fix ambiguity error. function Base.copyto!(dst::AbstractSerializedArray, src::AbstractSerializedArray) - return copyto!(dst, copy(src)) + return copyto!(dst, memory(src)) end # Fix ambiguity error. function Base.copyto!(dst::AbstractDiskArray, src::AbstractSerializedArray) - return copyto!(dst, copy(src)) + return copyto!(dst, memory(src)) end # Fix ambiguity error. function Base.copyto!(dst::AbstractSerializedArray, src::AbstractDiskArray) @@ -45,14 +50,16 @@ function Base.copyto!(dst::PermutedDimsArray, src::AbstractSerializedArray) return _copyto_read!(dst, src) end +equals_serialized(a1, a2) = memory(a1) == memory(a2) + function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractSerializedArray) - return copy(a1) == copy(a2) + return equals_serialized(a1, a2) end function Base.:(==)(a1::AbstractArray, a2::AbstractSerializedArray) - return a1 == copy(a2) + return equals_serialized(a1, a2) end function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractArray) - return copy(a1) == a2 + return equals_serialized(a1, a2) end # # These cause too many ambiguity errors, try bringing them back. @@ -60,11 +67,11 @@ end # return arrayt(a) # end # function Base.convert(arrayt::Type{<:AbstractArray}, a::AbstractSerializedArray) -# return convert(arrayt, copy(a)) +# return convert(arrayt, memory(a)) # end # # Fixes ambiguity error. # function Base.convert(arrayt::Type{<:Array}, a::AbstractSerializedArray) -# return convert(arrayt, copy(a)) +# return convert(arrayt, memory(a)) # end # @@ -79,6 +86,8 @@ file(a::SerializedArray) = getfield(a, :file) Base.axes(a::SerializedArray) = getfield(a, :axes) arraytype(a::SerializedArray{<:Any,<:Any,A}) where {A} = A +disk(a::AbstractArray) = SerializedArray(a) + function SerializedArray(file::String, a::AbstractArray) serialize(file, a) ax = axes(a) @@ -114,10 +123,10 @@ function DiskArrays.readblock!( a::SerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N} ) where {N} if i == axes(a) - aout .= copy(a) + aout .= memory(a) return a end - aout .= @view copy(a)[i...] + aout .= @view memory(a)[i...] return a end function DiskArrays.writeblock!( @@ -127,7 +136,7 @@ function DiskArrays.writeblock!( serialize(file(a), ain) return a end - a′ = copy(a) + a′ = memory(a) a′[i...] = ain serialize(file(a), a′) return a @@ -171,7 +180,7 @@ function Base.similar(a::PermutedSerializedArray, elt::Type, dims::Tuple{Vararg{ end function materialize(a::PermutedSerializedArray) - return PermutedDimsArray(copy(parent(a)), perm(a)) + return PermutedDimsArray(memory(parent(a)), perm(a)) end function Base.copy(a::PermutedSerializedArray) return copy(materialize(a)) @@ -241,7 +250,7 @@ end # friendly on GPU. Consider special cases of strded arrays # and handle with stride manipulations. function Base.copy(a::ReshapedSerializedArray{<:Any,<:Any,<:PermutedSerializedArray}) - a′ = reshape(copy(parent(a)), axes(a)) + a′ = reshape(memory(parent(a)), axes(a)) return a′ isa Base.ReshapedArray ? copy(a′) : a′ end @@ -254,10 +263,10 @@ function DiskArrays.readblock!( a::ReshapedSerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N} ) where {N} if i == axes(a) - aout .= copy(a) + aout .= memory(a) return a end - aout .= @view copy(a)[i...] + aout .= @view memory(a)[i...] return nothing end function DiskArrays.writeblock!( @@ -267,7 +276,7 @@ function DiskArrays.writeblock!( serialize(file(a), ain) return a end - a′ = copy(a) + a′ = memory(a) a′[i...] = ain serialize(file(a), a′) return nothing @@ -307,9 +316,9 @@ end DiskArrays.haschunks(a::SubSerializedArray) = Unchunked() function DiskArrays.readblock!(a::SubSerializedArray, aout, i::OrdinalRange...) if i == axes(a) - aout .= copy(a) + aout .= memory(a) end - aout[i...] = copy(view(a, i...)) + aout[i...] = memory(view(a, i...)) return nothing end function DiskArrays.writeblock!(a::SubSerializedArray, ain, i::OrdinalRange...) @@ -317,7 +326,7 @@ function DiskArrays.writeblock!(a::SubSerializedArray, ain, i::OrdinalRange...) serialize(file(a), ain) return a end - a_parent = copy(parent(a)) + a_parent = memory(parent(a)) pinds = parentindices(view(a.sub_parent, i...)) a_parent[pinds...] = ain serialize(file(a), a_parent) @@ -349,7 +358,7 @@ function Base.similar(a::TransposeSerializedArray, elt::Type, dims::Tuple{Vararg end function materialize(a::TransposeSerializedArray) - return transpose(copy(parent(a))) + return transpose(memory(parent(a))) end function Base.copy(a::TransposeSerializedArray) return copy(materialize(a)) @@ -392,7 +401,7 @@ function Base.similar(a::AdjointSerializedArray, elt::Type, dims::Tuple{Vararg{I end function materialize(a::AdjointSerializedArray) - return adjoint(copy(parent(a))) + return adjoint(memory(parent(a))) end function Base.copy(a::AdjointSerializedArray) return copy(materialize(a)) @@ -445,7 +454,7 @@ Base.size(a::BroadcastSerializedArray) = size(a.broadcasted) Base.broadcastable(a::BroadcastSerializedArray) = a.broadcasted function Base.copy(a::BroadcastSerializedArray) # Broadcast over the materialized arrays. - return copy(Base.Broadcast.broadcasted(a.broadcasted.f, copy.(a.broadcasted.args)...)) + return copy(Base.Broadcast.broadcasted(a.broadcasted.f, memory.(a.broadcasted.args)...)) end function Base.copy(broadcasted::Broadcasted{SerializedArrayStyle{N}}) where {N} diff --git a/test/test_basics.jl b/test/test_basics.jl index ca6087a..b686735 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -6,7 +6,9 @@ using SerializedArrays: ReshapedSerializedArray, SerializedArray, SubSerializedArray, - TransposeSerializedArray + TransposeSerializedArray, + disk, + memory using StableRNGs: StableRNG using Test: @test, @testset using TestExtras: @constinferred @@ -21,6 +23,12 @@ arrayts = (Array, JLArray) a = SerializedArray(x) @test @constinferred(copy(a)) == x @test typeof(copy(a)) == typeof(x) + @test memory(a) == x + @test memory(a) isa arrayt{elt,2} + @test memory(x) === x + @test disk(a) === a + @test disk(x) == a + @test disk(x) isa SerializedArray{elt,2,<:arrayt{elt,2}} x = arrayt(zeros(elt, 4, 4)) a = SerializedArray(x) diff --git a/test/test_linearalgebraext.jl b/test/test_linearalgebraext.jl index 2b5f769..a703a29 100644 --- a/test/test_linearalgebraext.jl +++ b/test/test_linearalgebraext.jl @@ -20,6 +20,14 @@ arrayts = (Array, JLArray) @test c == x * y @test c isa arrayt{elt,2} + c = @constinferred(x * b) + @test c == x * y + @test c isa arrayt{elt,2} + + c = @constinferred(a * y) + @test c == x * y + @test c isa arrayt{elt,2} + a = permutedims(SerializedArray(x), (2, 1)) b = permutedims(SerializedArray(y), (2, 1)) c = @constinferred(a * b)