From dc49ebb87d72a4fb97db4394373f51d951f751d6 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 5 Jan 2026 14:51:28 -0500 Subject: [PATCH 1/7] [WIP] Switch to FunctionImplementations --- Project.toml | 8 +- src/abstractsparsearray.jl | 64 ++++++------- src/abstractsparsearrayinterface.jl | 105 ++++---------------- src/indexing.jl | 113 ++++++++++++---------- src/map.jl | 31 +++--- src/oneelementarray.jl | 6 +- src/sparsearraydok.jl | 20 ++-- src/sparsearrayinterface.jl | 142 ++++++++++++++++++++++------ src/sparsearrays.jl | 2 +- src/wrappers.jl | 28 +++--- test/Project.toml | 2 +- 11 files changed, 277 insertions(+), 244 deletions(-) diff --git a/Project.toml b/Project.toml index 5180d3a..def7fdd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,15 +1,15 @@ name = "SparseArraysBase" uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208" authors = ["ITensor developers and contributors"] -version = "0.7.11" +version = "0.7.12" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" -DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" @@ -29,13 +29,13 @@ SparseArraysBaseTensorAlgebraExt = ["TensorAlgebra", "SparseArrays"] Accessors = "0.1.41" Adapt = "4.3" ArrayLayouts = "1.11" -DerivableInterfaces = "0.5" Dictionaries = "0.4.3" FillArrays = "1.13" +FunctionImplementations = "0.3.0" GPUArraysCore = "0.2" LinearAlgebra = "1.10" MapBroadcast = "0.1.5" -NamedDimsArrays = "0.11" +NamedDimsArrays = "0.12" Random = "1.10" SparseArrays = "1.10" TensorAlgebra = "0.6.2" diff --git a/src/abstractsparsearray.jl b/src/abstractsparsearray.jl index 3f13ccf..b5c1115 100644 --- a/src/abstractsparsearray.jl +++ b/src/abstractsparsearray.jl @@ -4,12 +4,12 @@ abstract type AbstractSparseArray{T, N} <: AbstractArray{T, N} end Base.convert(T::Type{<:AbstractSparseArray}, a::AbstractArray) = a isa T ? a : T(a) -using DerivableInterfaces: @array_aliases -# Define AbstractSparseVector, AnyAbstractSparseArray, etc. -@array_aliases AbstractSparseArray +## using DerivableInterfaces: @array_aliases +## # Define AbstractSparseVector, AnyAbstractSparseArray, etc. +## @array_aliases AbstractSparseArray -using DerivableInterfaces: DerivableInterfaces -function DerivableInterfaces.interface(::Type{<:AbstractSparseArray}) +using FunctionImplementations: FunctionImplementations +function FunctionImplementations.Style(::Type{<:AbstractSparseArray}) return SparseArrayInterface() end @@ -49,7 +49,7 @@ function Base.similar( return similar_sparsearray(a, T, ax) end -using DerivableInterfaces: @derive +## using DerivableInterfaces: @derive # TODO: These need to be loaded since `AbstractArrayOps` # includes overloads of functions from these modules. @@ -58,32 +58,32 @@ using DerivableInterfaces: @derive using ArrayLayouts: ArrayLayouts using LinearAlgebra: LinearAlgebra -@derive (T = AnyAbstractSparseArray,) begin - Base.getindex(::T, ::Any...) - Base.getindex(::T, ::Int...) - Base.setindex!(::T, ::Any, ::Any...) - Base.setindex!(::T, ::Any, ::Int...) - Base.copy!(::AbstractArray, ::T) - Base.copyto!(::AbstractArray, ::T) - Base.map(::Any, ::T...) - Base.map!(::Any, ::AbstractArray, ::T...) - Base.mapreduce(::Any, ::Any, ::T...; kwargs...) - Base.reduce(::Any, ::T...; kwargs...) - Base.all(::Function, ::T) - Base.all(::T) - Base.iszero(::T) - Base.real(::T) - Base.fill!(::T, ::Any) - DerivableInterfaces.zero!(::T) - Base.zero(::T) - Base.permutedims!(::Any, ::T, ::Any) - Broadcast.BroadcastStyle(::Type{<:T}) - Base.copyto!(::T, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}}) - ArrayLayouts.MemoryLayout(::Type{<:T}) - LinearAlgebra.mul!(::AbstractMatrix, ::T, ::T, ::Number, ::Number) -end - -using DerivableInterfaces.Concatenate: concatenate +## @derive (T = AnyAbstractSparseArray,) begin +## Base.getindex(::T, ::Any...) +## Base.getindex(::T, ::Int...) +## Base.setindex!(::T, ::Any, ::Any...) +## Base.setindex!(::T, ::Any, ::Int...) +## Base.copy!(::AbstractArray, ::T) +## Base.copyto!(::AbstractArray, ::T) +## Base.map(::Any, ::T...) +## Base.map!(::Any, ::AbstractArray, ::T...) +## Base.mapreduce(::Any, ::Any, ::T...; kwargs...) +## Base.reduce(::Any, ::T...; kwargs...) +## Base.all(::Function, ::T) +## Base.all(::T) +## Base.iszero(::T) +## Base.real(::T) +## Base.fill!(::T, ::Any) +## DerivableInterfaces.zero!(::T) +## Base.zero(::T) +## Base.permutedims!(::Any, ::T, ::Any) +## Broadcast.BroadcastStyle(::Type{<:T}) +## Base.copyto!(::T, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}}) +## ArrayLayouts.MemoryLayout(::Type{<:T}) +## LinearAlgebra.mul!(::AbstractMatrix, ::T, ::T, ::Number, ::Number) +## end + +using FunctionImplementations.Concatenate: concatenate # We overload `Base._cat` instead of `Base.cat` since it # is friendlier for invalidations/compile times, see # https://github.com/ITensor/SparseArraysBase.jl/issues/25. diff --git a/src/abstractsparsearrayinterface.jl b/src/abstractsparsearrayinterface.jl index 32a20d4..36e7f7f 100644 --- a/src/abstractsparsearrayinterface.jl +++ b/src/abstractsparsearrayinterface.jl @@ -1,7 +1,6 @@ using Base: @_propagate_inbounds_meta -using DerivableInterfaces: - DerivableInterfaces, @derive, @interface, AbstractArrayInterface, zero! using FillArrays: Zeros +using FunctionImplementations: FunctionImplementations function unstored end function eachstoredindex end @@ -52,35 +51,32 @@ function dense(a::AbstractArray) return @allowscalar convert(densetype(a), a) end -# Minimal interface for `SparseArrayInterface`. +# Minimal interface for `SparseArrayStyle`. # Fallbacks for dense/non-sparse arrays. -# TODO: Add `ndims` type parameter, like `Base.Broadcast.AbstractArrayStyle`. -# TODO: This isn't used to define interface functions right now. -# Currently, `@interface` expects an instance, probably it should take a -# type instead so fallback functions can use abstract types. -abstract type AbstractSparseArrayInterface{N} <: AbstractArrayInterface{N} end +using FunctionImplementations: AbstractArrayStyle +abstract type AbstractSparseArrayStyle <: AbstractArrayStyle end -function DerivableInterfaces.combine_interface_rule( - interface1::AbstractSparseArrayInterface, interface2::AbstractSparseArrayInterface +function FunctionImplementations.Style( + style1::AbstractSparseArrayStyle, style2::AbstractSparseArrayStyle ) return error("Rule not defined.") end -function DerivableInterfaces.combine_interface_rule( - interface1::Interface, interface2::Interface - ) where {Interface <: AbstractSparseArrayInterface} - return interface1 -end -function DerivableInterfaces.combine_interface_rule( - interface1::AbstractSparseArrayInterface, interface2::AbstractArrayInterface - ) - return interface1 -end -function DerivableInterfaces.combine_interface_rule( - interface1::AbstractArrayInterface, interface2::AbstractSparseArrayInterface +## function FunctionImplementations.Style( +## style1::Style, style2::Style +## ) where {Style <: AbstractSparseArrayStyle} +## return style1 +## end +function FunctionImplementations.Style( + style1::AbstractSparseArrayStyle, style2::AbstractArrayStyle ) - return interface2 + return style1 end +## function FunctionImplementations.Style( +## style1::AbstractArrayStyle, style2::AbstractSparseArrayStyle +## ) +## return style2 +## end to_vec(x) = vec(collect(x)) to_vec(x::AbstractArray) = vec(x) @@ -106,63 +102,6 @@ Base.size(a::StoredValues) = size(a.storedindices) return setindex!(a.array, value, a.storedindices[I]) end -using DerivableInterfaces: DerivableInterfaces, zero! - -# `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts` -# and is useful for sparse array logic, since it can be used to empty -# the sparse array storage. -# We use a single function definition to minimize method ambiguities. -@interface interface::AbstractSparseArrayInterface function DerivableInterfaces.zero!( - a::AbstractArray - ) - # More generally, this codepath could be taking if `zero(eltype(a))` - # is defined and the elements are immutable. - f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero! - @inbounds for I in eachstoredindex(a) - a[I] = f(a[I]) - end - return a -end - -# `f::typeof(norm)`, `op::typeof(max)` used by `norm`. -function reduce_init(f, op, as...) - # TODO: Generalize this. - @assert isone(length(as)) - a = only(as) - ## TODO: Make this more efficient for block sparse - ## arrays, in that case it allocates a block. Maybe - ## it can use `FillArrays.Zeros`. - return f(getunstoredindex(a, first(eachindex(a)))) -end - -@interface ::AbstractSparseArrayInterface function Base.mapreduce( - f, op, as::AbstractArray...; init = reduce_init(f, op, as...), kwargs... - ) - # TODO: Generalize this. - @assert isone(length(as)) - a = only(as) - output = mapreduce(f, op, storedvalues(a); init, kwargs...) - ## TODO: Bring this check back, or make the function more general. - ## f_notstored = apply_notstored(f, a) - ## @assert isequal(op(output, eltype(output)(f_notstored)), output) - return output -end - -abstract type AbstractSparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end - -@derive (T = AbstractSparseArrayStyle,) begin - Base.similar(::Broadcast.Broadcasted{<:T}, ::Type, ::Tuple) - Base.copyto!(::AbstractArray, ::Broadcast.Broadcasted{<:T}) -end - -struct SparseArrayStyle{N} <: AbstractSparseArrayStyle{N} end - -SparseArrayStyle{M}(::Val{N}) where {M, N} = SparseArrayStyle{N}() - -@interface ::AbstractSparseArrayInterface function Broadcast.BroadcastStyle(type::Type) - return SparseArrayStyle{ndims(type)}() -end - using ArrayLayouts: ArrayLayouts, MatMulMatAdd abstract type AbstractSparseLayout <: ArrayLayouts.MemoryLayout end @@ -230,9 +169,3 @@ function ArrayLayouts.materialize!( sparse_mul!(m.C, m.A, m.B, m.α, m.β) return m.C end - -struct SparseLayout <: AbstractSparseLayout end - -@interface ::AbstractSparseArrayInterface function ArrayLayouts.MemoryLayout(type::Type) - return SparseLayout() -end diff --git a/src/indexing.jl b/src/indexing.jl index 85fdead..03edac2 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -1,4 +1,5 @@ using Base: @_propagate_inbounds_meta +using FunctionImplementations: style # Indexing interface # ------------------ @@ -12,9 +13,9 @@ Obtain `getindex(A, I...)` with the guarantee that there is a stored entry at th Similar to `Base.getindex`, new definitions should be in line with `IndexStyle(A)`. """ -@inline getstoredindex(A::AbstractArray, I...) = @interface interface(A) getstoredindex( - A, I... -) +@inline function getstoredindex(A::AbstractArray, I...) + return style(A)(getstoredindex)(A, I...) +end """ getunstoredindex(A::AbstractArray, I...) -> eltype(A) @@ -26,9 +27,9 @@ instantiated object. Similar to `Base.getindex`, new definitions should be in line with `IndexStyle(A)`. """ -@inline getunstoredindex(A::AbstractArray, I...) = @interface interface(A) getunstoredindex( - A, I... -) +@inline function getunstoredindex(A::AbstractArray, I...) + return style(A)(getunstoredindex)(A, I...) +end """ isstored(A::AbstractArray, I...) -> Bool @@ -39,7 +40,9 @@ sparse array types might overload this function when appropriate. Similar to `Base.getindex`, new definitions should be in line with `IndexStyle(A)`. """ -@inline isstored(A::AbstractArray, I...) = @interface interface(A) isstored(A, I...) +@inline function isstored(A::AbstractArray, I...) + return style(A)(isstored)(A, I...) +end """ setstoredindex!(A::AbstractArray, v, I...) -> A @@ -48,9 +51,9 @@ Similar to `Base.getindex`, new definitions should be in line with `IndexStyle(A Similar to `Base.setindex!`, new definitions should be in line with `IndexStyle(A)`. """ -@inline setstoredindex!(A::AbstractArray, v, I...) = @interface interface(A) setstoredindex!( - A, v, I... -) +@inline function setstoredindex!(A::AbstractArray, v, I...) + return style(A)(setstoredindex!)(A, v, I...) +end """ setunstoredindex!(A::AbstractArray, v, I...) -> A @@ -59,9 +62,9 @@ Similar to `Base.setindex!`, new definitions should be in line with `IndexStyle( Similar to `Base.setindex!`, new definitions should be in line with `IndexStyle(A)`. """ -@inline setunstoredindex!(A::AbstractArray, v, I...) = @interface interface(A) setunstoredindex!( - A, v, I... -) +@inline function setunstoredindex!(A::AbstractArray, v, I...) + return style(A)(setunstoredindex!)(A, v, I...) +end # Indices interface # ----------------- @@ -109,13 +112,13 @@ to be the same as [`eachstoredindex`](@ref). """ function storedvalues end -@derive (T = AbstractArray,) begin - SparseArraysBase.eachstoredindex(::T...) - SparseArraysBase.eachstoredindex(::IndexStyle, ::T...) - SparseArraysBase.storedlength(::T) - SparseArraysBase.storedpairs(::T) - SparseArraysBase.storedvalues(::T) -end +## @derive (T = AbstractArray,) begin +## SparseArraysBase.eachstoredindex(::T...) +## SparseArraysBase.eachstoredindex(::IndexStyle, ::T...) +## SparseArraysBase.storedlength(::T) +## SparseArraysBase.storedpairs(::T) +## SparseArraysBase.storedvalues(::T) +## end # canonical indexing # ------------------ @@ -127,7 +130,7 @@ for f in (:isstored, :getunstoredindex, :getstoredindex) _f = Symbol(:_, f) error_if_canonical = Symbol(:error_if_canonical_, f) @eval begin - @interface ::AbstractArrayInterface function $f(A::AbstractArray, I...) + function $f(A::AbstractArray, I...) @_propagate_inbounds_meta style = IndexStyle(A) $error_if_canonical(style, A, I...) @@ -172,7 +175,7 @@ for f! in (:setunstoredindex!, :setstoredindex!) _f! = Symbol(:_, f!) error_if_canonical = Symbol(:error_if_canonical_, f!) @eval begin - @interface ::AbstractArrayInterface function $f!(A::AbstractArray, v, I...) + function $f!(A::AbstractArray, v, I...) @_propagate_inbounds_meta style = IndexStyle(A) $error_if_canonical(style, A, I...) @@ -215,65 +218,67 @@ end # AbstractArrayInterface fallback definitions # ------------------------------------------- -@interface ::AbstractArrayInterface function isstored(A::AbstractArray, i::Int, I::Int...) +function isstored(A::AbstractArray, i::Int, I::Int...) @inline @boundscheck checkbounds(A, i, I...) return true end -@interface ::AbstractArrayInterface function getunstoredindex(A::AbstractArray, I::Int...) +function getunstoredindex(A::AbstractArray, I::Int...) @inline @boundscheck checkbounds(A, I...) return zero(eltype(A)) end -@interface ::AbstractArrayInterface function getstoredindex(A::AbstractArray, I::Int...) +function getstoredindex(A::AbstractArray, I::Int...) @inline return getindex(A, I...) end -@interface ::AbstractArrayInterface function setstoredindex!(A::AbstractArray, v, I::Int...) +function setstoredindex!(A::AbstractArray, v, I::Int...) @inline return setindex!(A, v, I...) end -@interface ::AbstractArrayInterface setunstoredindex!(A::AbstractArray, v, I::Int...) = error( +setunstoredindex!(A::AbstractArray, v, I::Int...) = error( "setunstoredindex! for $(typeof(A)) is not supported" ) -@interface ::AbstractArrayInterface eachstoredindex(A::AbstractArray, B::AbstractArray...) = eachstoredindex( +eachstoredindex(A::AbstractArray, B::AbstractArray...) = eachstoredindex( IndexStyle(A, B...), A, B... ) -@interface ::AbstractArrayInterface eachstoredindex(style::IndexStyle, A::AbstractArray, B::AbstractArray...) = eachindex( +eachstoredindex(style::IndexStyle, A::AbstractArray, B::AbstractArray...) = eachindex( style, A, B... ) -@interface ::AbstractArrayInterface storedvalues(A::AbstractArray) = values(A) -@interface ::AbstractArrayInterface storedpairs(A::AbstractArray) = pairs(A) -@interface ::AbstractArrayInterface storedlength(A::AbstractArray) = length(storedvalues(A)) +storedvalues(A::AbstractArray) = values(A) +storedpairs(A::AbstractArray) = pairs(A) +storedlength(A::AbstractArray) = length(storedvalues(A)) # SparseArrayInterface implementations # ------------------------------------ # canonical errors are moved to `isstored`, `getstoredindex` and `getunstoredindex` # so no errors at this level by defining both IndexLinear and IndexCartesian -@interface ::AbstractSparseArrayInterface function Base.getindex( +const getindex_sparse = sparse_style(getindex) +function getindex_sparse( A::AbstractArray{<:Any, N}, I::Vararg{Int, N} ) where {N} @_propagate_inbounds_meta @boundscheck checkbounds(A, I...) # generally isstored requires bounds checking return @inbounds isstored(A, I...) ? getstoredindex(A, I...) : getunstoredindex(A, I...) end -@interface ::AbstractSparseArrayInterface function Base.getindex(A::AbstractArray, I::Int) +function getindex_sparse(A::AbstractArray, I::Int) @_propagate_inbounds_meta @boundscheck checkbounds(A, I) return @inbounds isstored(A, I) ? getstoredindex(A, I) : getunstoredindex(A, I) end # disambiguate vectors -@interface ::AbstractSparseArrayInterface function Base.getindex(A::AbstractVector, I::Int) +function getindex_sparse(A::AbstractVector, I::Int) @_propagate_inbounds_meta @boundscheck checkbounds(A, I) return @inbounds isstored(A, I) ? getstoredindex(A, I) : getunstoredindex(A, I) end -@interface ::AbstractSparseArrayInterface function Base.setindex!( +const setindex!_sparse = sparse_style(setindex!) +function setindex!_sparse( A::AbstractArray{<:Any, N}, v, I::Vararg{Int, N} ) where {N} @_propagate_inbounds_meta @@ -284,7 +289,7 @@ end setunstoredindex!(A, v, I...) end end -@interface ::AbstractSparseArrayInterface function Base.setindex!( +function setindex!_sparse( A::AbstractArray, v, I::Int ) @_propagate_inbounds_meta @@ -296,7 +301,7 @@ end end end # disambiguate vectors -@interface ::AbstractSparseArrayInterface function Base.setindex!( +function setindex!_sparse( A::AbstractVector, v, I::Int ) @_propagate_inbounds_meta @@ -314,7 +319,8 @@ end end # required: one implementation for canonical index style -@interface ::AbstractSparseArrayInterface function eachstoredindex( +const eachstoredindex_sparse = sparse_style(eachstoredindex) +function eachstoredindex_sparse( style::IndexStyle, A::AbstractArray ) error_if_canonical_eachstoredindex(style, A) @@ -331,18 +337,20 @@ end end # derived but may be specialized: -@interface ::AbstractSparseArrayInterface function eachstoredindex( +function eachstoredindex_sparse( style::IndexStyle, A::AbstractArray, B::AbstractArray... ) return union(map(Base.Fix1(eachstoredindex, style), (A, B...))...) end -@interface ::AbstractSparseArrayInterface storedvalues(A::AbstractArray) = StoredValues(A) +const storedvalues_sparse = sparse_style(storedvalues) +storedvalues_sparse(A::AbstractArray) = StoredValues(A) # default implementation is a bit tricky here: we don't know if this is the "canonical" # implementation, so we check this and otherwise map back to `_isstored` to canonicalize the # indices -@interface ::AbstractSparseArrayInterface function isstored(A::AbstractArray, I::Int...) +const isstored_sparse = sparse_style(isstored) +function isstored_sparse(A::AbstractArray, I::Int...) @_propagate_inbounds_meta style = IndexStyle(A) # canonical linear indexing @@ -361,7 +369,8 @@ end return _isstored(style, A, Base.to_indices(A, I)...) end -@interface ::AbstractSparseArrayInterface function getunstoredindex( +const getunstoredindex_sparse = sparse_style(getunstoredindex) +function getunstoredindex_sparse( A::AbstractArray, I::Int... ) @_propagate_inbounds_meta @@ -383,8 +392,8 @@ end return _getunstoredindex(style, A, Base.to_indices(A, I)...) end -# make sure we don't call AbstractArrayInterface defaults -@interface ::AbstractSparseArrayInterface function getstoredindex( +const getstoredindex_sparse = sparse_style(getstoredindex) +function getstoredindex_sparse( A::AbstractArray, I::Int... ) @_propagate_inbounds_meta @@ -393,11 +402,13 @@ end return _getstoredindex(style, A, Base.to_indices(A, I)...) end -for f! in (:setstoredindex!, :setunstoredindex!) +const setstoredindex!_sparse = sparse_style(setstoredindex!) +const setunstoredindex!_sparse = sparse_style(setunstoredindex!) +for f! in (:setstoredindex!_sparse, :setunstoredindex!_sparse) _f! = Symbol(:_, f!) error_if_canonical_setstoredindex = Symbol(:error_if_canonical_, f!) @eval begin - @interface ::AbstractSparseArrayInterface function $f!(A::AbstractArray, v, I::Int...) + function $f!(A::AbstractArray, v, I::Int...) @_propagate_inbounds_meta style = IndexStyle(A) $error_if_canonical_setstoredindex(style, A, I...) @@ -406,10 +417,12 @@ for f! in (:setstoredindex!, :setunstoredindex!) end end -@interface ::AbstractSparseArrayInterface storedlength(A::AbstractArray) = length( +const storedlength_sparse = sparse_style(storedlength) +storedlength_sparse(A::AbstractArray) = length( storedvalues(A) ) -@interface ::AbstractSparseArrayInterface function storedpairs(A::AbstractArray) +const storedpairs_sparse = sparse_style(storedpairs) +function storedpairs_sparse(A::AbstractArray) return Iterators.map(I -> (I => A[I]), eachstoredindex(A)) end @@ -426,7 +439,7 @@ for (Tr, Tc) in Iterators.product( ) Tr === Tc === :Integer && continue @eval begin - @interface ::AbstractSparseArrayInterface function Base.getindex( + function getindex_sparse( A::AbstractMatrix, kr::$Tr, jr::$Tc ) Base.@inline # needed to make boundschecks work diff --git a/src/map.jl b/src/map.jl index 83efe6b..1ab5a52 100644 --- a/src/map.jl +++ b/src/map.jl @@ -67,15 +67,16 @@ end # map(!) # ------ -@interface I::AbstractSparseArrayInterface function Base.map( +const map_sparse = sparse_style(map) +function map_sparse( f, A::AbstractArray, Bs::AbstractArray... ) f_pres = ZeroPreserving(f, A, Bs...) - return map_sparsearray(f_pres, A, Bs...) + return map_sparse(f_pres, A, Bs...) end # This isn't an overload of `Base.map` since that leads to ambiguity errors. -function map_sparsearray(f::ZeroPreserving, A::AbstractArray, Bs::AbstractArray...) +function map_sparse(f::ZeroPreserving, A::AbstractArray, Bs::AbstractArray...) T = Base.Broadcast.combine_eltypes(f.f, (A, Bs...)) C = similar(A, T) # TODO: Instead use: @@ -84,18 +85,19 @@ function map_sparsearray(f::ZeroPreserving, A::AbstractArray, Bs::AbstractArray. # C = similar(A, Unstored(U)) # ``` # though right now `map` doesn't preserve `Zeros` or `BlockZeros`. - return map_sparsearray!(f, C, A, Bs...) + return map!_sparse(f, C, A, Bs...) end -@interface I::AbstractSparseArrayInterface function Base.map!( +const map!_sparse = sparse_style(map!) +function map!_sparse( f, C::AbstractArray, A::AbstractArray, Bs::AbstractArray... ) f_pres = ZeroPreserving(f, A, Bs...) - return map_sparsearray!(f_pres, C, A, Bs...) + return map!_sparse(f_pres, C, A, Bs...) end # This isn't an overload of `Base.map!` since that leads to ambiguity errors. -function map_sparsearray!( +function map!_sparse( f::ZeroPreserving, C::AbstractArray, A::AbstractArray, Bs::AbstractArray... ) checkshape(C, A, Bs...) @@ -121,30 +123,33 @@ end # Derived functions # ----------------- -@interface I::AbstractSparseArrayInterface function Base.copyto!( +const copyto!_sparse = sparse_style(copyto!) +function copyto!_sparse( dest::AbstractArray, src::AbstractArray ) - @interface I map!(identity, dest, src) + map!_sparse(identity, dest, src) return dest end # Only map the stored values of the inputs. function map_stored! end -@interface interface::AbstractArrayInterface function map_stored!( +const map_stored!_sparse = sparse_style(map_stored!) +function map_stored!_sparse( f, a_dest::AbstractArray, as::AbstractArray... ) - @interface interface map!(WeakPreserving(f), a_dest, as...) + map!_sparse(WeakPreserving(f), a_dest, as...) return a_dest end # Only map all values, not just the stored ones. function map_all! end -@interface interface::AbstractArrayInterface function map_all!( +const map_all!_sparse = sparse_style(map_all!) +function map_all!_sparse( f, a_dest::AbstractArray, as::AbstractArray... ) - @interface interface map!(NonPreserving(f), a_dest, as...) + map!_sparse(NonPreserving(f), a_dest, as...) return a_dest end diff --git a/src/oneelementarray.jl b/src/oneelementarray.jl index e276169..ea21b63 100644 --- a/src/oneelementarray.jl +++ b/src/oneelementarray.jl @@ -16,9 +16,9 @@ struct OneElementArray{T, N, I, Unstored <: AbstractArray{T, N}} <: AbstractSpar end end -using DerivableInterfaces: @array_aliases -# Define `OneElementMatrix`, `AnyOneElementArray`, etc. -@array_aliases OneElementArray +## using DerivableInterfaces: @array_aliases +## # Define `OneElementMatrix`, `AnyOneElementArray`, etc. +## @array_aliases OneElementArray function OneElementArray{T, N}( value, index::NTuple{N, Int}, axes::NTuple{N, AbstractUnitRange} diff --git a/src/sparsearraydok.jl b/src/sparsearraydok.jl index 6441d34..6d61ec3 100644 --- a/src/sparsearraydok.jl +++ b/src/sparsearraydok.jl @@ -1,5 +1,5 @@ using Accessors: @set -using DerivableInterfaces: DerivableInterfaces, @interface, interface, zero! +using FunctionImplementations: FunctionImplementations, zero! using Dictionaries: Dictionary, IndexError, set! const DOKStorage{T, N} = Dictionary{CartesianIndex{N}, T} @@ -65,15 +65,15 @@ function SparseArrayDOK{T}(::UndefInitializer, ax::Vararg{Any, N}) where {T, N} return SparseArrayDOK{T, N}(undef, ax) end -using DerivableInterfaces: DerivableInterfaces -# This defines the destination type of various operations in DerivableInterfaces.jl. -function Base.similar(::AbstractSparseArrayInterface, T::Type, ax::Tuple) - return similar(SparseArrayDOK{T}, ax) -end +## using DerivableInterfaces: DerivableInterfaces +## # This defines the destination type of various operations in DerivableInterfaces.jl. +## function Base.similar(::AbstractSparseArrayInterface, T::Type, ax::Tuple) +## return similar(SparseArrayDOK{T}, ax) +## end -using DerivableInterfaces: @array_aliases -# Define `SparseMatrixDOK`, `AnySparseArrayDOK`, etc. -@array_aliases SparseArrayDOK +## using DerivableInterfaces: @array_aliases +## # Define `SparseMatrixDOK`, `AnySparseArrayDOK`, etc. +## @array_aliases SparseArrayDOK storage(a::SparseArrayDOK) = a.storage @@ -115,7 +115,7 @@ end storedpairs(a::SparseArrayDOK) = pairs(storage(a)) # TODO: Also handle wrappers. -function DerivableInterfaces.zero!(a::SparseArrayDOK) +function FunctionImplementations.zero!(a::SparseArrayDOK) empty!(storage(a)) return a end diff --git a/src/sparsearrayinterface.jl b/src/sparsearrayinterface.jl index a0307fd..544d697 100644 --- a/src/sparsearrayinterface.jl +++ b/src/sparsearrayinterface.jl @@ -1,36 +1,118 @@ -using DerivableInterfaces: DerivableInterfaces - -struct SparseArrayInterface{N} <: AbstractSparseArrayInterface{N} end -SparseArrayInterface() = SparseArrayInterface{Any}() -SparseArrayInterface(::Val{N}) where {N} = SparseArrayInterface{N}() -SparseArrayInterface{M}(::Val{N}) where {M, N} = SparseArrayInterface{N}() - -# Fix ambiguity error. -function DerivableInterfaces.combine_interface_rule( - ::SparseArrayInterface{N}, ::SparseArrayInterface{N} - ) where {N} - return SparseArrayInterface{N}() -end -function DerivableInterfaces.combine_interface_rule( - ::SparseArrayInterface, ::SparseArrayInterface - ) - return SparseArrayInterface() +using FunctionImplementations: FunctionImplementations + +struct SparseArrayStyle <: AbstractSparseArrayStyle end + +# Convenient shorthand to refer to the sparse interface. +# Can turn a function into a sparse function with the syntax `sparse_style(f)`, +# i.e. `sparse_style(map)(x -> 2x, randn(2, 2))` while use the sparse +# version of `map`. +const sparse_style = SparseArrayStyle() + +## # Fix ambiguity error. +## function DerivableInterfaces.combine_interface_rule( +## ::SparseArrayInterface{N}, ::SparseArrayInterface{N} +## ) where {N} +## return SparseArrayInterface{N}() +## end +## function DerivableInterfaces.combine_interface_rule( +## ::SparseArrayInterface, ::SparseArrayInterface +## ) +## return SparseArrayInterface() +## end +## function DerivableInterfaces.combine_interface_rule( +## interface1::SparseArrayInterface, interface2::AbstractSparseArrayInterface +## ) +## return interface1 +## end +## function DerivableInterfaces.combine_interface_rule( +## interface1::AbstractSparseArrayInterface, interface2::SparseArrayInterface +## ) +## return interface2 +## end + +## FunctionImplementations.Style(::Type{<:AbstractSparseArrayStyle}) = SparseArrayStyle() + +## using FunctionImplementations: zero! +## +## # `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts` +## # and is useful for sparse array logic, since it can be used to empty +## # the sparse array storage. +## # We use a single function definition to minimize method ambiguities. +## const zero!_sparse = sparse_style(zero!) +## function zero!_sparse(a::AbstractArray) +## # More generally, this codepath could be taking if `zero(eltype(a))` +## # is defined and the elements are immutable. +## f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero! +## @inbounds for I in eachstoredindex(a) +## a[I] = f(a[I]) +## end +## return a +## end + +using FunctionImplementations: FunctionImplementations, zero! + +# `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts` +# and is useful for sparse array logic, since it can be used to empty +# the sparse array storage. +# We use a single function definition to minimize method ambiguities. +const zero!_sparse = sparse_style(zero!) +function zero!_sparse(a::AbstractArray) + # More generally, this codepath could be taking if `zero(eltype(a))` + # is defined and the elements are immutable. + f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero! + @inbounds for I in eachstoredindex(a) + a[I] = f(a[I]) + end + return a end -function DerivableInterfaces.combine_interface_rule( - interface1::SparseArrayInterface, interface2::AbstractSparseArrayInterface - ) - return interface1 + +# `f::typeof(norm)`, `op::typeof(max)` used by `norm`. +function reduce_init(f, op, as...) + # TODO: Generalize this. + @assert isone(length(as)) + a = only(as) + ## TODO: Make this more efficient for block sparse + ## arrays, in that case it allocates a block. Maybe + ## it can use `FillArrays.Zeros`. + return f(getunstoredindex(a, first(eachindex(a)))) end -function DerivableInterfaces.combine_interface_rule( - interface1::AbstractSparseArrayInterface, interface2::SparseArrayInterface + +const mapreduce_sparse = sparse_style(mapreduce) +function mapreduce_sparse( + f, op, as::AbstractArray...; init = reduce_init(f, op, as...), kwargs... ) - return interface2 + # TODO: Generalize this. + @assert isone(length(as)) + a = only(as) + output = mapreduce(f, op, storedvalues(a); init, kwargs...) + ## TODO: Bring this check back, or make the function more general. + ## f_notstored = apply_notstored(f, a) + ## @assert isequal(op(output, eltype(output)(f_notstored)), output) + return output end -# Convenient shorthand to refer to the sparse interface. -# Can turn a function into a sparse function with the syntax `sparse(f)`, -# i.e. `sparse(map)(x -> 2x, randn(2, 2))` while use the sparse -# version of `map`. -# const sparse = SparseArrayInterface() +module Broadcast + + abstract type AbstractSparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end + + ## @derive (T = AbstractSparseArrayStyle,) begin + ## Base.similar(::Broadcast.Broadcasted{<:T}, ::Type, ::Tuple) + ## Base.copyto!(::AbstractArray, ::Broadcast.Broadcasted{<:T}) + ## end + + struct SparseArrayStyle{N} <: AbstractSparseArrayStyle{N} end + + SparseArrayStyle{M}(::Val{N}) where {M, N} = SparseArrayStyle{N}() + + # TODO: Don't make this a `sparse_style` function. + BroadcastStyle_sparse = sparse_style(Broadcast.BroadcastStyle) + function BroadcastStyle_sparse(type::Type) + return SparseArrayStyle{ndims(type)}() + end + +end # module Broadcast -DerivableInterfaces.interface(::Type{<:AbstractSparseArrayStyle}) = SparseArrayInterface() +# TODO: Don't make this a `sparse_style` function. +struct SparseLayout <: AbstractSparseLayout end +const MemoryLayout_sparse = sparse_style(ArrayLayouts.MemoryLayout) +MemoryLayout_sparse(type::Type) = SparseLayout() diff --git a/src/sparsearrays.jl b/src/sparsearrays.jl index ec03616..e8f7b7c 100644 --- a/src/sparsearrays.jl +++ b/src/sparsearrays.jl @@ -6,7 +6,7 @@ function eachstoredindex(m::AbstractSparseMatrixCSC) return Iterators.map(CartesianIndex, zip(I, J)) end function eachstoredindex(a::Base.ReshapedArray{<:Any, <:Any, <:AbstractSparseMatrixCSC}) - return @interface SparseArrayInterface() eachstoredindex(a) + return eachstoredindex_sparse(a) end function SparseArrays.SparseMatrixCSC{Tv, Ti}(m::AnyAbstractSparseMatrix) where {Tv, Ti} diff --git a/src/wrappers.jl b/src/wrappers.jl index 3851114..47df938 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -148,30 +148,30 @@ function isstored(a::Transpose, I::Vararg{Int, 2}) return isstored_wrapped(a, I...) end -# TODO: Turn these into `AbstractWrappedSparseArrayInterface` functions? +# TODO: Turn these into `AbstractWrappedSparseArrayStyle` functions? for type in (:Adjoint, :PermutedDimsArray, :ReshapedArray, :SubArray, :Transpose) @eval begin - @interface ::AbstractSparseArrayInterface storedvalues(a::$type) = storedparentvalues(a) - @interface ::AbstractSparseArrayInterface function eachstoredindex(a::$type) + storedvalues_sparse(a::$type) = storedparentvalues(a) + function eachstoredindex_sparse(a::$type) return map(Base.Fix1(parentindex_to_index, a), eachstoredparentindex(a)) end - @interface ::AbstractSparseArrayInterface function eachstoredindex( + function eachstoredindex_sparse( style::IndexStyle, a::$type ) # TODO: Make lazy with `Iterators.map`. return map(Base.Fix1(parentindex_to_index, a), eachstoredparentindex(style, a)) end - @interface ::AbstractSparseArrayInterface function getstoredindex(a::$type, I::Int...) + function getstoredindex_sparse(a::$type, I::Int...) return parentvalue_to_value( a, getstoredindex(parent(a), index_to_parentindex(a, I...)...) ) end - @interface ::AbstractSparseArrayInterface function getunstoredindex(a::$type, I::Int...) + function getunstoredindex_sparse(a::$type, I::Int...) return parentvalue_to_value( a, getunstoredindex(parent(a), index_to_parentindex(a, I...)...) ) end - @interface ::AbstractSparseArrayInterface function setstoredindex!( + function setstoredindex!_sparse( a::$type, value, I::Int... ) setstoredindex!( @@ -179,7 +179,7 @@ for type in (:Adjoint, :PermutedDimsArray, :ReshapedArray, :SubArray, :Transpose ) return a end - @interface ::AbstractSparseArrayInterface function setunstoredindex!( + function setunstoredindex!_sparse( a::$type, value, I::Int... ) setunstoredindex!( @@ -191,7 +191,7 @@ for type in (:Adjoint, :PermutedDimsArray, :ReshapedArray, :SubArray, :Transpose end using LinearAlgebra: LinearAlgebra, Diagonal -@interface ::AbstractArrayInterface storedvalues(D::Diagonal) = LinearAlgebra.diag(D) +storedvalues_sparse(D::Diagonal) = LinearAlgebra.diag(D) # compat with LTS: @static if VERSION ≥ v"1.11" @@ -201,20 +201,20 @@ else return view(CartesianIndices(x), LinearAlgebra.diagind(x)) end end -@interface ::AbstractArrayInterface eachstoredindex(D::Diagonal) = _diagind( +eachstoredindex_sparse(D::Diagonal) = _diagind( D, IndexCartesian() ) -@interface ::AbstractArrayInterface function isstored(D::Diagonal, i::Int, j::Int) +function isstored_sparse(D::Diagonal, i::Int, j::Int) return i == j && checkbounds(Bool, D, i, j) end -@interface ::AbstractArrayInterface function getstoredindex(D::Diagonal, i::Int, j::Int) +function getstoredindex_sparse(D::Diagonal, i::Int, j::Int) return D.diag[i] end -@interface ::AbstractArrayInterface function getunstoredindex(D::Diagonal, i::Int, j::Int) +function getunstoredindex_sparse(D::Diagonal, i::Int, j::Int) return zero(eltype(D)) end -@interface ::AbstractArrayInterface function setstoredindex!(D::Diagonal, v, i::Int, j::Int) +function setstoredindex!_sparse(D::Diagonal, v, i::Int, j::Int) D.diag[i] = v return D end diff --git a/test/Project.toml b/test/Project.toml index ceeb6e3..fdf61ca 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -27,7 +27,7 @@ Dictionaries = "0.4.4" FillArrays = "1.13.0" JLArrays = "0.2.0, 0.3" LinearAlgebra = "<0.0.1, 1" -NamedDimsArrays = "0.11" +NamedDimsArrays = "0.12" Random = "<0.0.1, 1" SafeTestsets = "0.1.0" SparseArrays = "1.10" From c7ab9a031e877a29fcb751e3da069c47b4d5dc00 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 5 Jan 2026 15:46:01 -0500 Subject: [PATCH 2/7] Get package loading --- Project.toml | 2 +- src/abstractsparsearray.jl | 28 +++++++++++++++++++++------- src/indexing.jl | 13 ++++++++----- src/oneelementarray.jl | 5 ++--- src/sparsearraydok.jl | 5 ++--- src/sparsearrayinterface.jl | 17 +++++++++-------- 6 files changed, 43 insertions(+), 27 deletions(-) diff --git a/Project.toml b/Project.toml index def7fdd..7ff3925 100644 --- a/Project.toml +++ b/Project.toml @@ -31,7 +31,7 @@ Adapt = "4.3" ArrayLayouts = "1.11" Dictionaries = "0.4.3" FillArrays = "1.13" -FunctionImplementations = "0.3.0" +FunctionImplementations = "0.3.1" GPUArraysCore = "0.2" LinearAlgebra = "1.10" MapBroadcast = "0.1.5" diff --git a/src/abstractsparsearray.jl b/src/abstractsparsearray.jl index b5c1115..71eeed8 100644 --- a/src/abstractsparsearray.jl +++ b/src/abstractsparsearray.jl @@ -1,6 +1,20 @@ using Dictionaries: AbstractDictionary abstract type AbstractSparseArray{T, N} <: AbstractArray{T, N} end +const AbstractSparseVector{T} = AbstractSparseArray{T, 1} +const AbstractSparseMatrix{T} = AbstractSparseArray{T, 2} + +using Adapt: WrappedArray +const WrappedAbstractSparseArray{T, N} = + WrappedArray{T, N, AbstractSparseArray, AbstractSparseArray{T, N}} +const AnyAbstractSparseArray{T, N} = Union{ + AbstractSparseArray{T, N}, WrappedAbstractSparseArray{T, N}, +} +const AnyAbstractSparseVector{T} = AnyAbstractSparseArray{T, 1} +const AnyAbstractSparseMatrix{T} = AnyAbstractSparseArray{T, 2} +const AnyAbstractSparseVecOrMat{T} = Union{ + AnyAbstractSparseVector{T}, AnyAbstractSparseMatrix{T}, +} Base.convert(T::Type{<:AbstractSparseArray}, a::AbstractArray) = a isa T ? a : T(a) @@ -83,13 +97,13 @@ using LinearAlgebra: LinearAlgebra ## LinearAlgebra.mul!(::AbstractMatrix, ::T, ::T, ::Number, ::Number) ## end -using FunctionImplementations.Concatenate: concatenate -# We overload `Base._cat` instead of `Base.cat` since it -# is friendlier for invalidations/compile times, see -# https://github.com/ITensor/SparseArraysBase.jl/issues/25. -function Base._cat(dims, a::AnyAbstractSparseArray...) - return concatenate(dims, a...) -end +## using FunctionImplementations.Concatenate: concatenate +## # We overload `Base._cat` instead of `Base.cat` since it +## # is friendlier for invalidations/compile times, see +## # https://github.com/ITensor/SparseArraysBase.jl/issues/25. +## function Base._cat(dims, a::AnyAbstractSparseArray...) +## return concatenate(dims, a...) +## end # TODO: Use `map(WeakPreserving(f), a)` instead. # Currently that has trouble with type unstable maps, since diff --git a/src/indexing.jl b/src/indexing.jl index 03edac2..1c26c2c 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -126,7 +126,11 @@ function storedvalues end # f(::AbstractArray, I::Int) if IndexLinear # f(::AbstractArray{<:Any,N}, I::Vararg{Int,N}) if IndexCartesian -for f in (:isstored, :getunstoredindex, :getstoredindex) +const isstored_sparse = sparse_style(isstored) +const getunstoredindex_sparse = sparse_style(getunstoredindex) +const getstoredindex_sparse = sparse_style(getstoredindex) + +for f in (:isstored_sparse, :getunstoredindex_sparse, :getstoredindex_sparse) _f = Symbol(:_, f) error_if_canonical = Symbol(:error_if_canonical_, f) @eval begin @@ -171,7 +175,9 @@ for f in (:isstored, :getunstoredindex, :getstoredindex) end end -for f! in (:setunstoredindex!, :setstoredindex!) +const setunstoredindex!_sparse = sparse_style(setunstoredindex!) +const setstoredindex!_sparse = sparse_style(setstoredindex!) +for f! in (:setunstoredindex!_sparse, :setstoredindex!_sparse) _f! = Symbol(:_, f!) error_if_canonical = Symbol(:error_if_canonical_, f!) @eval begin @@ -349,7 +355,6 @@ storedvalues_sparse(A::AbstractArray) = StoredValues(A) # default implementation is a bit tricky here: we don't know if this is the "canonical" # implementation, so we check this and otherwise map back to `_isstored` to canonicalize the # indices -const isstored_sparse = sparse_style(isstored) function isstored_sparse(A::AbstractArray, I::Int...) @_propagate_inbounds_meta style = IndexStyle(A) @@ -369,7 +374,6 @@ function isstored_sparse(A::AbstractArray, I::Int...) return _isstored(style, A, Base.to_indices(A, I)...) end -const getunstoredindex_sparse = sparse_style(getunstoredindex) function getunstoredindex_sparse( A::AbstractArray, I::Int... ) @@ -392,7 +396,6 @@ function getunstoredindex_sparse( return _getunstoredindex(style, A, Base.to_indices(A, I)...) end -const getstoredindex_sparse = sparse_style(getstoredindex) function getstoredindex_sparse( A::AbstractArray, I::Int... ) diff --git a/src/oneelementarray.jl b/src/oneelementarray.jl index ea21b63..eda8443 100644 --- a/src/oneelementarray.jl +++ b/src/oneelementarray.jl @@ -16,9 +16,8 @@ struct OneElementArray{T, N, I, Unstored <: AbstractArray{T, N}} <: AbstractSpar end end -## using DerivableInterfaces: @array_aliases -## # Define `OneElementMatrix`, `AnyOneElementArray`, etc. -## @array_aliases OneElementArray +const OneElementVector{T} = OneElementArray{T, 1} +const OneElementMatrix{T} = OneElementArray{T, 2} function OneElementArray{T, N}( value, index::NTuple{N, Int}, axes::NTuple{N, AbstractUnitRange} diff --git a/src/sparsearraydok.jl b/src/sparsearraydok.jl index 6d61ec3..f8132ca 100644 --- a/src/sparsearraydok.jl +++ b/src/sparsearraydok.jl @@ -71,9 +71,8 @@ end ## return similar(SparseArrayDOK{T}, ax) ## end -## using DerivableInterfaces: @array_aliases -## # Define `SparseMatrixDOK`, `AnySparseArrayDOK`, etc. -## @array_aliases SparseArrayDOK +const SparseVectorDOK{T} = SparseArrayDOK{T, 1} +const SparseMatrixDOK{T} = SparseArrayDOK{T, 2} storage(a::SparseArrayDOK) = a.storage diff --git a/src/sparsearrayinterface.jl b/src/sparsearrayinterface.jl index 544d697..9b097d6 100644 --- a/src/sparsearrayinterface.jl +++ b/src/sparsearrayinterface.jl @@ -93,7 +93,7 @@ end module Broadcast - abstract type AbstractSparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end + abstract type AbstractSparseArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end ## @derive (T = AbstractSparseArrayStyle,) begin ## Base.similar(::Broadcast.Broadcasted{<:T}, ::Type, ::Tuple) @@ -105,14 +105,15 @@ module Broadcast SparseArrayStyle{M}(::Val{N}) where {M, N} = SparseArrayStyle{N}() # TODO: Don't make this a `sparse_style` function. - BroadcastStyle_sparse = sparse_style(Broadcast.BroadcastStyle) - function BroadcastStyle_sparse(type::Type) - return SparseArrayStyle{ndims(type)}() - end + ## using ..SparseArraysBase: sparse_style + ## const BroadcastStyle_sparse = sparse_style(Base.Broadcast.BroadcastStyle) + ## function Base.Broadcast.BroadcastStyle(type::Type{<:AnyAbstractSparseArray}) + ## return SparseArrayStyle{ndims(type)}() + ## end end # module Broadcast -# TODO: Don't make this a `sparse_style` function. +## # TODO: Don't make this a `sparse_style` function. struct SparseLayout <: AbstractSparseLayout end -const MemoryLayout_sparse = sparse_style(ArrayLayouts.MemoryLayout) -MemoryLayout_sparse(type::Type) = SparseLayout() +## const MemoryLayout_sparse = sparse_style(ArrayLayouts.MemoryLayout) +## ArrayLayouts.MemoryLayout(type::Type{<:AnyAbstractSparseArray}) = SparseLayout() From 8194ebbbb87735e4274d7068b5d0f0f1046f6e52 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 5 Jan 2026 15:55:04 -0500 Subject: [PATCH 3/7] Fix some tests --- src/indexing.jl | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/indexing.jl b/src/indexing.jl index 1c26c2c..97d05d1 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -1,5 +1,5 @@ using Base: @_propagate_inbounds_meta -using FunctionImplementations: style +using FunctionImplementations: Implementation, style # Indexing interface # ------------------ @@ -126,15 +126,11 @@ function storedvalues end # f(::AbstractArray, I::Int) if IndexLinear # f(::AbstractArray{<:Any,N}, I::Vararg{Int,N}) if IndexCartesian -const isstored_sparse = sparse_style(isstored) -const getunstoredindex_sparse = sparse_style(getunstoredindex) -const getstoredindex_sparse = sparse_style(getstoredindex) - -for f in (:isstored_sparse, :getunstoredindex_sparse, :getstoredindex_sparse) +for f in (:isstored, :getunstoredindex, :getstoredindex) _f = Symbol(:_, f) error_if_canonical = Symbol(:error_if_canonical_, f) @eval begin - function $f(A::AbstractArray, I...) + function (::Implementation{typeof($f)})(A::AbstractArray, I...) @_propagate_inbounds_meta style = IndexStyle(A) $error_if_canonical(style, A, I...) @@ -175,13 +171,11 @@ for f in (:isstored_sparse, :getunstoredindex_sparse, :getstoredindex_sparse) end end -const setunstoredindex!_sparse = sparse_style(setunstoredindex!) -const setstoredindex!_sparse = sparse_style(setstoredindex!) -for f! in (:setunstoredindex!_sparse, :setstoredindex!_sparse) +for f! in (:setstoredindex!, :setunstoredindex!) _f! = Symbol(:_, f!) error_if_canonical = Symbol(:error_if_canonical_, f!) @eval begin - function $f!(A::AbstractArray, v, I...) + function (::Implementation{typeof($f!)})(A::AbstractArray, v, I...) @_propagate_inbounds_meta style = IndexStyle(A) $error_if_canonical(style, A, I...) @@ -355,6 +349,7 @@ storedvalues_sparse(A::AbstractArray) = StoredValues(A) # default implementation is a bit tricky here: we don't know if this is the "canonical" # implementation, so we check this and otherwise map back to `_isstored` to canonicalize the # indices +const isstored_sparse = sparse_style(isstored) function isstored_sparse(A::AbstractArray, I::Int...) @_propagate_inbounds_meta style = IndexStyle(A) @@ -374,6 +369,7 @@ function isstored_sparse(A::AbstractArray, I::Int...) return _isstored(style, A, Base.to_indices(A, I)...) end +const getunstoredindex_sparse = sparse_style(getunstoredindex) function getunstoredindex_sparse( A::AbstractArray, I::Int... ) @@ -396,6 +392,7 @@ function getunstoredindex_sparse( return _getunstoredindex(style, A, Base.to_indices(A, I)...) end +const getstoredindex_sparse = sparse_style(getstoredindex) function getstoredindex_sparse( A::AbstractArray, I::Int... ) @@ -405,13 +402,11 @@ function getstoredindex_sparse( return _getstoredindex(style, A, Base.to_indices(A, I)...) end -const setstoredindex!_sparse = sparse_style(setstoredindex!) -const setunstoredindex!_sparse = sparse_style(setunstoredindex!) -for f! in (:setstoredindex!_sparse, :setunstoredindex!_sparse) +for f! in (:setstoredindex!, :setunstoredindex!) _f! = Symbol(:_, f!) error_if_canonical_setstoredindex = Symbol(:error_if_canonical_, f!) @eval begin - function $f!(A::AbstractArray, v, I::Int...) + function (::Implementation{typeof($f!)})(A::AbstractArray, v, I::Int...) @_propagate_inbounds_meta style = IndexStyle(A) $error_if_canonical_setstoredindex(style, A, I...) From 958b10d7eeeeb88d0b5f2965e7b3163077fca1ec Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 5 Jan 2026 17:52:15 -0500 Subject: [PATCH 4/7] Fix more tests --- src/abstractsparsearray.jl | 62 ++++++++++++++++-------------- src/indexing.jl | 76 +++++++++++++++++++++---------------- src/sparsearrayinterface.jl | 5 +++ src/wrappers.jl | 24 +++++++----- 4 files changed, 96 insertions(+), 71 deletions(-) diff --git a/src/abstractsparsearray.jl b/src/abstractsparsearray.jl index 71eeed8..8965d84 100644 --- a/src/abstractsparsearray.jl +++ b/src/abstractsparsearray.jl @@ -23,9 +23,7 @@ Base.convert(T::Type{<:AbstractSparseArray}, a::AbstractArray) = a isa T ? a : T ## @array_aliases AbstractSparseArray using FunctionImplementations: FunctionImplementations -function FunctionImplementations.Style(::Type{<:AbstractSparseArray}) - return SparseArrayInterface() -end +FunctionImplementations.Style(::Type{<:AnyAbstractSparseArray}) = SparseArrayStyle() function Base.copy(a::AnyAbstractSparseArray) return copyto!(similar(a), a) @@ -63,8 +61,6 @@ function Base.similar( return similar_sparsearray(a, T, ax) end -## using DerivableInterfaces: @derive - # TODO: These need to be loaded since `AbstractArrayOps` # includes overloads of functions from these modules. # Ideally that wouldn't be needed and can be circumvented @@ -72,30 +68,38 @@ end using ArrayLayouts: ArrayLayouts using LinearAlgebra: LinearAlgebra -## @derive (T = AnyAbstractSparseArray,) begin -## Base.getindex(::T, ::Any...) -## Base.getindex(::T, ::Int...) -## Base.setindex!(::T, ::Any, ::Any...) -## Base.setindex!(::T, ::Any, ::Int...) -## Base.copy!(::AbstractArray, ::T) -## Base.copyto!(::AbstractArray, ::T) -## Base.map(::Any, ::T...) -## Base.map!(::Any, ::AbstractArray, ::T...) -## Base.mapreduce(::Any, ::Any, ::T...; kwargs...) -## Base.reduce(::Any, ::T...; kwargs...) -## Base.all(::Function, ::T) -## Base.all(::T) -## Base.iszero(::T) -## Base.real(::T) -## Base.fill!(::T, ::Any) -## DerivableInterfaces.zero!(::T) -## Base.zero(::T) -## Base.permutedims!(::Any, ::T, ::Any) -## Broadcast.BroadcastStyle(::Type{<:T}) -## Base.copyto!(::T, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}}) -## ArrayLayouts.MemoryLayout(::Type{<:T}) -## LinearAlgebra.mul!(::AbstractMatrix, ::T, ::T, ::Number, ::Number) -## end +Base.getindex(a::AnyAbstractSparseArray, I::Any...) = style(a)(getindex)(a, I...) +Base.getindex(a::AnyAbstractSparseArray, I::Int...) = style(a)(getindex)(a, I...) +Base.setindex!(a::AnyAbstractSparseArray, x, I::Any...) = style(a)(setindex!)(a, x, I...) +Base.setindex!(a::AnyAbstractSparseArray, x, I::Int...) = style(a)(setindex!)(a, x, I...) +Base.copy!(dst::AbstractArray, src::AnyAbstractSparseArray) = style(src)(copy!)(dst, src) +Base.copyto!(dst::AbstractArray, src::AnyAbstractSparseArray) = style(src)(copyto!)(dst, src) +Base.map(f, as::AnyAbstractSparseArray...) = style(as...)(map)(f, as...) +function Base.map!(f, dst::AbstractArray, as::AnyAbstractSparseArray...) + return style(as...)(map!)(f, dst, as...) +end +function Base.mapreduce(f, op, as::AnyAbstractSparseArray...; kwargs...) + return style(as...)(mapreduce)(f, op, as...; kwargs...) +end +function Base.reduce(f, as::AnyAbstractSparseArray...; kwargs...) + return style(as...)(reduce)(f, as...; kwargs...) +end +Base.all(f::Function, a::AnyAbstractSparseArray) = style(a)(all)(f, a) +Base.all(a::AnyAbstractSparseArray) = style(a)(all)(a) +Base.iszero(a::AnyAbstractSparseArray) = style(a)(iszero)(a) +Base.isreal(a::AnyAbstractSparseArray) = style(a)(isreal)(a) +Base.real(a::AnyAbstractSparseArray) = style(a)(real)(a) +Base.fill!(a::AnyAbstractSparseArray, x) = style(a)(fill!)(a, x) +FunctionImplementations.zero!(a::AnyAbstractSparseArray) = style(a)(zero!)(a) +Base.zero(a::AnyAbstractSparseArray) = style(a)(zero)(a) +function Base.permutedims!(dst, a::AnyAbstractSparseArray, perm) + return style(a)(permutedims!)(dst, a, perm) +end + +## Broadcast.BroadcastStyle(::Type{<:T}) +## Base.copyto!(::T, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}}) +## ArrayLayouts.MemoryLayout(::Type{<:T}) +## LinearAlgebra.mul!(::AbstractMatrix, ::T, ::T, ::Number, ::Number) ## using FunctionImplementations.Concatenate: concatenate ## # We overload `Base._cat` instead of `Base.cat` since it diff --git a/src/indexing.jl b/src/indexing.jl index 97d05d1..2efef29 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -3,8 +3,6 @@ using FunctionImplementations: Implementation, style # Indexing interface # ------------------ -# these definitions are not using @derive since we need the @inline annotation -# to correctly deal with boundschecks and @inbounds """ getstoredindex(A::AbstractArray, I...) -> eltype(A) @@ -112,13 +110,13 @@ to be the same as [`eachstoredindex`](@ref). """ function storedvalues end -## @derive (T = AbstractArray,) begin -## SparseArraysBase.eachstoredindex(::T...) -## SparseArraysBase.eachstoredindex(::IndexStyle, ::T...) -## SparseArraysBase.storedlength(::T) -## SparseArraysBase.storedpairs(::T) -## SparseArraysBase.storedvalues(::T) -## end +eachstoredindex(as::AbstractArray...) = style(as...)(eachstoredindex)(as...) +function eachstoredindex(indexstyle::IndexStyle, as::AbstractArray...) + return style(as...)(eachstoredindex)(indexstyle, as...) +end +storedlength(a::AbstractArray) = style(a)(storedlength)(a) +storedpairs(a::AbstractArray) = style(a)(storedpairs)(a) +storedvalues(a::AbstractArray) = style(a)(storedvalues)(a) # canonical indexing # ------------------ @@ -216,42 +214,42 @@ for f! in (:setstoredindex!, :setunstoredindex!) end end -# AbstractArrayInterface fallback definitions +# AbstractArrayStyle fallback definitions # ------------------------------------------- -function isstored(A::AbstractArray, i::Int, I::Int...) +function (::Implementation{typeof(isstored)})(A::AbstractArray, i::Int, I::Int...) @inline @boundscheck checkbounds(A, i, I...) return true end -function getunstoredindex(A::AbstractArray, I::Int...) +function (::Implementation{typeof(getunstoredindex)})(A::AbstractArray, I::Int...) @inline @boundscheck checkbounds(A, I...) return zero(eltype(A)) end -function getstoredindex(A::AbstractArray, I::Int...) +function (::Implementation{typeof(getstoredindex)})(A::AbstractArray, I::Int...) @inline return getindex(A, I...) end -function setstoredindex!(A::AbstractArray, v, I::Int...) +function (::Implementation{typeof(setstoredindex!)})(A::AbstractArray, v, I::Int...) @inline return setindex!(A, v, I...) end -setunstoredindex!(A::AbstractArray, v, I::Int...) = error( - "setunstoredindex! for $(typeof(A)) is not supported" -) +function (::Implementation{typeof(setunstoredindex!)})(A::AbstractArray, v, I::Int...) + return error("setunstoredindex! for $(typeof(A)) is not supported") +end -eachstoredindex(A::AbstractArray, B::AbstractArray...) = eachstoredindex( - IndexStyle(A, B...), A, B... -) -eachstoredindex(style::IndexStyle, A::AbstractArray, B::AbstractArray...) = eachindex( - style, A, B... -) +function (::Implementation{typeof(eachstoredindex)})(A::AbstractArray, B::AbstractArray...) + return eachstoredindex(IndexStyle(A, B...), A, B...) +end +function (::Implementation{typeof(eachstoredindex)})(style::IndexStyle, A::AbstractArray, B::AbstractArray...) + return eachindex(style, A, B...) +end -storedvalues(A::AbstractArray) = values(A) -storedpairs(A::AbstractArray) = pairs(A) -storedlength(A::AbstractArray) = length(storedvalues(A)) +(::Implementation{typeof(storedvalues)})(a::AbstractArray) = values(a) +(::Implementation{typeof(storedpairs)})(a::AbstractArray) = pairs(a) +(::Implementation{typeof(storedlength)})(a::AbstractArray) = length(storedvalues(a)) # SparseArrayInterface implementations # ------------------------------------ @@ -276,6 +274,12 @@ function getindex_sparse(A::AbstractVector, I::Int) @boundscheck checkbounds(A, I) return @inbounds isstored(A, I) ? getstoredindex(A, I) : getunstoredindex(A, I) end +# TODO: Make this more general, use `Base.to_index`. +function getindex_sparse( + a::AbstractArray{<:Any, N}, I::CartesianIndex{N} + ) where {N} + return getindex_sparse(a, Tuple(I)...) +end const setindex!_sparse = sparse_style(setindex!) function setindex!_sparse( @@ -312,6 +316,16 @@ function setindex!_sparse( setunstoredindex!(A, v, I) end end +# TODO: Make this more general, use `Base.to_index`. +function setindex!_sparse( + a::AbstractArray{<:Any, N}, value, I::CartesianIndex{N} + ) where {N} + return setindex!(a, value, Tuple(I)...) +end +function setindex!_sparse(a::AbstractArray, value, I...) + map!(identity, @view(a[I...]), value) + return a +end @noinline function error_if_canonical_eachstoredindex(style::IndexStyle, A::AbstractArray) style === IndexStyle(A) && throw(Base.CanonicalIndexError("eachstoredindex", typeof(A))) @@ -402,11 +416,13 @@ function getstoredindex_sparse( return _getstoredindex(style, A, Base.to_indices(A, I)...) end -for f! in (:setstoredindex!, :setunstoredindex!) +const setstoredindex!_sparse = sparse_style(setstoredindex!) +const setunstoredindex!_sparse = sparse_style(setunstoredindex!) +for f! in (:setstoredindex!_sparse, :setunstoredindex!_sparse) _f! = Symbol(:_, f!) error_if_canonical_setstoredindex = Symbol(:error_if_canonical_, f!) @eval begin - function (::Implementation{typeof($f!)})(A::AbstractArray, v, I::Int...) + function $f!(A::AbstractArray, v, I::Int...) @_propagate_inbounds_meta style = IndexStyle(A) $error_if_canonical_setstoredindex(style, A, I...) @@ -415,10 +431,6 @@ for f! in (:setstoredindex!, :setunstoredindex!) end end -const storedlength_sparse = sparse_style(storedlength) -storedlength_sparse(A::AbstractArray) = length( - storedvalues(A) -) const storedpairs_sparse = sparse_style(storedpairs) function storedpairs_sparse(A::AbstractArray) return Iterators.map(I -> (I => A[I]), eachstoredindex(A)) diff --git a/src/sparsearrayinterface.jl b/src/sparsearrayinterface.jl index 9b097d6..607934f 100644 --- a/src/sparsearrayinterface.jl +++ b/src/sparsearrayinterface.jl @@ -49,6 +49,11 @@ const sparse_style = SparseArrayStyle() ## return a ## end +const fill!_sparse = sparse_style(fill!) +function fill!_sparse(a::AbstractArray, value) + return map!(Returns(value), a, a) +end + using FunctionImplementations: FunctionImplementations, zero! # `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts` diff --git a/src/wrappers.jl b/src/wrappers.jl index 47df938..fb3c430 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -190,8 +190,11 @@ for type in (:Adjoint, :PermutedDimsArray, :ReshapedArray, :SubArray, :Transpose end end +using FunctionImplementations: Style using LinearAlgebra: LinearAlgebra, Diagonal -storedvalues_sparse(D::Diagonal) = LinearAlgebra.diag(D) +const diag_style = Style(Diagonal) +const storedvalues_diag = diag_style(storedvalues) +storedvalues_diag(D::Diagonal) = LinearAlgebra.diag(D) # compat with LTS: @static if VERSION ≥ v"1.11" @@ -201,20 +204,21 @@ else return view(CartesianIndices(x), LinearAlgebra.diagind(x)) end end -eachstoredindex_sparse(D::Diagonal) = _diagind( - D, IndexCartesian() -) +const eachstoredindex_diag = diag_style(eachstoredindex) +eachstoredindex_diag(D::Diagonal) = _diagind(D, IndexCartesian()) -function isstored_sparse(D::Diagonal, i::Int, j::Int) +const isstored_diag = diag_style(isstored) +function isstored_diag(D::Diagonal, i::Int, j::Int) return i == j && checkbounds(Bool, D, i, j) end -function getstoredindex_sparse(D::Diagonal, i::Int, j::Int) - return D.diag[i] -end -function getunstoredindex_sparse(D::Diagonal, i::Int, j::Int) +const getstoredindex_diag = diag_style(getstoredindex) +getstoredindex_diag(D::Diagonal, i::Int, j::Int) = D.diag[i] +const getunstoredindex_diag = diag_style(getunstoredindex) +function getunstoredindex_diag(D::Diagonal, i::Int, j::Int) return zero(eltype(D)) end -function setstoredindex!_sparse(D::Diagonal, v, i::Int, j::Int) +const setstoredindex!_diag = diag_style(setstoredindex!) +function setstoredindex!_diag(D::Diagonal, v, i::Int, j::Int) D.diag[i] = v return D end From 1aa2b942aa39bce924b525ce57b072804f1c8201 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 5 Jan 2026 21:30:30 -0500 Subject: [PATCH 5/7] Get tests passing --- src/abstractsparsearray.jl | 13 ++++++ src/abstractsparsearrayinterface.jl | 10 ++-- src/indexing.jl | 9 +++- src/map.jl | 27 +++++++++++ src/sparsearrayinterface.jl | 72 +++++++++++++++++++++++------ 5 files changed, 109 insertions(+), 22 deletions(-) diff --git a/src/abstractsparsearray.jl b/src/abstractsparsearray.jl index 8965d84..da8658d 100644 --- a/src/abstractsparsearray.jl +++ b/src/abstractsparsearray.jl @@ -95,6 +95,19 @@ Base.zero(a::AnyAbstractSparseArray) = style(a)(zero)(a) function Base.permutedims!(dst, a::AnyAbstractSparseArray, perm) return style(a)(permutedims!)(dst, a, perm) end +function LinearAlgebra.mul!( + dst::AbstractMatrix, a1::AnyAbstractSparseArray, a2::AnyAbstractSparseArray, + α::Number, β::Number, + ) + return style(a1, a2)(mul!)(dst, a1, a2, α, β) +end + +function Base.Broadcast.BroadcastStyle(type::Type{<:AnyAbstractSparseArray}) + return Broadcast.SparseArrayStyle{ndims(type)}() +end + +using ArrayLayouts: ArrayLayouts +ArrayLayouts.MemoryLayout(type::Type{<:AnyAbstractSparseArray}) = SparseLayout() ## Broadcast.BroadcastStyle(::Type{<:T}) ## Base.copyto!(::T, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}}) diff --git a/src/abstractsparsearrayinterface.jl b/src/abstractsparsearrayinterface.jl index 36e7f7f..0ba49a3 100644 --- a/src/abstractsparsearrayinterface.jl +++ b/src/abstractsparsearrayinterface.jl @@ -120,7 +120,7 @@ function mul_indices(I1::CartesianIndex{2}, I2::CartesianIndex{2}) end using LinearAlgebra: mul! -function default_mul!!( +function mul!!( a_dest::AbstractMatrix, a1::AbstractMatrix, a2::AbstractMatrix, @@ -131,20 +131,20 @@ function default_mul!!( return a_dest end -function default_mul!!( +function mul!!( a_dest::Number, a1::Number, a2::Number, α::Number = true, β::Number = false ) return a1 * a2 * α + a_dest * β end # a1 * a2 * α + a_dest * β -function sparse_mul!( +function _mul!_sparse( a_dest::AbstractArray, a1::AbstractArray, a2::AbstractArray, α::Number = true, β::Number = false; - (mul!!) = (default_mul!!), + (mul!!) = (mul!!), ) a_dest .*= β β′ = one(Bool) @@ -166,6 +166,6 @@ end function ArrayLayouts.materialize!( m::MatMulMatAdd{<:AbstractSparseLayout, <:AbstractSparseLayout, <:AbstractSparseLayout} ) - sparse_mul!(m.C, m.A, m.B, m.α, m.β) + _mul!_sparse(m.C, m.A, m.B, m.α, m.β) return m.C end diff --git a/src/indexing.jl b/src/indexing.jl index 2efef29..2805675 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -280,6 +280,10 @@ function getindex_sparse( ) where {N} return getindex_sparse(a, Tuple(I)...) end +using ArrayLayouts: ArrayLayouts +function getindex_sparse(a::AbstractArray, I...) + return ArrayLayouts.layout_getindex(a, I...) +end const setindex!_sparse = sparse_style(setindex!) function setindex!_sparse( @@ -418,11 +422,12 @@ end const setstoredindex!_sparse = sparse_style(setstoredindex!) const setunstoredindex!_sparse = sparse_style(setunstoredindex!) -for f! in (:setstoredindex!_sparse, :setunstoredindex!_sparse) +for f! in (:setstoredindex!, :setunstoredindex!) + f!_sparse = Symbol(f!, :_sparse) _f! = Symbol(:_, f!) error_if_canonical_setstoredindex = Symbol(:error_if_canonical_, f!) @eval begin - function $f!(A::AbstractArray, v, I::Int...) + function $f!_sparse(A::AbstractArray, v, I::Int...) @_propagate_inbounds_meta style = IndexStyle(A) $error_if_canonical_setstoredindex(style, A, I...) diff --git a/src/map.jl b/src/map.jl index 1ab5a52..d44e337 100644 --- a/src/map.jl +++ b/src/map.jl @@ -131,6 +131,13 @@ function copyto!_sparse( return dest end +const permutedims!_sparse = sparse_style(permutedims!) +function permutedims!_sparse( + a_dest::AbstractArray, a_src::AbstractArray, perm + ) + return map!(identity, a_dest, PermutedDimsArray(a_src, perm)) +end + # Only map the stored values of the inputs. function map_stored! end @@ -153,6 +160,26 @@ function map_all!_sparse( return a_dest end +# TODO: Generalize to multiple inputs. +const reduce_sparse = sparse_style(reduce) +function reduce_sparse(f, a::AbstractArray; kwargs...) + return mapreduce(identity, f, a; kwargs...) +end + +const all_sparse = sparse_style(all) +function all_sparse(a::AbstractArray) + return reduce(&, a; init = true) +end +function all_sparse(f::Function, a::AbstractArray) + return mapreduce(f, &, a; init = true) +end + +const isreal_sparse = sparse_style(isreal) +isreal_sparse(a::AbstractArray) = all(isreal, a) + +const iszero_sparse = sparse_style(iszero) +iszero_sparse(a::AbstractArray) = all(iszero, a) + # Utility functions # ----------------- # shape check similar to checkbounds diff --git a/src/sparsearrayinterface.jl b/src/sparsearrayinterface.jl index 607934f..a4b31af 100644 --- a/src/sparsearrayinterface.jl +++ b/src/sparsearrayinterface.jl @@ -71,6 +71,20 @@ function zero!_sparse(a::AbstractArray) return a end +const zero_sparse = sparse_style(zero) +# Specialized version of `Base.zero` written in terms of `zero!`. +# This is friendlier for sparse arrays since `zero!` makes it easier +# to handle the logic of dropping all elements of the sparse array when possible. +# We use a single function definition to minimize method ambiguities. +function zero_sparse(a::AbstractArray) + # More generally, the first codepath could be taking if `zero(eltype(a))` + # is defined and the elements are immutable. + if eltype(a) <: Number + return zero!(similar(a)) + end + return map(zero, a) +end + # `f::typeof(norm)`, `op::typeof(max)` used by `norm`. function reduce_init(f, op, as...) # TODO: Generalize this. @@ -82,6 +96,15 @@ function reduce_init(f, op, as...) return f(getunstoredindex(a, first(eachindex(a)))) end +# This is defined in this way so we can rely on the Broadcast logic +# for determining the destination of the operation (element type, shape, etc.). +const map_sparse = sparse_style(map) +function map_sparse(f, as::AbstractArray...) + # Broadcasting is used here to determine the destination array but that + # could be done manually here. + return f.(as...) +end + const mapreduce_sparse = sparse_style(mapreduce) function mapreduce_sparse( f, op, as::AbstractArray...; init = reduce_init(f, op, as...), kwargs... @@ -97,26 +120,45 @@ function mapreduce_sparse( end module Broadcast - abstract type AbstractSparseArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end - - ## @derive (T = AbstractSparseArrayStyle,) begin - ## Base.similar(::Broadcast.Broadcasted{<:T}, ::Type, ::Tuple) - ## Base.copyto!(::AbstractArray, ::Broadcast.Broadcasted{<:T}) - ## end - struct SparseArrayStyle{N} <: AbstractSparseArrayStyle{N} end - SparseArrayStyle{M}(::Val{N}) where {M, N} = SparseArrayStyle{N}() +end + +using MapBroadcast: Mapped +# TODO: Look into `SparseArrays.capturescalars`: +# https://github.com/JuliaSparse/SparseArrays.jl/blob/1beb0e4a4618b0399907b0000c43d9f66d34accc/src/higherorderfns.jl#L1092-L1102 +function Base.copyto!( + a_dest::AbstractArray, bc::Base.Broadcast.Broadcasted{<:Broadcast.SparseArrayStyle} + ) + m = Mapped(bc) + map!(m.f, a_dest, m.args...) + return a_dest +end + +## # This captures broadcast expressions such as `a .= 2`. +## # Ideally this would be handled by `map!(f, a_dest)` but that isn't defined yet: +## # https://github.com/JuliaLang/julia/issues/31677 +## # https://github.com/JuliaLang/julia/pull/40632 +## function Base.copyto!( +## a_dest::AbstractArray, bc::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}} +## ) +## @interface interface fill!(a_dest, bc.f(bc.args...)[]) +## end - # TODO: Don't make this a `sparse_style` function. - ## using ..SparseArraysBase: sparse_style - ## const BroadcastStyle_sparse = sparse_style(Base.Broadcast.BroadcastStyle) - ## function Base.Broadcast.BroadcastStyle(type::Type{<:AnyAbstractSparseArray}) - ## return SparseArrayStyle{ndims(type)}() - ## end +function Base.similar( + bc::Base.Broadcast.Broadcasted{<:Broadcast.SparseArrayStyle}, elt::Type, ax + ) + return similar(SparseArrayDOK{elt}, ax) +end -end # module Broadcast +using ArrayLayouts: ArrayLayouts +const mul!_sparse = sparse_style(mul!) +function mul!_sparse( + a_dest::AbstractVecOrMat, a1::AbstractVecOrMat, a2::AbstractVecOrMat, α::Number, β::Number + ) + return ArrayLayouts.mul!(a_dest, a1, a2, α, β) +end ## # TODO: Don't make this a `sparse_style` function. struct SparseLayout <: AbstractSparseLayout end From a78ccc11778102d5a7ea0a1bd2ccd7d74d948bed Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 6 Jan 2026 15:52:01 -0500 Subject: [PATCH 6/7] Cleanup --- src/SparseArraysBase.jl | 4 +- src/abstractsparsearray.jl | 33 +++-------- ...terface.jl => abstractsparsearraystyle.jl} | 10 ---- src/sparsearraydok.jl | 6 -- ...earrayinterface.jl => sparsearraystyle.jl} | 58 +------------------ src/wrappers.jl | 14 ++--- 6 files changed, 20 insertions(+), 105 deletions(-) rename src/{abstractsparsearrayinterface.jl => abstractsparsearraystyle.jl} (94%) rename src/{sparsearrayinterface.jl => sparsearraystyle.jl} (64%) diff --git a/src/SparseArraysBase.jl b/src/SparseArraysBase.jl index 9f75a13..0f5ce2f 100644 --- a/src/SparseArraysBase.jl +++ b/src/SparseArraysBase.jl @@ -17,8 +17,8 @@ export SparseArrayDOK, storedpairs, storedvalues -include("abstractsparsearrayinterface.jl") -include("sparsearrayinterface.jl") +include("abstractsparsearraystyle.jl") +include("sparsearraystyle.jl") include("indexing.jl") include("map.jl") include("wrappers.jl") diff --git a/src/abstractsparsearray.jl b/src/abstractsparsearray.jl index da8658d..c835b48 100644 --- a/src/abstractsparsearray.jl +++ b/src/abstractsparsearray.jl @@ -18,10 +18,6 @@ const AnyAbstractSparseVecOrMat{T} = Union{ Base.convert(T::Type{<:AbstractSparseArray}, a::AbstractArray) = a isa T ? a : T(a) -## using DerivableInterfaces: @array_aliases -## # Define AbstractSparseVector, AnyAbstractSparseArray, etc. -## @array_aliases AbstractSparseArray - using FunctionImplementations: FunctionImplementations FunctionImplementations.Style(::Type{<:AnyAbstractSparseArray}) = SparseArrayStyle() @@ -61,10 +57,6 @@ function Base.similar( return similar_sparsearray(a, T, ax) end -# TODO: These need to be loaded since `AbstractArrayOps` -# includes overloads of functions from these modules. -# Ideally that wouldn't be needed and can be circumvented -# with `GlobalRef`. using ArrayLayouts: ArrayLayouts using LinearAlgebra: LinearAlgebra @@ -73,7 +65,9 @@ Base.getindex(a::AnyAbstractSparseArray, I::Int...) = style(a)(getindex)(a, I... Base.setindex!(a::AnyAbstractSparseArray, x, I::Any...) = style(a)(setindex!)(a, x, I...) Base.setindex!(a::AnyAbstractSparseArray, x, I::Int...) = style(a)(setindex!)(a, x, I...) Base.copy!(dst::AbstractArray, src::AnyAbstractSparseArray) = style(src)(copy!)(dst, src) -Base.copyto!(dst::AbstractArray, src::AnyAbstractSparseArray) = style(src)(copyto!)(dst, src) +function Base.copyto!(dst::AbstractArray, src::AnyAbstractSparseArray) + return style(src)(copyto!)(dst, src) +end Base.map(f, as::AnyAbstractSparseArray...) = style(as...)(map)(f, as...) function Base.map!(f, dst::AbstractArray, as::AnyAbstractSparseArray...) return style(as...)(map!)(f, dst, as...) @@ -109,18 +103,11 @@ end using ArrayLayouts: ArrayLayouts ArrayLayouts.MemoryLayout(type::Type{<:AnyAbstractSparseArray}) = SparseLayout() -## Broadcast.BroadcastStyle(::Type{<:T}) -## Base.copyto!(::T, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}}) -## ArrayLayouts.MemoryLayout(::Type{<:T}) -## LinearAlgebra.mul!(::AbstractMatrix, ::T, ::T, ::Number, ::Number) - -## using FunctionImplementations.Concatenate: concatenate -## # We overload `Base._cat` instead of `Base.cat` since it -## # is friendlier for invalidations/compile times, see -## # https://github.com/ITensor/SparseArraysBase.jl/issues/25. -## function Base._cat(dims, a::AnyAbstractSparseArray...) -## return concatenate(dims, a...) -## end +using FunctionImplementations.Concatenate: concatenate +# We overload `Base._cat` instead of `Base.cat` since it +# is friendlier for invalidations/compile times, see: +# https://github.com/ITensor/SparseArraysBase.jl/issues/25 +Base._cat(dims, a::AnyAbstractSparseArray...) = concatenate(dims, a...) # TODO: Use `map(WeakPreserving(f), a)` instead. # Currently that has trouble with type unstable maps, since @@ -278,10 +265,6 @@ function sparserand!( end end -# Catch some cases that aren't getting caught by the current -# DerivableInterfaces.jl logic. -# TODO: Make this more systematic once DerivableInterfaces.jl -# is rewritten. using ArrayLayouts: ArrayLayouts, MemoryLayout using LinearAlgebra: LinearAlgebra, Adjoint function ArrayLayouts.MemoryLayout(::Type{Transpose{T, P}}) where {T, P <: AbstractSparseMatrix} diff --git a/src/abstractsparsearrayinterface.jl b/src/abstractsparsearraystyle.jl similarity index 94% rename from src/abstractsparsearrayinterface.jl rename to src/abstractsparsearraystyle.jl index 0ba49a3..45bd02c 100644 --- a/src/abstractsparsearrayinterface.jl +++ b/src/abstractsparsearraystyle.jl @@ -62,21 +62,11 @@ function FunctionImplementations.Style( ) return error("Rule not defined.") end -## function FunctionImplementations.Style( -## style1::Style, style2::Style -## ) where {Style <: AbstractSparseArrayStyle} -## return style1 -## end function FunctionImplementations.Style( style1::AbstractSparseArrayStyle, style2::AbstractArrayStyle ) return style1 end -## function FunctionImplementations.Style( -## style1::AbstractArrayStyle, style2::AbstractSparseArrayStyle -## ) -## return style2 -## end to_vec(x) = vec(collect(x)) to_vec(x::AbstractArray) = vec(x) diff --git a/src/sparsearraydok.jl b/src/sparsearraydok.jl index f8132ca..44217d3 100644 --- a/src/sparsearraydok.jl +++ b/src/sparsearraydok.jl @@ -65,12 +65,6 @@ function SparseArrayDOK{T}(::UndefInitializer, ax::Vararg{Any, N}) where {T, N} return SparseArrayDOK{T, N}(undef, ax) end -## using DerivableInterfaces: DerivableInterfaces -## # This defines the destination type of various operations in DerivableInterfaces.jl. -## function Base.similar(::AbstractSparseArrayInterface, T::Type, ax::Tuple) -## return similar(SparseArrayDOK{T}, ax) -## end - const SparseVectorDOK{T} = SparseArrayDOK{T, 1} const SparseMatrixDOK{T} = SparseArrayDOK{T, 2} diff --git a/src/sparsearrayinterface.jl b/src/sparsearraystyle.jl similarity index 64% rename from src/sparsearrayinterface.jl rename to src/sparsearraystyle.jl index a4b31af..6c5ccb0 100644 --- a/src/sparsearrayinterface.jl +++ b/src/sparsearraystyle.jl @@ -2,53 +2,12 @@ using FunctionImplementations: FunctionImplementations struct SparseArrayStyle <: AbstractSparseArrayStyle end -# Convenient shorthand to refer to the sparse interface. +# Convenient shorthand to refer to the sparse style. # Can turn a function into a sparse function with the syntax `sparse_style(f)`, # i.e. `sparse_style(map)(x -> 2x, randn(2, 2))` while use the sparse # version of `map`. const sparse_style = SparseArrayStyle() -## # Fix ambiguity error. -## function DerivableInterfaces.combine_interface_rule( -## ::SparseArrayInterface{N}, ::SparseArrayInterface{N} -## ) where {N} -## return SparseArrayInterface{N}() -## end -## function DerivableInterfaces.combine_interface_rule( -## ::SparseArrayInterface, ::SparseArrayInterface -## ) -## return SparseArrayInterface() -## end -## function DerivableInterfaces.combine_interface_rule( -## interface1::SparseArrayInterface, interface2::AbstractSparseArrayInterface -## ) -## return interface1 -## end -## function DerivableInterfaces.combine_interface_rule( -## interface1::AbstractSparseArrayInterface, interface2::SparseArrayInterface -## ) -## return interface2 -## end - -## FunctionImplementations.Style(::Type{<:AbstractSparseArrayStyle}) = SparseArrayStyle() - -## using FunctionImplementations: zero! -## -## # `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts` -## # and is useful for sparse array logic, since it can be used to empty -## # the sparse array storage. -## # We use a single function definition to minimize method ambiguities. -## const zero!_sparse = sparse_style(zero!) -## function zero!_sparse(a::AbstractArray) -## # More generally, this codepath could be taking if `zero(eltype(a))` -## # is defined and the elements are immutable. -## f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero! -## @inbounds for I in eachstoredindex(a) -## a[I] = f(a[I]) -## end -## return a -## end - const fill!_sparse = sparse_style(fill!) function fill!_sparse(a::AbstractArray, value) return map!(Returns(value), a, a) @@ -119,6 +78,8 @@ function mapreduce_sparse( return output end +# Namespace for Broadcast styles to avoid clashing with FunctionImplementations +# styles. module Broadcast abstract type AbstractSparseArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end struct SparseArrayStyle{N} <: AbstractSparseArrayStyle{N} end @@ -136,16 +97,6 @@ function Base.copyto!( return a_dest end -## # This captures broadcast expressions such as `a .= 2`. -## # Ideally this would be handled by `map!(f, a_dest)` but that isn't defined yet: -## # https://github.com/JuliaLang/julia/issues/31677 -## # https://github.com/JuliaLang/julia/pull/40632 -## function Base.copyto!( -## a_dest::AbstractArray, bc::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}} -## ) -## @interface interface fill!(a_dest, bc.f(bc.args...)[]) -## end - function Base.similar( bc::Base.Broadcast.Broadcasted{<:Broadcast.SparseArrayStyle}, elt::Type, ax ) @@ -160,7 +111,4 @@ function mul!_sparse( return ArrayLayouts.mul!(a_dest, a1, a2, α, β) end -## # TODO: Don't make this a `sparse_style` function. struct SparseLayout <: AbstractSparseLayout end -## const MemoryLayout_sparse = sparse_style(ArrayLayouts.MemoryLayout) -## ArrayLayouts.MemoryLayout(type::Type{<:AnyAbstractSparseArray}) = SparseLayout() diff --git a/src/wrappers.jl b/src/wrappers.jl index fb3c430..3ce15a8 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -194,31 +194,31 @@ using FunctionImplementations: Style using LinearAlgebra: LinearAlgebra, Diagonal const diag_style = Style(Diagonal) const storedvalues_diag = diag_style(storedvalues) -storedvalues_diag(D::Diagonal) = LinearAlgebra.diag(D) +storedvalues_diag(D::AbstractMatrix) = LinearAlgebra.diag(D) # compat with LTS: @static if VERSION ≥ v"1.11" _diagind = LinearAlgebra.diagind else - function _diagind(x::Diagonal, ::IndexCartesian) + function _diagind(x::AbstractMatrix, ::IndexCartesian) return view(CartesianIndices(x), LinearAlgebra.diagind(x)) end end const eachstoredindex_diag = diag_style(eachstoredindex) -eachstoredindex_diag(D::Diagonal) = _diagind(D, IndexCartesian()) +eachstoredindex_diag(D::AbstractMatrix) = _diagind(D, IndexCartesian()) const isstored_diag = diag_style(isstored) -function isstored_diag(D::Diagonal, i::Int, j::Int) +function isstored_diag(D::AbstractMatrix, i::Int, j::Int) return i == j && checkbounds(Bool, D, i, j) end const getstoredindex_diag = diag_style(getstoredindex) -getstoredindex_diag(D::Diagonal, i::Int, j::Int) = D.diag[i] +getstoredindex_diag(D::AbstractMatrix, i::Int, j::Int) = D.diag[i] const getunstoredindex_diag = diag_style(getunstoredindex) -function getunstoredindex_diag(D::Diagonal, i::Int, j::Int) +function getunstoredindex_diag(D::AbstractMatrix, i::Int, j::Int) return zero(eltype(D)) end const setstoredindex!_diag = diag_style(setstoredindex!) -function setstoredindex!_diag(D::Diagonal, v, i::Int, j::Int) +function setstoredindex!_diag(D::AbstractMatrix, v, i::Int, j::Int) D.diag[i] = v return D end From d58327775d222cb76c4dac2bd4a1fcbb0e0d03ff Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 6 Jan 2026 16:14:58 -0500 Subject: [PATCH 7/7] Mark as breaking --- Project.toml | 2 +- docs/Project.toml | 2 +- examples/Project.toml | 2 +- test/Project.toml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 7ff3925..c102112 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseArraysBase" uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208" authors = ["ITensor developers and contributors"] -version = "0.7.12" +version = "0.8.0" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/docs/Project.toml b/docs/Project.toml index 786b536..92bdbdf 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -11,4 +11,4 @@ SparseArraysBase = {path = ".."} Dictionaries = "0.4.4" Documenter = "1.8.1" Literate = "2.20.1" -SparseArraysBase = "0.7.0" +SparseArraysBase = "0.8" diff --git a/examples/Project.toml b/examples/Project.toml index 9dbdee0..e11cac3 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -8,5 +8,5 @@ SparseArraysBase = {path = ".."} [compat] Dictionaries = "0.4.4" -SparseArraysBase = "0.7.0" +SparseArraysBase = "0.8" Test = "<0.0.1, 1" diff --git a/test/Project.toml b/test/Project.toml index fdf61ca..761f1bb 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -31,7 +31,7 @@ NamedDimsArrays = "0.12" Random = "<0.0.1, 1" SafeTestsets = "0.1.0" SparseArrays = "1.10" -SparseArraysBase = "0.7" +SparseArraysBase = "0.8" StableRNGs = "1.0.2" Suppressor = "0.2.8" TensorAlgebra = "0.6"