From 9e66bf0b04d8b8cf8a8aaf0cd9c0c58c3a6ff6d6 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 8 Sep 2025 16:43:04 -0400 Subject: [PATCH 01/13] [WIP] More matrix functions, broadcasting, constructors --- Project.toml | 2 +- .../abstractdiagonalarray.jl | 24 +++ .../diagonalarraydiaginterface.jl | 22 --- src/diaginterface/diaginterface.jl | 1 + src/diagonalarray/diagonalarray.jl | 186 +++++++++++++++++- src/diagonalarray/diagonalmatrix.jl | 88 +++++++++ src/dual.jl | 8 + 7 files changed, 303 insertions(+), 28 deletions(-) diff --git a/Project.toml b/Project.toml index 0909e7e..d37916b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiagonalArrays" uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" authors = ["ITensor developers and contributors"] -version = "0.3.17" +version = "0.3.18" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/abstractdiagonalarray/abstractdiagonalarray.jl b/src/abstractdiagonalarray/abstractdiagonalarray.jl index 1c23777..6fca242 100644 --- a/src/abstractdiagonalarray/abstractdiagonalarray.jl +++ b/src/abstractdiagonalarray/abstractdiagonalarray.jl @@ -4,6 +4,30 @@ abstract type AbstractDiagonalArray{T,N} <: AbstractSparseArray{T,N} end const AbstractDiagonalMatrix{T} = AbstractDiagonalArray{T,2} const AbstractDiagonalVector{T} = AbstractDiagonalArray{T,1} +# Define for type stability, for some reason the generic versions +# in SparseArraysBase.jl is not type stable. +# TODO: Investigate type stability of `iszero` in SparseArraysBase.jl. +function Base.iszero(a::AbstractDiagonalArray) + return iszero(diagview(a)) +end + +using FillArrays: AbstractFill, getindex_value +using LinearAlgebra: norm +# TODO: `_norm` works around: +# https://github.com/JuliaArrays/FillArrays.jl/issues/417 +# Change back to `norm` when that is fixed. +_norm(a, p::Int=2) = norm(a, p) +function _norm(a::AbstractFill, p::Int=2) + nrm1 = norm(getindex_value(a)) + return (length(a))^(1/oftype(nrm1, p)) * nrm1 +end +function LinearAlgebra.norm(a::AbstractDiagonalArray, p::Int=2) + # TODO: `_norm` works around: + # https://github.com/JuliaArrays/FillArrays.jl/issues/417 + # Change back to `norm` when that is fixed. + return _norm(diagview(a), p) +end + using LinearAlgebra: LinearAlgebra, ishermitian, isposdef, issymmetric LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix{<:Real}) = issquare(a) function LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix{<:Number}) diff --git a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl index 3f914d3..efcb4f2 100644 --- a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl +++ b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl @@ -99,25 +99,3 @@ function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:DiagonalArrayStyle} copyto!(diagview(dest), broadcasted_diagview(bc)) return dest end - -## SparseArraysBase.StorageIndex(i::DiagIndex) = StorageIndex(index(i)) - -## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndex) -## return a[StorageIndex(i)] -## end - -## function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndex) -## a[StorageIndex(i)] = value -## return a -## end - -## SparseArraysBase.StorageIndices(i::DiagIndices) = StorageIndices(indices(i)) - -## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndices) -## return a[StorageIndices(i)] -## end - -## function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndices) -## a[StorageIndices(i)] = value -## return a -## end diff --git a/src/diaginterface/diaginterface.jl b/src/diaginterface/diaginterface.jl index ab572c9..dcbc331 100644 --- a/src/diaginterface/diaginterface.jl +++ b/src/diaginterface/diaginterface.jl @@ -97,6 +97,7 @@ function setdiagindex!(a::AbstractArray, v, i::Integer) end function getdiagindices(a::AbstractArray, I) + # TODO: Should this be a view? return @view diagview(a)[I] end diff --git a/src/diagonalarray/diagonalarray.jl b/src/diagonalarray/diagonalarray.jl index 47ce8b2..5e74b90 100644 --- a/src/diagonalarray/diagonalarray.jl +++ b/src/diagonalarray/diagonalarray.jl @@ -20,10 +20,58 @@ SparseArraysBase.unstored(a::DiagonalArray) = a.unstored Base.size(a::DiagonalArray) = size(unstored(a)) Base.axes(a::DiagonalArray) = axes(unstored(a)) +function DiagonalArray(diag::AbstractVector, unstored::Unstored) + return _DiagonalArray(diag, parent(unstored)) +end function DiagonalArray(::UndefInitializer, unstored::Unstored) - return _DiagonalArray( - Vector{eltype(unstored)}(undef, minimum(size(unstored))), parent(unstored) - ) + return DiagonalArray(Vector{eltype(unstored)}(undef, minimum(size(unstored))), unstored) +end + +# This helps to support diagonals where the elements are known +# from the types, for example diagonals that are `Zeros` and `Ones`. +function DiagonalArray{T,N,D,U}( + ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} +) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} + return DiagonalArray(D((Base.OneTo(minimum(length, ax)),)), Unstored(U(ax))) +end +function DiagonalArray{T,N,D,U}( + ax1::AbstractUnitRange{<:Integer}, ax_rest::Vararg{AbstractUnitRange{<:Integer}} +) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} + return DiagonalArray{T,N,D,U}((ax1, ax_rest...)) +end +function DiagonalArray{T,N,D,U}( + sz::Tuple{Integer,Vararg{AbstractUnitRange{<:Integer}}} +) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} + return DiagonalArray{T,N,D,U}(Base.OneTo.(sz)) +end +function DiagonalArray{T,N,D,U}( + sz1::Integer, sz_rest::Vararg{Integer} +) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} + return DiagonalArray{T,N,D,U}((sz1, sz_rest...)) +end + +# This helps to support diagonals where the elements are known +# from the types, for example diagonals that are `Zeros` and `Ones`. +# These versions use the default unstored type `Zeros{T,N}`. +function DiagonalArray{T,N,D}( + ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} +) where {T,N,D<:AbstractVector{T}} + return DiagonalArray{T,N,D,Zeros{T,N}}(ax) +end +function DiagonalArray{T,N,D}( + ax1::AbstractUnitRange{<:Integer}, ax_rest::Vararg{AbstractUnitRange{<:Integer}} +) where {T,N,D<:AbstractVector{T}} + return DiagonalArray{T,N,D,Zeros{T,N}}(ax1, ax_rest...) +end +function DiagonalArray{T,N,D}( + sz::Tuple{Integer,Vararg{AbstractUnitRange{<:Integer}}} +) where {T,N,D<:AbstractVector{T}} + return DiagonalArray{T,N,D,Zeros{T,N}}(sz) +end +function DiagonalArray{T,N,D}( + sz1::Integer, sz_rest::Vararg{Integer} +) where {T,N,D<:AbstractVector{T}} + return DiagonalArray{T,N,D,Zeros{T,N}}(sz1, sz_rest...) end # Constructors accepting axes. @@ -32,7 +80,7 @@ function DiagonalArray{T,N}( ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, ) where {T,N} N == length(ax) || throw(ArgumentError("Wrong number of axes")) - return _DiagonalArray(convert(AbstractVector{T}, diag), Zeros{T}(ax)) + return DiagonalArray(convert(AbstractVector{T}, diag), Unstored(Zeros{T}(ax))) end function DiagonalArray{T,N}( diag::AbstractVector, @@ -97,7 +145,7 @@ function DiagonalArray{T}( end function DiagonalArray{T,N}(diag::AbstractVector, dims::Dims{N}) where {T,N} - return _DiagonalArray(convert(AbstractVector{T}, diag), Zeros{T}(dims)) + return DiagonalArray(convert(AbstractVector{T}, diag), Unstored(Zeros{T}(dims))) end function DiagonalArray{T,N}(diag::AbstractVector, dims::Vararg{Int,N}) where {T,N} @@ -161,6 +209,28 @@ function DiagonalArray{T}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N} return DiagonalArray{T,N}(undef, dims) end +# 0-dim limit. +function DiagonalArray{T,0,D}( + ::UndefInitializer, ax::Tuple{} +) where {T,D<:AbstractVector{T}} + return DiagonalArray{T,0,D}(D(undef, 0), ax) +end +function DiagonalArray{T,0,D}(::UndefInitializer) where {T,D<:AbstractVector{T}} + return DiagonalArray{T,0,D}(undef, ()) +end +function DiagonalArray{T,0}(::UndefInitializer, ax::Tuple{}) where {T} + return DiagonalArray{T,0,Vector{T}}(undef, ax) +end +function DiagonalArray{T,0}(::UndefInitializer) where {T} + return DiagonalArray{T,0}(undef, ()) +end +function DiagonalArray{T}(::UndefInitializer, axes::Tuple{}) where {T} + return DiagonalArray{T,0}(undef, ()) +end +function DiagonalArray{T}(::UndefInitializer) where {T} + return DiagonalArray{T}(undef, ()) +end + # Axes version function DiagonalArray{T}(::UndefInitializer, axes::NTuple{N,Base.OneTo{Int}}) where {T,N} return DiagonalArray{T,N}(undef, length.(axes)) @@ -197,3 +267,109 @@ function DerivableInterfaces.permuteddims(a::DiagonalArray, perm) # Unlike `permutedims(::Diagonal, perm)`, we copy here. return DiagonalArray(diagview(a), ax_perm) end + +# Scalar indexing. +using DerivableInterfaces: @interface, interface +one_based_range(r) = false +one_based_range(r::Base.OneTo) = true +one_based_range(r::Base.Slice) = true +function _diag_axes(a::DiagonalArray, I...) + return map(ntuple(identity, ndims(a))) do d + return Base.axes1(axes(a, d)[I[d]]) + end +end +# A view that preserves the diagonal structure. +function _view_diag(a::DiagonalArray, I...) + ax = _diag_axes(a, I...) + return DiagonalArray(view(diagview(a), Base.OneTo(minimum(length, I))), ax) +end +# A slice that preserves the diagonal structure. +function _getindex_diag(a::DiagonalArray, I...) + ax = _diag_axes(a, I...) + return DiagonalArray(diagview(a)[Base.OneTo(minimum(length, I))], ax) +end +function Base.view(a::DiagonalArray, I...) + I′ = to_indices(a, I) + return if all(one_based_range, I′) + _view_diag(a, I′...) + else + invoke(view, Tuple{AbstractArray,Vararg}, a, I′...) + end +end +function Base.getindex(a::DiagonalArray, I::Int...) + return @interface interface(a) a[I...] +end +function Base.getindex(a::DiagonalArray, I::DiagIndex) + return getdiagindex(a, index(I)) +end +function Base.getindex(a::DiagonalArray, I::DiagIndices) + # TODO: Should this be a view? + return @view diagview(a)[indices(I)] +end +function Base.getindex(a::DiagonalArray, I...) + I′ = to_indices(a, I) + return if all(i -> i isa Real, I′) + # Catch scalar indexing case. + @interface interface(a) a[I...] + elseif all(one_based_range, I′) + _getindex_diag(a, I′...) + else + copy(view(a, I′...)) + end +end + +# Define in order to preserve immutable diagonals such as FillArrays. +function DiagonalArray{T,N}(a::DiagonalArray{T,N}) where {T,N} + # TODO: Should this copy? This matches the design of `LinearAlgebra.Diagonal`: + # https://github.com/JuliaLang/LinearAlgebra.jl/blob/release-1.12/src/diagonal.jl#L110-L112 + return a +end +function DiagonalArray{T,N}(a::DiagonalArray{<:Any,N}) where {T,N} + return DiagonalArray{T,N}(diagview(a)) +end +function DiagonalArray{T}(a::DiagonalArray) where {T} + return DiagonalArray{T,ndims(a)}(a) +end +function DiagonalArray(a::DiagonalArray) + return DiagonalArray{eltype(a),ndims(a)}(a) +end +function Base.AbstractArray{T,N}(a::DiagonalArray{<:Any,N}) where {T,N} + return DiagonalArray{T,N}(a) +end + +# TODO: These definitions work around this issue: +# https://github.com/JuliaArrays/FillArrays.jl/issues/416 +# when the diagonal is a FillArrays.Ones or Zeros. +using Base.Broadcast: Broadcast, broadcast, broadcasted +using FillArrays: AbstractFill, Ones, Zeros +_broadcasted(f::F, a::AbstractArray) where {F} = broadcasted(f, a) +_broadcasted(::typeof(identity), a::Ones) = a +_broadcasted(::typeof(identity), a::Zeros) = a +_broadcasted(::typeof(complex), a::Ones) = Ones{complex(eltype(a))}(axes(a)) +_broadcasted(::typeof(complex), a::Zeros) = Zeros{complex(eltype(a))}(axes(a)) +_broadcasted(elt::Type, a::Ones) = Ones{elt}(axes(a)) +_broadcasted(elt::Type, a::Zeros) = Zeros{elt}(axes(a)) +_broadcasted(::typeof(inv), a::Ones) = _broadcasted(typeof(inv(oneunit(eltype(a)))), a) +using LinearAlgebra: pinv +_broadcasted(::typeof(pinv), a::Ones) = _broadcasted(typeof(inv(oneunit(eltype(a)))), a) +_broadcasted(::typeof(sqrt), a::Ones) = _broadcasted(typeof(sqrt(one(eltype(a)))), a) +_broadcasted(::typeof(sqrt), a::Zeros) = _broadcasted(typeof(sqrt(zero(eltype(a)))), a) +_broadcasted(::typeof(cbrt), a::Ones) = _broadcasted(typeof(cbrt(one(eltype(a)))), a) +_broadcasted(::typeof(cbrt), a::Zeros) = _broadcasted(typeof(cbrt(zero(eltype(a)))), a) +_broadcasted(::typeof(exp), a::Zeros) = Ones{typeof(exp(zero(eltype(a))))}(axes(a)) +_broadcasted(::typeof(cis), a::Zeros) = Ones{typeof(cis(zero(eltype(a))))}(axes(a)) +_broadcasted(::typeof(log), a::Ones) = Zeros{typeof(log(one(eltype(a))))}(axes(a)) +_broadcasted(::typeof(cos), a::Zeros) = Ones{typeof(cos(zero(eltype(a))))}(axes(a)) +_broadcasted(::typeof(sin), a::Zeros) = _broadcasted(typeof(sin(zero(eltype(a)))), a) +_broadcasted(::typeof(tan), a::Zeros) = _broadcasted(typeof(tan(zero(eltype(a)))), a) +_broadcasted(::typeof(sec), a::Zeros) = Ones{typeof(sec(zero(eltype(a))))}(axes(a)) +_broadcasted(::typeof(cosh), a::Zeros) = Ones{typeof(cosh(zero(eltype(a))))}(axes(a)) +# Eager version of `_broadcasted`. +_broadcast(f::F, a::AbstractArray) where {F} = copy(_broadcasted(f, a)) + +function Broadcast.broadcasted( + ::DiagonalArrayStyle{N}, f::F, a::DiagonalArray{T,N,Diag} +) where {F,T,N,Diag<:AbstractFill{T}} + # TODO: Check that `f` preserves zeros? + return DiagonalArray(_broadcasted(f, diagview(a)), axes(a)) +end diff --git a/src/diagonalarray/diagonalmatrix.jl b/src/diagonalarray/diagonalmatrix.jl index 4eb6fab..36800f7 100644 --- a/src/diagonalarray/diagonalmatrix.jl +++ b/src/diagonalarray/diagonalmatrix.jl @@ -58,3 +58,91 @@ function LinearAlgebra.mul!( d_dest .= d1 .* d2 .* α .+ d_dest .* β return a_dest end + +# Adapted from https://github.com/JuliaLang/LinearAlgebra.jl/blob/release-1.12/src/diagonal.jl#L866-L928. +function LinearAlgebra.tr(a::DiagonalMatrix) + checksquare(a) + # TODO: Define as `sum(tr, diagview(a))` like LinearAlgebra.jl? + return sum(diagview(a)) +end +# TODO: Special case for FillArrays diagonals. +function LinearAlgebra.det(a::DiagonalMatrix) + checksquare(a) + # TODO: Define as `prod(det, diagview(a))` like LinearAlgebra.jl? + return prod(diagview(a)) +end +# TODO: Special case for FillArrays diagonals. +function LinearAlgebra.logabsdet(a::DiagonalMatrix) + checksquare(a) + return mapreduce(((d1, s1), (d2, s2)) -> (d1 + d2, s1 * s2), diagview(a)) do x + return (log(abs(x)), sign(x)) + end +end +# TODO: Special case for FillArrays diagonals. +function LinearAlgebra.logdet(a::DiagonalMatrix{<:Complex}) + checksquare(a) + z = sum(log, diagview(a)) + return complex(real(z), rem2pi(imag(z), RoundNearest)) +end + +# Matrix functions +for f in ( + :exp, + :cis, + :log, + :sqrt, + :cos, + :sin, + :tan, + :csc, + :sec, + :cot, + :cosh, + :sinh, + :tanh, + :csch, + :sech, + :coth, + :acos, + :asin, + :atan, + :acsc, + :asec, + :acot, + :acosh, + :asinh, + :atanh, + :acsch, + :asech, + :acoth, +) + @eval begin + function Base.$f(a::DiagonalMatrix) + checksquare(a) + return DiagonalMatrix(_broadcast($f, diagview(a)), axes(a)) + end + end +end + +# Cube root of a real-valued diagonal matrix +function Base.cbrt(a::DiagonalMatrix{<:Real}) + checksquare(a) + return DiagonalMatrix(_broadcast(cbrt, diagview(a)), axes(a)) +end + +function LinearAlgebra.inv(a::DiagonalMatrix) + checksquare(a) + # `DiagonalArrays._broadcast` works around issues like https://github.com/JuliaArrays/FillArrays.jl/issues/416 + # when the diagonal is a FillArray or similar lazy array. + d⁻¹ = _broadcast(inv, diagview(a)) + any(isinf, d⁻¹) && error("Singular Exception") + return DiagonalMatrix(d⁻¹, axes(a)) +end + +# TODO: Support `atol` and `rtol` keyword arguments: +# https://docs.julialang.org/en/v1/stdlib/LinearAlgebra/#LinearAlgebra.pinv +using LinearAlgebra: pinv +function LinearAlgebra.pinv(a::DiagonalMatrix) + checksquare(a) + return DiagonalMatrix(_broadcast(pinv, diagview(a)), axes(a)) +end diff --git a/src/dual.jl b/src/dual.jl index 9780a3b..36b6ee5 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -1,3 +1,11 @@ # TODO: Define `TensorProducts.dual`. dual(x) = x issquare(a::AbstractMatrix) = (axes(a, 1) == dual(axes(a, 2))) +# Like `LinearAlgebra.checksquare` but based on `DiagonalArrays.issquare`, +# which checks the axes and allows customizing to check that the +# codomain is the dual of the domain. +# Returns the codomain if the check passes. +function checksquare(a::AbstractMatrix) + issquare(a) || throw(DimensionMismatch(lazy"matrix is not square: axes are $(axes(a))")) + return axes(a, 1) +end From cef70fa0276eb45f1592c8e59fc4613eb5a4ccbc Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 8 Sep 2025 21:08:55 -0400 Subject: [PATCH 02/13] More tests --- src/diagonalarray/diagonalarray.jl | 20 +++++++++++++++----- test/test_basics.jl | 13 ++++++++++++- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/diagonalarray/diagonalarray.jl b/src/diagonalarray/diagonalarray.jl index 5e74b90..0cb6c48 100644 --- a/src/diagonalarray/diagonalarray.jl +++ b/src/diagonalarray/diagonalarray.jl @@ -27,12 +27,22 @@ function DiagonalArray(::UndefInitializer, unstored::Unstored) return DiagonalArray(Vector{eltype(unstored)}(undef, minimum(size(unstored))), unstored) end +function construct_from_length(vect::Type{<:AbstractVector}, len::Integer) + if applicable(vect, len) + return vect(len) + elseif applicable(vect, (Base.OneTo(len),)) + return vect((Base.OneTo(len),)) + else + error(lazy"Can't construct $(vect) from length.") + end +end + # This helps to support diagonals where the elements are known # from the types, for example diagonals that are `Zeros` and `Ones`. function DiagonalArray{T,N,D,U}( ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} ) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray(D((Base.OneTo(minimum(length, ax)),)), Unstored(U(ax))) + return DiagonalArray(construct_from_length(D, minimum(length, ax)), Unstored(U(ax))) end function DiagonalArray{T,N,D,U}( ax1::AbstractUnitRange{<:Integer}, ax_rest::Vararg{AbstractUnitRange{<:Integer}} @@ -40,12 +50,12 @@ function DiagonalArray{T,N,D,U}( return DiagonalArray{T,N,D,U}((ax1, ax_rest...)) end function DiagonalArray{T,N,D,U}( - sz::Tuple{Integer,Vararg{AbstractUnitRange{<:Integer}}} + sz::Tuple{Integer,Vararg{Integer}} ) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} return DiagonalArray{T,N,D,U}(Base.OneTo.(sz)) end function DiagonalArray{T,N,D,U}( - sz1::Integer, sz_rest::Vararg{Integer} + sz1::Integer, sz_rest::Integer... ) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} return DiagonalArray{T,N,D,U}((sz1, sz_rest...)) end @@ -64,12 +74,12 @@ function DiagonalArray{T,N,D}( return DiagonalArray{T,N,D,Zeros{T,N}}(ax1, ax_rest...) end function DiagonalArray{T,N,D}( - sz::Tuple{Integer,Vararg{AbstractUnitRange{<:Integer}}} + sz::Tuple{Integer,Vararg{Integer}} ) where {T,N,D<:AbstractVector{T}} return DiagonalArray{T,N,D,Zeros{T,N}}(sz) end function DiagonalArray{T,N,D}( - sz1::Integer, sz_rest::Vararg{Integer} + sz1::Integer, sz_rest::Integer... ) where {T,N,D<:AbstractVector{T}} return DiagonalArray{T,N,D,Zeros{T,N}}(sz1, sz_rest...) end diff --git a/test/test_basics.jl b/test/test_basics.jl index 4bc24cb..98ab8de 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -15,7 +15,7 @@ using DiagonalArrays: diagonal, diagonaltype, diagview -using FillArrays: Fill, Ones +using FillArrays: Fill, Ones, Zeros using SparseArraysBase: SparseArrayDOK, sparsezeros, storedlength using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric @@ -105,6 +105,17 @@ using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric eltype(DiagonalArray{elt,2}(undef, (2, 2))) ≡ eltype(DiagonalArray{elt,2}(undef, Base.OneTo(2), Base.OneTo(2))) ≡ eltype(DiagonalArray{elt,2}(undef, (Base.OneTo(2), Base.OneTo(2)))) + + # Special constructors for immutable diagonal. + @test DiagonalMatrix(Base.OneTo(UInt32(2))) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}(Base.OneTo.((2, 2))) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}(Base.OneTo.((2, 2))...) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}((2, 2)) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}(2, 2) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(Base.OneTo.((2, 2))) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(Base.OneTo.((2, 2))...) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32}}((2, 2)) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(2, 2) end @testset "permutedims" begin a = DiagonalArray(randn(elt, 2), (2, 3, 4)) From 37891d2ec8e78a39e52e78b87061b73f03730fe7 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 9 Sep 2025 09:54:18 -0400 Subject: [PATCH 03/13] More tests --- src/diagonalarray/diagonalarray.jl | 66 +++++++++++++++++++++--------- test/test_basics.jl | 22 ++++++---- 2 files changed, 61 insertions(+), 27 deletions(-) diff --git a/src/diagonalarray/diagonalarray.jl b/src/diagonalarray/diagonalarray.jl index 0cb6c48..43ab8a3 100644 --- a/src/diagonalarray/diagonalarray.jl +++ b/src/diagonalarray/diagonalarray.jl @@ -27,7 +27,14 @@ function DiagonalArray(::UndefInitializer, unstored::Unstored) return DiagonalArray(Vector{eltype(unstored)}(undef, minimum(size(unstored))), unstored) end -function construct_from_length(vect::Type{<:AbstractVector}, len::Integer) +# Indicate we will construct an array just from the shape, +# for example for a Base.OneTo or FillArrays.Ones or Zeros. +# All the elements should be uniquely defined by the input axes. +struct ShapeInitializer end + +# This is used to create custom constructors for arrays, +# in this case a generic constructor of a vector from a length. +function construct(vect::Type{<:AbstractVector}, ::ShapeInitializer, len::Integer) if applicable(vect, len) return vect(len) elseif applicable(vect, (Base.OneTo(len),)) @@ -40,51 +47,57 @@ end # This helps to support diagonals where the elements are known # from the types, for example diagonals that are `Zeros` and `Ones`. function DiagonalArray{T,N,D,U}( - ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} + init::ShapeInitializer, + ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, ) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray(construct_from_length(D, minimum(length, ax)), Unstored(U(ax))) + return DiagonalArray(construct(D, init, minimum(length, ax)), Unstored(U(ax))) end function DiagonalArray{T,N,D,U}( - ax1::AbstractUnitRange{<:Integer}, ax_rest::Vararg{AbstractUnitRange{<:Integer}} + init::ShapeInitializer, + ax1::AbstractUnitRange{<:Integer}, + ax_rest::Vararg{AbstractUnitRange{<:Integer}}, ) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray{T,N,D,U}((ax1, ax_rest...)) + return DiagonalArray{T,N,D,U}(init, (ax1, ax_rest...)) end function DiagonalArray{T,N,D,U}( - sz::Tuple{Integer,Vararg{Integer}} + init::ShapeInitializer, sz::Tuple{Integer,Vararg{Integer}} ) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray{T,N,D,U}(Base.OneTo.(sz)) + return DiagonalArray{T,N,D,U}(init, Base.OneTo.(sz)) end function DiagonalArray{T,N,D,U}( - sz1::Integer, sz_rest::Integer... + init::ShapeInitializer, sz1::Integer, sz_rest::Integer... ) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray{T,N,D,U}((sz1, sz_rest...)) + return DiagonalArray{T,N,D,U}(init, (sz1, sz_rest...)) end # This helps to support diagonals where the elements are known # from the types, for example diagonals that are `Zeros` and `Ones`. # These versions use the default unstored type `Zeros{T,N}`. function DiagonalArray{T,N,D}( - ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} + init::ShapeInitializer, + ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, ) where {T,N,D<:AbstractVector{T}} - return DiagonalArray{T,N,D,Zeros{T,N}}(ax) + return DiagonalArray{T,N,D,Zeros{T,N}}(init, ax) end function DiagonalArray{T,N,D}( - ax1::AbstractUnitRange{<:Integer}, ax_rest::Vararg{AbstractUnitRange{<:Integer}} + init::ShapeInitializer, + ax1::AbstractUnitRange{<:Integer}, + ax_rest::Vararg{AbstractUnitRange{<:Integer}}, ) where {T,N,D<:AbstractVector{T}} - return DiagonalArray{T,N,D,Zeros{T,N}}(ax1, ax_rest...) + return DiagonalArray{T,N,D,Zeros{T,N}}(init, ax1, ax_rest...) end function DiagonalArray{T,N,D}( - sz::Tuple{Integer,Vararg{Integer}} + init::ShapeInitializer, sz::Tuple{Integer,Vararg{Integer}} ) where {T,N,D<:AbstractVector{T}} - return DiagonalArray{T,N,D,Zeros{T,N}}(sz) + return DiagonalArray{T,N,D,Zeros{T,N}}(init, sz) end function DiagonalArray{T,N,D}( - sz1::Integer, sz_rest::Integer... + init::ShapeInitializer, sz1::Integer, sz_rest::Integer... ) where {T,N,D<:AbstractVector{T}} - return DiagonalArray{T,N,D,Zeros{T,N}}(sz1, sz_rest...) + return DiagonalArray{T,N,D,Zeros{T,N}}(init, sz1, sz_rest...) end -# Constructors accepting axes. +# Constructor from diagonal entries accepting axes. function DiagonalArray{T,N}( diag::AbstractVector, ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, @@ -219,7 +232,22 @@ function DiagonalArray{T}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N} return DiagonalArray{T,N}(undef, dims) end -# 0-dim limit. +# 0-dim from diag. +function DiagonalArray{T,0,D}( + diag::AbstractVector, ax::Tuple{} +) where {T,D<:AbstractVector{T}} + error() +end +function DiagonalArray{T,0}(diag::AbstractVector, ax::Tuple{}) where {T} + diag′ = convert(AbstractVector{T}, diag) + D = typeof(diag′) + return DiagonalArray{T,0,D}(diag, ax) +end +function DiagonalArray{T,0}(diag::AbstractVector) where {T} + return DiagonalArray{T,0}(diag, ()) +end + +# 0-dim undef. function DiagonalArray{T,0,D}( ::UndefInitializer, ax::Tuple{} ) where {T,D<:AbstractVector{T}} diff --git a/test/test_basics.jl b/test/test_basics.jl index 98ab8de..99580a9 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -2,6 +2,7 @@ using Test: @test, @testset, @test_broken, @inferred using DerivableInterfaces: permuteddims using DiagonalArrays: DiagonalArrays, + ShapeInitializer, Delta, DeltaMatrix, DiagonalArray, @@ -107,15 +108,20 @@ using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric eltype(DiagonalArray{elt,2}(undef, (Base.OneTo(2), Base.OneTo(2)))) # Special constructors for immutable diagonal. + init = ShapeInitializer() @test DiagonalMatrix(Base.OneTo(UInt32(2))) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}(Base.OneTo.((2, 2))) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}(Base.OneTo.((2, 2))...) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}((2, 2)) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}(2, 2) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(Base.OneTo.((2, 2))) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(Base.OneTo.((2, 2))...) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32}}((2, 2)) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(2, 2) + DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}( + init, Base.OneTo.((2, 2)) + ) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}( + init, Base.OneTo.((2, 2))... + ) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}(init, (2, 2)) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}(init, 2, 2) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))...) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, (2, 2)) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, 2, 2) end @testset "permutedims" begin a = DiagonalArray(randn(elt, 2), (2, 3, 4)) From 85dc72a71612b82906cd10158f1cf7a917a2cb3d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 9 Sep 2025 11:38:37 -0400 Subject: [PATCH 04/13] More tests --- .../diagonalarraydiaginterface.jl | 7 ++ src/diagonalarray/diagonalarray.jl | 118 +++++++++--------- test/test_basics.jl | 34 +++++ 3 files changed, 103 insertions(+), 56 deletions(-) diff --git a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl index efcb4f2..d9f203d 100644 --- a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl +++ b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl @@ -43,6 +43,9 @@ function SparseArraysBase.getstoredindex( # allequal(I) || error("Not a diagonal index.") return getdiagindex(a, first(I)) end +function SparseArraysBase.getstoredindex(a::AbstractDiagonalArray{<:Any,0}) + return getdiagindex(a, 1) +end function SparseArraysBase.setstoredindex!( a::AbstractDiagonalArray{<:Any,N}, value, I::Vararg{Int,N} ) where {N} @@ -52,6 +55,10 @@ function SparseArraysBase.setstoredindex!( setdiagindex!(a, value, first(I)) return a end +function SparseArraysBase.setstoredindex!(a::AbstractDiagonalArray{<:Any,0}, value) + setdiagindex!(a, value, 1) + return a +end function SparseArraysBase.eachstoredindex(::IndexCartesian, a::AbstractDiagonalArray) return diagindices(a) end diff --git a/src/diagonalarray/diagonalarray.jl b/src/diagonalarray/diagonalarray.jl index 43ab8a3..c32f546 100644 --- a/src/diagonalarray/diagonalarray.jl +++ b/src/diagonalarray/diagonalarray.jl @@ -1,6 +1,14 @@ using FillArrays: Zeros using SparseArraysBase: Unstored, unstored +diaglength_from_shape(sz::Tuple{Integer,Vararg{Integer}}) = minimum(sz) +function diaglength_from_shape( + sz::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} +) + return minimum(length, sz) +end +diaglength_from_shape(sz::Tuple{}) = 1 + function _DiagonalArray end struct DiagonalArray{T,N,Diag<:AbstractVector{T},Unstored<:AbstractArray{T,N}} <: @@ -10,7 +18,7 @@ struct DiagonalArray{T,N,Diag<:AbstractVector{T},Unstored<:AbstractArray{T,N}} < global @inline function _DiagonalArray( diag::Diag, unstored::Unstored ) where {T,N,Diag<:AbstractVector{T},Unstored<:AbstractArray{T,N}} - length(diag) == minimum(size(unstored)) || + length(diag) == diaglength_from_shape(size(unstored)) || throw(ArgumentError("Length of diagonals doesn't match dimensions")) return new{T,N,Diag,Unstored}(diag, unstored) end @@ -24,7 +32,9 @@ function DiagonalArray(diag::AbstractVector, unstored::Unstored) return _DiagonalArray(diag, parent(unstored)) end function DiagonalArray(::UndefInitializer, unstored::Unstored) - return DiagonalArray(Vector{eltype(unstored)}(undef, minimum(size(unstored))), unstored) + return DiagonalArray( + Vector{eltype(unstored)}(undef, diaglength_from_shape(size(unstored))), unstored + ) end # Indicate we will construct an array just from the shape, @@ -47,17 +57,14 @@ end # This helps to support diagonals where the elements are known # from the types, for example diagonals that are `Zeros` and `Ones`. function DiagonalArray{T,N,D,U}( - init::ShapeInitializer, - ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, + init::ShapeInitializer, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}} ) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray(construct(D, init, minimum(length, ax)), Unstored(U(ax))) + return DiagonalArray(construct(D, init, diaglength_from_shape(ax)), Unstored(U(ax))) end function DiagonalArray{T,N,D,U}( - init::ShapeInitializer, - ax1::AbstractUnitRange{<:Integer}, - ax_rest::Vararg{AbstractUnitRange{<:Integer}}, + init::ShapeInitializer, ax::AbstractUnitRange{<:Integer}... ) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray{T,N,D,U}(init, (ax1, ax_rest...)) + return DiagonalArray{T,N,D,U}(init, ax) end function DiagonalArray{T,N,D,U}( init::ShapeInitializer, sz::Tuple{Integer,Vararg{Integer}} @@ -74,27 +81,24 @@ end # from the types, for example diagonals that are `Zeros` and `Ones`. # These versions use the default unstored type `Zeros{T,N}`. function DiagonalArray{T,N,D}( - init::ShapeInitializer, - ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, + init::ShapeInitializer, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}} ) where {T,N,D<:AbstractVector{T}} - return DiagonalArray{T,N,D,Zeros{T,N}}(init, ax) + return DiagonalArray{T,N,D,Zeros{T,N,typeof(ax)}}(init, ax) end function DiagonalArray{T,N,D}( - init::ShapeInitializer, - ax1::AbstractUnitRange{<:Integer}, - ax_rest::Vararg{AbstractUnitRange{<:Integer}}, + init::ShapeInitializer, ax::AbstractUnitRange{<:Integer}... ) where {T,N,D<:AbstractVector{T}} - return DiagonalArray{T,N,D,Zeros{T,N}}(init, ax1, ax_rest...) + return DiagonalArray{T,N,D}(init, ax) end function DiagonalArray{T,N,D}( init::ShapeInitializer, sz::Tuple{Integer,Vararg{Integer}} ) where {T,N,D<:AbstractVector{T}} - return DiagonalArray{T,N,D,Zeros{T,N}}(init, sz) + return DiagonalArray{T,N,D}(init, Base.OneTo.(sz)) end function DiagonalArray{T,N,D}( init::ShapeInitializer, sz1::Integer, sz_rest::Integer... ) where {T,N,D<:AbstractVector{T}} - return DiagonalArray{T,N,D,Zeros{T,N}}(init, sz1, sz_rest...) + return DiagonalArray{T,N,D}(init, (sz1, sz_rest...)) end # Constructor from diagonal entries accepting axes. @@ -217,7 +221,7 @@ end # undef function DiagonalArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N} - return DiagonalArray{T,N}(Vector{T}(undef, minimum(dims)), dims) + return DiagonalArray{T,N}(Vector{T}(undef, diaglength_from_shape(dims)), dims) end function DiagonalArray{T,N}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N} @@ -233,45 +237,47 @@ function DiagonalArray{T}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N} end # 0-dim from diag. -function DiagonalArray{T,0,D}( - diag::AbstractVector, ax::Tuple{} -) where {T,D<:AbstractVector{T}} - error() -end -function DiagonalArray{T,0}(diag::AbstractVector, ax::Tuple{}) where {T} - diag′ = convert(AbstractVector{T}, diag) - D = typeof(diag′) - return DiagonalArray{T,0,D}(diag, ax) -end -function DiagonalArray{T,0}(diag::AbstractVector) where {T} - return DiagonalArray{T,0}(diag, ()) -end - -# 0-dim undef. -function DiagonalArray{T,0,D}( - ::UndefInitializer, ax::Tuple{} -) where {T,D<:AbstractVector{T}} - return DiagonalArray{T,0,D}(D(undef, 0), ax) -end -function DiagonalArray{T,0,D}(::UndefInitializer) where {T,D<:AbstractVector{T}} - return DiagonalArray{T,0,D}(undef, ()) -end -function DiagonalArray{T,0}(::UndefInitializer, ax::Tuple{}) where {T} - return DiagonalArray{T,0,Vector{T}}(undef, ax) -end -function DiagonalArray{T,0}(::UndefInitializer) where {T} - return DiagonalArray{T,0}(undef, ()) -end -function DiagonalArray{T}(::UndefInitializer, axes::Tuple{}) where {T} - return DiagonalArray{T,0}(undef, ()) -end -function DiagonalArray{T}(::UndefInitializer) where {T} - return DiagonalArray{T}(undef, ()) -end +## function DiagonalArray{T,0,D}( +## diag::AbstractVector, ax::Tuple{} +## ) where {T,D<:AbstractVector{T}} +## error() +## end +## function DiagonalArray{T,0}(diag::AbstractVector, ax::Tuple{}) where {T} +## diag′ = convert(AbstractVector{T}, diag) +## D = typeof(diag′) +## return DiagonalArray{T,0,D}(diag, ax) +## end +## function DiagonalArray{T,0}(diag::AbstractVector) where {T} +## return DiagonalArray{T,0}(diag, ()) +## end + +## # 0-dim undef. +## function DiagonalArray{T,0,D}( +## ::UndefInitializer, ax::Tuple{} +## ) where {T,D<:AbstractVector{T}} +## return DiagonalArray{T,0,D}(D(undef, 0), ax) +## end +## function DiagonalArray{T,0,D}(::UndefInitializer) where {T,D<:AbstractVector{T}} +## return DiagonalArray{T,0,D}(undef, ()) +## end +## function DiagonalArray{T,0}(::UndefInitializer, ax::Tuple{}) where {T} +## return DiagonalArray{T,0,Vector{T}}(undef, ax) +## end +## function DiagonalArray{T,0}(::UndefInitializer) where {T} +## return DiagonalArray{T,0}(undef, ()) +## end +## function DiagonalArray{T}(::UndefInitializer, axes::Tuple{}) where {T} +## return DiagonalArray{T,0}(undef, ()) +## end +## function DiagonalArray{T}(::UndefInitializer) where {T} +## return DiagonalArray{T}(undef, ()) +## end # Axes version -function DiagonalArray{T}(::UndefInitializer, axes::NTuple{N,Base.OneTo{Int}}) where {T,N} - return DiagonalArray{T,N}(undef, length.(axes)) +function DiagonalArray{T}( + ::UndefInitializer, axes::Tuple{Base.OneTo{Int},Vararg{Base.OneTo{Int}}} +) where {T} + return DiagonalArray{T,length(axes)}(undef, length.(axes)) end function Base.similar(a::DiagonalArray, unstored::Unstored) diff --git a/test/test_basics.jl b/test/test_basics.jl index 99580a9..8ad4a0e 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -122,6 +122,40 @@ using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))...) ≡ DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, (2, 2)) ≡ DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, 2, 2) + + # 0-dim constructors + v = randn(elt, 1) + @test DiagonalArray(v) ≡ + DiagonalArray(v, ()) ≡ + DiagonalArray{elt}(v) ≡ + DiagonalArray{elt}(v, ()) ≡ + DiagonalArray{elt,0}(v) ≡ + DiagonalArray{elt,0}(v, ()) + @test size(DiagonalArray{elt}(undef)) ≡ + size(DiagonalArray{elt}(undef, ())) ≡ + size(DiagonalArray{elt,0}(undef)) ≡ + size(DiagonalArray{elt,0}(undef, ())) + @test elt ≡ + eltype(DiagonalArray{elt}(undef)) ≡ + eltype(DiagonalArray{elt}(undef, ())) ≡ + eltype(DiagonalArray{elt,0}(undef)) ≡ + eltype(DiagonalArray{elt,0}(undef, ())) + + # Special constructors for immutable diagonal. + init = ShapeInitializer() + @test DiagonalMatrix(Base.OneTo(UInt32(2))) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}( + init, Base.OneTo.((2, 2)) + ) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}( + init, Base.OneTo.((2, 2))... + ) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}(init, (2, 2)) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}(init, 2, 2) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))...) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, (2, 2)) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, 2, 2) end @testset "permutedims" begin a = DiagonalArray(randn(elt, 2), (2, 3, 4)) From 8dcddf63cd44fe48c57f36b5fea64af11102a6da Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 9 Sep 2025 11:40:40 -0400 Subject: [PATCH 05/13] More 0-dim tests --- test/test_basics.jl | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/test/test_basics.jl b/test/test_basics.jl index 8ad4a0e..e780714 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -143,19 +143,11 @@ using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric # Special constructors for immutable diagonal. init = ShapeInitializer() - @test DiagonalMatrix(Base.OneTo(UInt32(2))) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}( - init, Base.OneTo.((2, 2)) - ) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}( - init, Base.OneTo.((2, 2))... - ) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}(init, (2, 2)) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}(init, 2, 2) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))...) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, (2, 2)) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, 2, 2) + @test DiagonalArray{<:Any,0}(Base.OneTo(UInt32(1))) ≡ + DiagonalArray{UInt32,0,Base.OneTo{UInt32},Zeros{UInt32,0}}(init, ()) ≡ + DiagonalArray{UInt32,0,Base.OneTo{UInt32},Zeros{UInt32,0}}(init) ≡ + DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init, ()) ≡ + DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init) end @testset "permutedims" begin a = DiagonalArray(randn(elt, 2), (2, 3, 4)) From 5188f4b7bef52edba3931256d1a1381cda8cb67c Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 9 Sep 2025 12:03:51 -0400 Subject: [PATCH 06/13] Slicing tests --- src/diagonalarray/diagonalarray.jl | 45 ++++++------------------------ test/test_basics.jl | 36 +++++++++++++++++++++++- 2 files changed, 43 insertions(+), 38 deletions(-) diff --git a/src/diagonalarray/diagonalarray.jl b/src/diagonalarray/diagonalarray.jl index c32f546..5f4aab4 100644 --- a/src/diagonalarray/diagonalarray.jl +++ b/src/diagonalarray/diagonalarray.jl @@ -236,43 +236,6 @@ function DiagonalArray{T}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N} return DiagonalArray{T,N}(undef, dims) end -# 0-dim from diag. -## function DiagonalArray{T,0,D}( -## diag::AbstractVector, ax::Tuple{} -## ) where {T,D<:AbstractVector{T}} -## error() -## end -## function DiagonalArray{T,0}(diag::AbstractVector, ax::Tuple{}) where {T} -## diag′ = convert(AbstractVector{T}, diag) -## D = typeof(diag′) -## return DiagonalArray{T,0,D}(diag, ax) -## end -## function DiagonalArray{T,0}(diag::AbstractVector) where {T} -## return DiagonalArray{T,0}(diag, ()) -## end - -## # 0-dim undef. -## function DiagonalArray{T,0,D}( -## ::UndefInitializer, ax::Tuple{} -## ) where {T,D<:AbstractVector{T}} -## return DiagonalArray{T,0,D}(D(undef, 0), ax) -## end -## function DiagonalArray{T,0,D}(::UndefInitializer) where {T,D<:AbstractVector{T}} -## return DiagonalArray{T,0,D}(undef, ()) -## end -## function DiagonalArray{T,0}(::UndefInitializer, ax::Tuple{}) where {T} -## return DiagonalArray{T,0,Vector{T}}(undef, ax) -## end -## function DiagonalArray{T,0}(::UndefInitializer) where {T} -## return DiagonalArray{T,0}(undef, ()) -## end -## function DiagonalArray{T}(::UndefInitializer, axes::Tuple{}) where {T} -## return DiagonalArray{T,0}(undef, ()) -## end -## function DiagonalArray{T}(::UndefInitializer) where {T} -## return DiagonalArray{T}(undef, ()) -## end - # Axes version function DiagonalArray{T}( ::UndefInitializer, axes::Tuple{Base.OneTo{Int},Vararg{Base.OneTo{Int}}} @@ -327,11 +290,19 @@ function _view_diag(a::DiagonalArray, I...) ax = _diag_axes(a, I...) return DiagonalArray(view(diagview(a), Base.OneTo(minimum(length, I))), ax) end +function _view_diag(a::DiagonalArray, I1::Base.Slice, Irest::Base.Slice...) + ax = _diag_axes(a, I1, Irest...) + return DiagonalArray(view(diagview(a), :), ax) +end # A slice that preserves the diagonal structure. function _getindex_diag(a::DiagonalArray, I...) ax = _diag_axes(a, I...) return DiagonalArray(diagview(a)[Base.OneTo(minimum(length, I))], ax) end +function _getindex_diag(a::DiagonalArray, I1::Base.Slice, Irest::Base.Slice...) + ax = _diag_axes(a, I1, Irest...) + return DiagonalArray(diagview(a)[:], ax) +end function Base.view(a::DiagonalArray, I...) I′ = to_indices(a, I) return if all(one_based_range, I′) diff --git a/test/test_basics.jl b/test/test_basics.jl index e780714..623515b 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -17,7 +17,7 @@ using DiagonalArrays: diagonaltype, diagview using FillArrays: Fill, Ones, Zeros -using SparseArraysBase: SparseArrayDOK, sparsezeros, storedlength +using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, sparsezeros, storedlength using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric @testset "Test DiagonalArrays" begin @@ -149,6 +149,40 @@ using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init, ()) ≡ DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init) end + @testset "Slicing" begin + # Slicing that preserves the diagonal structure. + a = DiagonalMatrix(randn(elt, 3)) + b = @view a[:, :] + @test b isa DiagonalMatrix{elt,<:SubArray{elt,1}} + @test diagview(b) ≡ view(diagview(a), :) + + a = DiagonalMatrix(randn(elt, 3)) + b = @view a[Base.OneTo(2), Base.OneTo(2)] + @test b isa DiagonalMatrix{elt,<:SubArray{elt,1}} + @test diagview(b) ≡ view(diagview(a), Base.OneTo(2)) + + a = DiagonalMatrix(randn(elt, 3)) + b = a[:, :] + @test typeof(b) == typeof(a) + @test diagview(b) == diagview(a) + + a = DiagonalMatrix(randn(elt, 3)) + b = a[Base.OneTo(2), Base.OneTo(2)] + @test typeof(b) == typeof(a) + @test diagview(b) == diagview(a)[Base.OneTo(2)] + + # Slicing that doesn't preserve the diagonal structure. + a = DiagonalMatrix(randn(elt, 3)) + b = @view a[2:3, 2:3] + @test b isa SubArray + @test b == Matrix(a)[2:3, 2:3] + + a = DiagonalMatrix(randn(elt, 3)) + b = a[2:3, 2:3] + @test b isa SparseMatrixDOK{elt} + @test b == Matrix(a)[2:3, 2:3] + @test storedlength(b) == 2 + end @testset "permutedims" begin a = DiagonalArray(randn(elt, 2), (2, 3, 4)) b = permutedims(a, (3, 1, 2)) From 2b83052f5400d8c3976f00f3263ed60d29068195 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 9 Sep 2025 12:45:27 -0400 Subject: [PATCH 07/13] Refactor constructors a little bit --- src/diagonalarray/diagonalarray.jl | 67 +++++++++++++++++------------- test/test_basics.jl | 17 +++----- 2 files changed, 42 insertions(+), 42 deletions(-) diff --git a/src/diagonalarray/diagonalarray.jl b/src/diagonalarray/diagonalarray.jl index 5f4aab4..ef82ccc 100644 --- a/src/diagonalarray/diagonalarray.jl +++ b/src/diagonalarray/diagonalarray.jl @@ -9,18 +9,16 @@ function diaglength_from_shape( end diaglength_from_shape(sz::Tuple{}) = 1 -function _DiagonalArray end - -struct DiagonalArray{T,N,Diag<:AbstractVector{T},Unstored<:AbstractArray{T,N}} <: +struct DiagonalArray{T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} <: AbstractDiagonalArray{T,N} - diag::Diag - unstored::Unstored - global @inline function _DiagonalArray( - diag::Diag, unstored::Unstored - ) where {T,N,Diag<:AbstractVector{T},Unstored<:AbstractArray{T,N}} + diag::D + unstored::U + function DiagonalArray{T,N,D,U}( + diag::AbstractVector, unstored::Unstored + ) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} length(diag) == diaglength_from_shape(size(unstored)) || throw(ArgumentError("Length of diagonals doesn't match dimensions")) - return new{T,N,Diag,Unstored}(diag, unstored) + return new{T,N,D,U}(diag, parent(unstored)) end end @@ -28,9 +26,31 @@ SparseArraysBase.unstored(a::DiagonalArray) = a.unstored Base.size(a::DiagonalArray) = size(unstored(a)) Base.axes(a::DiagonalArray) = axes(unstored(a)) +function DiagonalArray{T,N,D}( + diag::D, unstored::Unstored{T,N} +) where {T,N,D<:AbstractVector{T}} + U = typeof(parent(unstored)) + return DiagonalArray{T,N,D,U}(diag, unstored) +end +function DiagonalArray{T,N}(diag::AbstractVector{T}, unstored::Unstored{T,N}) where {T,N} + D = typeof(diag) + U = typeof(parent(unstored)) + return DiagonalArray{T,N,D,U}(diag, unstored) +end +function DiagonalArray{T}(diag::AbstractVector{T}, unstored::Unstored{T}) where {T} + N = ndims(unstored) + D = typeof(diag) + U = typeof(parent(unstored)) + return DiagonalArray{T,N,D,U}(diag, unstored) +end function DiagonalArray(diag::AbstractVector, unstored::Unstored) - return _DiagonalArray(diag, parent(unstored)) + T = eltype(diag) + N = ndims(unstored) + D = typeof(diag) + U = typeof(parent(unstored)) + return DiagonalArray{T,N,D,U}(diag, unstored) end + function DiagonalArray(::UndefInitializer, unstored::Unstored) return DiagonalArray( Vector{eltype(unstored)}(undef, diaglength_from_shape(size(unstored))), unstored @@ -56,25 +76,12 @@ end # This helps to support diagonals where the elements are known # from the types, for example diagonals that are `Zeros` and `Ones`. -function DiagonalArray{T,N,D,U}( - init::ShapeInitializer, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}} -) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray(construct(D, init, diaglength_from_shape(ax)), Unstored(U(ax))) -end -function DiagonalArray{T,N,D,U}( - init::ShapeInitializer, ax::AbstractUnitRange{<:Integer}... -) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray{T,N,D,U}(init, ax) -end -function DiagonalArray{T,N,D,U}( - init::ShapeInitializer, sz::Tuple{Integer,Vararg{Integer}} -) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray{T,N,D,U}(init, Base.OneTo.(sz)) -end -function DiagonalArray{T,N,D,U}( - init::ShapeInitializer, sz1::Integer, sz_rest::Integer... -) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} - return DiagonalArray{T,N,D,U}(init, (sz1, sz_rest...)) +function DiagonalArray{T,N,D}( + init::ShapeInitializer, unstored::Unstored +) where {T,N,D<:AbstractVector{T}} + return DiagonalArray{T,N,D}( + construct(D, init, diaglength_from_shape(axes(unstored))), unstored + ) end # This helps to support diagonals where the elements are known @@ -83,7 +90,7 @@ end function DiagonalArray{T,N,D}( init::ShapeInitializer, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}} ) where {T,N,D<:AbstractVector{T}} - return DiagonalArray{T,N,D,Zeros{T,N,typeof(ax)}}(init, ax) + return DiagonalArray{T,N,D}(init, Unstored(Zeros{T,N}(ax))) end function DiagonalArray{T,N,D}( init::ShapeInitializer, ax::AbstractUnitRange{<:Integer}... diff --git a/test/test_basics.jl b/test/test_basics.jl index 623515b..7495001 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -9,6 +9,7 @@ using DiagonalArrays: DiagonalMatrix, ScaledDelta, ScaledDeltaMatrix, + Unstored, δ, delta, diagindices, @@ -110,18 +111,11 @@ using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric # Special constructors for immutable diagonal. init = ShapeInitializer() @test DiagonalMatrix(Base.OneTo(UInt32(2))) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}( - init, Base.OneTo.((2, 2)) - ) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}( - init, Base.OneTo.((2, 2))... - ) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}(init, (2, 2)) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}(init, 2, 2) ≡ DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))) ≡ DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))...) ≡ DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, (2, 2)) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, 2, 2) + DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, 2, 2) ≡ + DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, Unstored(Zeros{UInt32}(2, 2))) # 0-dim constructors v = randn(elt, 1) @@ -144,10 +138,9 @@ using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric # Special constructors for immutable diagonal. init = ShapeInitializer() @test DiagonalArray{<:Any,0}(Base.OneTo(UInt32(1))) ≡ - DiagonalArray{UInt32,0,Base.OneTo{UInt32},Zeros{UInt32,0}}(init, ()) ≡ - DiagonalArray{UInt32,0,Base.OneTo{UInt32},Zeros{UInt32,0}}(init) ≡ DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init, ()) ≡ - DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init) + DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init) ≡ + DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init, Unstored(Zeros{UInt32}())) end @testset "Slicing" begin # Slicing that preserves the diagonal structure. From ebb9b557e7b50406e7b0e4d7a778bb7a9f44d64d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 9 Sep 2025 12:54:48 -0400 Subject: [PATCH 08/13] More tests --- test/test_basics.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/test_basics.jl b/test/test_basics.jl index 7495001..0e654e0 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -142,6 +142,17 @@ using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init) ≡ DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init, Unstored(Zeros{UInt32}())) end + @testset "Conversion" begin + a = DiagonalMatrix(randn(elt, 2)) + @test DiagonalMatrix{elt}(a) ≡ a + @test DiagonalMatrix{ComplexF64}(a) == DiagonalMatrix(ComplexF64.(diagview(a))) + @test DiagonalArray{ComplexF64}(a) == DiagonalMatrix(ComplexF64.(diagview(a))) + @test DiagonalArray(a) ≡ a + @test AbstractMatrix{elt}(a) ≡ a + @test AbstractMatrix{ComplexF64}(a) == DiagonalMatrix(ComplexF64.(diagview(a))) + @test AbstractArray{elt}(a) ≡ a + @test AbstractArray{ComplexF64}(a) == DiagonalMatrix(ComplexF64.(diagview(a))) + end @testset "Slicing" begin # Slicing that preserves the diagonal structure. a = DiagonalMatrix(randn(elt, 3)) From ea6d5ecc56668414c5e5d08223f0069631ca997f Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 9 Sep 2025 14:17:31 -0400 Subject: [PATCH 09/13] More tests --- src/diagonalarray/diagonalarray.jl | 1 + test/test_basics.jl | 37 +++++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/diagonalarray/diagonalarray.jl b/src/diagonalarray/diagonalarray.jl index ef82ccc..aaa4175 100644 --- a/src/diagonalarray/diagonalarray.jl +++ b/src/diagonalarray/diagonalarray.jl @@ -374,6 +374,7 @@ _broadcasted(elt::Type, a::Zeros) = Zeros{elt}(axes(a)) _broadcasted(::typeof(inv), a::Ones) = _broadcasted(typeof(inv(oneunit(eltype(a)))), a) using LinearAlgebra: pinv _broadcasted(::typeof(pinv), a::Ones) = _broadcasted(typeof(inv(oneunit(eltype(a)))), a) +_broadcasted(::typeof(pinv), a::Zeros) = _broadcasted(typeof(inv(zero(eltype(a)))), a) _broadcasted(::typeof(sqrt), a::Ones) = _broadcasted(typeof(sqrt(one(eltype(a)))), a) _broadcasted(::typeof(sqrt), a::Zeros) = _broadcasted(typeof(sqrt(zero(eltype(a)))), a) _broadcasted(::typeof(cbrt), a::Ones) = _broadcasted(typeof(cbrt(one(eltype(a)))), a) diff --git a/test/test_basics.jl b/test/test_basics.jl index 0e654e0..a3f0781 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -19,7 +19,7 @@ using DiagonalArrays: diagview using FillArrays: Fill, Ones, Zeros using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, sparsezeros, storedlength -using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric +using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric, pinv @testset "Test DiagonalArrays" begin @testset "DiagonalArray (eltype=$elt)" for elt in ( @@ -215,6 +215,41 @@ using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric # Non-zero-preserving functions not supported yet. c = DiagonalArray{elt}(undef, (2, 3)) @test_broken c .= a .+ 2 + + a_ones = DiagonalMatrix(Ones{elt}(2)) + a_zeros = DiagonalMatrix(Zeros{elt}(2)) + @test identity.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) + @test identity.(a_zeros) ≡ DiagonalMatrix(Zeros{elt}(2)) + @test complex.(a_ones) ≡ DiagonalMatrix(Ones{complex(elt)}(2)) + @test complex.(a_zeros) ≡ DiagonalMatrix(Zeros{complex(elt)}(2)) + @test Float32.(a_ones) ≡ DiagonalMatrix(Ones{Float32}(2)) + @test Float32.(a_zeros) ≡ DiagonalMatrix(Zeros{Float32}(2)) + @test inv.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) + @test inv.(a_zeros) ≡ DiagonalMatrix(Fill(inv(zero(elt)), 2)) + @test pinv.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) + @test pinv.(a_zeros) ≡ DiagonalMatrix(Zeros{elt}(2)) + @test sqrt.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) + @test sqrt.(a_zeros) ≡ DiagonalMatrix(Zeros{elt}(2)) + if elt <: Real + @test cbrt.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) + @test cbrt.(a_zeros) ≡ DiagonalMatrix(Zeros{elt}(2)) + end + @test exp.(a_ones) ≡ DiagonalMatrix(Fill(exp(one(elt)), 2)) + @test exp.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(exp(zero(elt)))}(2)) + @test cis.(a_ones) ≡ DiagonalMatrix(Fill(cis(one(elt)), 2)) + @test cis.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(cis(zero(elt)))}(2)) + @test log.(a_ones) ≡ DiagonalMatrix(Zeros{typeof(log(one(elt)))}(2)) + @test log.(a_zeros) ≡ DiagonalMatrix(Fill(log(zero(elt)), 2)) + @test cos.(a_ones) ≡ DiagonalMatrix(Fill(cos(one(elt)), 2)) + @test cos.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(cos(zero(elt)))}(2)) + @test sin.(a_ones) ≡ DiagonalMatrix(Fill(sin(one(elt)), 2)) + @test sin.(a_zeros) ≡ DiagonalMatrix(Zeros{typeof(sin(zero(elt)))}(2)) + @test tan.(a_ones) ≡ DiagonalMatrix(Fill(tan(one(elt)), 2)) + @test tan.(a_zeros) ≡ DiagonalMatrix(Zeros{typeof(tan(zero(elt)))}(2)) + @test sec.(a_ones) ≡ DiagonalMatrix(Fill(sec(one(elt)), 2)) + @test sec.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(sec(zero(elt)))}(2)) + @test cosh.(a_ones) ≡ DiagonalMatrix(Fill(cosh(one(elt)), 2)) + @test cosh.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(cosh(zero(elt)))}(2)) end @testset "LinearAlgebra matrix properties" begin @test ishermitian(DiagonalMatrix([1, 2])) From 588367874f7b24e9be800741f9429ca50cead1cb Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 9 Sep 2025 14:55:23 -0400 Subject: [PATCH 10/13] More tests --- src/diagonalarray/diagonalarray.jl | 14 ++---- src/diagonalarray/diagonalmatrix.jl | 4 +- test/test_basics.jl | 72 ++++++++++++++++++++++++++++- 3 files changed, 77 insertions(+), 13 deletions(-) diff --git a/src/diagonalarray/diagonalarray.jl b/src/diagonalarray/diagonalarray.jl index aaa4175..7deb7c0 100644 --- a/src/diagonalarray/diagonalarray.jl +++ b/src/diagonalarray/diagonalarray.jl @@ -34,21 +34,15 @@ function DiagonalArray{T,N,D}( end function DiagonalArray{T,N}(diag::AbstractVector{T}, unstored::Unstored{T,N}) where {T,N} D = typeof(diag) - U = typeof(parent(unstored)) - return DiagonalArray{T,N,D,U}(diag, unstored) + return DiagonalArray{T,N,D}(diag, unstored) end -function DiagonalArray{T}(diag::AbstractVector{T}, unstored::Unstored{T}) where {T} - N = ndims(unstored) +function DiagonalArray{T}(diag::AbstractVector{T}, unstored::Unstored{T,N}) where {T,N} D = typeof(diag) U = typeof(parent(unstored)) return DiagonalArray{T,N,D,U}(diag, unstored) end -function DiagonalArray(diag::AbstractVector, unstored::Unstored) - T = eltype(diag) - N = ndims(unstored) - D = typeof(diag) - U = typeof(parent(unstored)) - return DiagonalArray{T,N,D,U}(diag, unstored) +function DiagonalArray(diag::AbstractVector{T}, unstored::Unstored{T}) where {T} + return DiagonalArray{T}(diag, unstored) end function DiagonalArray(::UndefInitializer, unstored::Unstored) diff --git a/src/diagonalarray/diagonalmatrix.jl b/src/diagonalarray/diagonalmatrix.jl index 36800f7..d3a310d 100644 --- a/src/diagonalarray/diagonalmatrix.jl +++ b/src/diagonalarray/diagonalmatrix.jl @@ -86,7 +86,7 @@ function LinearAlgebra.logdet(a::DiagonalMatrix{<:Complex}) end # Matrix functions -for f in ( +for f in [ :exp, :cis, :log, @@ -115,7 +115,7 @@ for f in ( :acsch, :asech, :acoth, -) +] @eval begin function Base.$f(a::DiagonalMatrix) checksquare(a) diff --git a/test/test_basics.jl b/test/test_basics.jl index a3f0781..d3c6adb 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -19,7 +19,8 @@ using DiagonalArrays: diagview using FillArrays: Fill, Ones, Zeros using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, sparsezeros, storedlength -using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric, pinv +using LinearAlgebra: + Diagonal, det, ishermitian, isposdef, issymmetric, logdet, mul!, pinv, tr @testset "Test DiagonalArrays" begin @testset "DiagonalArray (eltype=$elt)" for elt in ( @@ -142,6 +143,13 @@ using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric, pinv DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init) ≡ DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init, Unstored(Zeros{UInt32}())) end + @testset "0-dim operations" begin + diag = randn(elt, 1) + a = DiagonalArray(diag) + @test a[] == diag[1] + a[] = 2 + @test a[] == 2 + end @testset "Conversion" begin a = DiagonalMatrix(randn(elt, 2)) @test DiagonalMatrix{elt}(a) ≡ a @@ -273,6 +281,68 @@ using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric, pinv @test !isposdef(DiagonalMatrix([randn(2, 2), randn(3, 3)])) @test !isposdef(DiagonalMatrix([randn(2, 2), randn(2, 3)])) end + @testset "LinearAlgebra matrix functions" begin + diag = randn(elt, 2) + a = DiagonalMatrix(diag) + @test tr(a) ≈ sum(diag) + @test det(a) ≈ prod(diag) + + # Use a positive diagonal in order to take the `log`. + diag = rand(elt, 2) + a = DiagonalMatrix(diag) + @test real(logdet(a)) ≈ real(sum(log, diag)) + @test imag(logdet(a)) ≈ rem2pi(imag(sum(log, diag)), RoundNearest) + + for f in [ + :exp, + :cis, + :log, + :sqrt, + :cos, + :sin, + :tan, + :csc, + :sec, + :cot, + :cosh, + :sinh, + :tanh, + :csch, + :sech, + :coth, + :acos, + :asin, + :atan, + :acot, + :asinh, + :atanh, + :acsch, + :asech, + ] + @eval begin + a = DiagonalMatrix(rand($elt, 2)) + @test $f(a) ≈ DiagonalMatrix($f.(diagview(a))) + end + end + + for f in [:acsc, :asec, :acosh, :acoth] + @eval begin + a = DiagonalMatrix(inv.(rand($elt, 2))) + @test $f(a) ≈ DiagonalMatrix($f.(diagview(a))) + end + end + + if elt <: Real + a = DiagonalMatrix(randn(elt, 2)) + @test cbrt(a) ≈ DiagonalMatrix(cbrt.(diagview(a))) + end + + a = DiagonalMatrix(randn(elt, 2)) + @test inv(a) ≈ DiagonalMatrix(inv.(diagview(a))) + + a = DiagonalMatrix(randn(elt, 2)) + @test pinv(a) ≈ DiagonalMatrix(pinv.(diagview(a))) + end @testset "Matrix multiplication" begin a1 = DiagonalArray{elt}(undef, (2, 3)) a1[1, 1] = 11 From ed27769482e187d62465f24b7843b1b3645e7189 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 9 Sep 2025 15:20:44 -0400 Subject: [PATCH 11/13] More tests --- src/diagonalarray/diagonalarray.jl | 14 ++++------ test/test_basics.jl | 44 +++++++++++++++++++++++------- 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/src/diagonalarray/diagonalarray.jl b/src/diagonalarray/diagonalarray.jl index 7deb7c0..7e0c7cd 100644 --- a/src/diagonalarray/diagonalarray.jl +++ b/src/diagonalarray/diagonalarray.jl @@ -27,19 +27,17 @@ Base.size(a::DiagonalArray) = size(unstored(a)) Base.axes(a::DiagonalArray) = axes(unstored(a)) function DiagonalArray{T,N,D}( - diag::D, unstored::Unstored{T,N} -) where {T,N,D<:AbstractVector{T}} - U = typeof(parent(unstored)) + diag::D, unstored::Unstored{T,N,U} +) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} return DiagonalArray{T,N,D,U}(diag, unstored) end -function DiagonalArray{T,N}(diag::AbstractVector{T}, unstored::Unstored{T,N}) where {T,N} - D = typeof(diag) +function DiagonalArray{T,N}( + diag::D, unstored::Unstored{T,N} +) where {T,N,D<:AbstractVector{T}} return DiagonalArray{T,N,D}(diag, unstored) end function DiagonalArray{T}(diag::AbstractVector{T}, unstored::Unstored{T,N}) where {T,N} - D = typeof(diag) - U = typeof(parent(unstored)) - return DiagonalArray{T,N,D,U}(diag, unstored) + return DiagonalArray{T,N}(diag, unstored) end function DiagonalArray(diag::AbstractVector{T}, unstored::Unstored{T}) where {T} return DiagonalArray{T}(diag, unstored) diff --git a/test/test_basics.jl b/test/test_basics.jl index d3c6adb..f28a061 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -16,7 +16,8 @@ using DiagonalArrays: diaglength, diagonal, diagonaltype, - diagview + diagview, + getdiagindices using FillArrays: Fill, Ones, Zeros using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, sparsezeros, storedlength using LinearAlgebra: @@ -104,19 +105,29 @@ using LinearAlgebra: eltype(DiagonalArray{elt}(undef, (2, 2))) ≡ eltype(DiagonalArray{elt}(undef, Base.OneTo(2), Base.OneTo(2))) ≡ eltype(DiagonalArray{elt}(undef, (Base.OneTo(2), Base.OneTo(2)))) ≡ - eltype(DiagonalArray{elt,2}(undef, 2, 2)) ≡ - eltype(DiagonalArray{elt,2}(undef, (2, 2))) ≡ - eltype(DiagonalArray{elt,2}(undef, Base.OneTo(2), Base.OneTo(2))) ≡ - eltype(DiagonalArray{elt,2}(undef, (Base.OneTo(2), Base.OneTo(2)))) + eltype(DiagonalMatrix{elt}(undef, 2, 2)) ≡ + eltype(DiagonalMatrix{elt}(undef, (2, 2))) ≡ + eltype(DiagonalMatrix{elt}(undef, Base.OneTo(2), Base.OneTo(2))) ≡ + eltype(DiagonalMatrix{elt}(undef, (Base.OneTo(2), Base.OneTo(2)))) # Special constructors for immutable diagonal. init = ShapeInitializer() @test DiagonalMatrix(Base.OneTo(UInt32(2))) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))...) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, (2, 2)) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, 2, 2) ≡ - DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(init, Unstored(Zeros{UInt32}(2, 2))) + DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))) ≡ + DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))...) ≡ + DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, (2, 2)) ≡ + DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, 2, 2) ≡ + DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Unstored(Zeros{UInt32}(2, 2))) + @test DiagonalMatrix(Ones{elt}(2)) ≡ + DiagonalMatrix{elt,Ones{elt,1,Tuple{Base.OneTo{Int}}}}(init, Base.OneTo.((2, 2))) ≡ + DiagonalMatrix{elt,Ones{elt,1,Tuple{Base.OneTo{Int}}}}( + init, Base.OneTo.((2, 2))... + ) ≡ + DiagonalMatrix{elt,Ones{elt,1,Tuple{Base.OneTo{Int}}}}(init, (2, 2)) ≡ + DiagonalMatrix{elt,Ones{elt,1,Tuple{Base.OneTo{Int}}}}(init, 2, 2) ≡ + DiagonalMatrix{elt,Ones{elt,1,Tuple{Base.OneTo{Int}}}}( + init, Unstored(Zeros{elt}(2, 2)) + ) # 0-dim constructors v = randn(elt, 1) @@ -259,6 +270,16 @@ using LinearAlgebra: @test cosh.(a_ones) ≡ DiagonalMatrix(Fill(cosh(one(elt)), 2)) @test cosh.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(cosh(zero(elt)))}(2)) end + @testset "Array properties" begin + a = DiagonalMatrix(randn(elt, 2)) + @test !iszero(a) + + a = DiagonalMatrix(zeros(elt, 2)) + @test iszero(a) + + a = DiagonalMatrix(Zeros{elt}(2)) + @test iszero(a) + end @testset "LinearAlgebra matrix properties" begin @test ishermitian(DiagonalMatrix([1, 2])) @test !ishermitian(DiagonalMatrix([1, 2], (2, 3))) @@ -399,6 +420,9 @@ using LinearAlgebra: @test d isa Diagonal{eltype(v)} @test diagview(d) == diagview(a) @test diagonaltype(a) === typeof(d) + + a = randn(3, 3) + @test getdiagindices(a, 2:3) == diagview(a)[2:3] end @testset "delta" begin for (a, elt′) in ( From f86cb641430edc68b278fb4e9e8dcbb845e4430f Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 9 Sep 2025 15:33:33 -0400 Subject: [PATCH 12/13] More test coverage --- test/test_basics.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/test/test_basics.jl b/test/test_basics.jl index f28a061..de1dce7 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,4 +1,3 @@ -using Test: @test, @testset, @test_broken, @inferred using DerivableInterfaces: permuteddims using DiagonalArrays: DiagonalArrays, @@ -19,9 +18,10 @@ using DiagonalArrays: diagview, getdiagindices using FillArrays: Fill, Ones, Zeros -using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, sparsezeros, storedlength using LinearAlgebra: Diagonal, det, ishermitian, isposdef, issymmetric, logdet, mul!, pinv, tr +using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, sparsezeros, storedlength +using Test: @test, @test_throws, @testset, @test_broken, @inferred @testset "Test DiagonalArrays" begin @testset "DiagonalArray (eltype=$elt)" for elt in ( @@ -118,6 +118,8 @@ using LinearAlgebra: DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, (2, 2)) ≡ DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, 2, 2) ≡ DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Unstored(Zeros{UInt32}(2, 2))) + + init = ShapeInitializer() @test DiagonalMatrix(Ones{elt}(2)) ≡ DiagonalMatrix{elt,Ones{elt,1,Tuple{Base.OneTo{Int}}}}(init, Base.OneTo.((2, 2))) ≡ DiagonalMatrix{elt,Ones{elt,1,Tuple{Base.OneTo{Int}}}}( @@ -129,6 +131,14 @@ using LinearAlgebra: init, Unstored(Zeros{elt}(2, 2)) ) + init = ShapeInitializer() + @test_throws ErrorException DiagonalMatrix{elt,Vector{elt}}(init, Base.OneTo.((2, 2))) + @test_throws ErrorException DiagonalMatrix{elt,Vector{elt}}( + init, Base.OneTo.((2, 2))... + ) + @test_throws ErrorException DiagonalMatrix{elt,Vector{elt}}(init, (2, 2)) + @test_throws ErrorException DiagonalMatrix{elt,Vector{elt}}(init, 2, 2) + # 0-dim constructors v = randn(elt, 1) @test DiagonalArray(v) ≡ From 37795bc76bf8a1e09a517eb913a700c831b98d47 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 9 Sep 2025 15:44:33 -0400 Subject: [PATCH 13/13] Add KroneckerArrays as a downstream test --- .github/workflows/IntegrationTest.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index 381cd67..2e658d7 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -18,6 +18,7 @@ jobs: matrix: pkg: - 'BlockSparseArrays' + - 'KroneckerArrays' uses: "ITensor/ITensorActions/.github/workflows/IntegrationTest.yml@main" with: localregistry: "https://github.com/ITensor/ITensorRegistry.git"