diff --git a/src/factorizations/diagonal.jl b/src/factorizations/diagonal.jl index c54bef00e..a101b1f21 100644 --- a/src/factorizations/diagonal.jl +++ b/src/factorizations/diagonal.jl @@ -72,7 +72,7 @@ function MAK.initialize_output( V_cod = fuse(codomain(t)) V_dom = fuse(domain(t)) U = similar(t, codomain(t) ← V_cod) - S = DiagonalTensorMap{real(scalartype(t))}(undef, V_cod ← V_dom) + S = similar_diagonal(t, real(scalartype(t)), V_cod) Vᴴ = similar(t, V_dom ← domain(t)) return U, S, Vᴴ end diff --git a/src/factorizations/factorizations.jl b/src/factorizations/factorizations.jl index 73fbc45eb..02f1712eb 100644 --- a/src/factorizations/factorizations.jl +++ b/src/factorizations/factorizations.jl @@ -6,9 +6,12 @@ module Factorizations export copy_oftype, factorisation_scalartype, one!, truncspace using ..TensorKit -using ..TensorKit: AdjointTensorMap, SectorDict, SectorVector, blocktype, foreachblock, one! +using ..TensorKit: AdjointTensorMap, SectorDict, SectorVector, + blocktype, foreachblock, one!, + similar_diagonal, similarstoragetype -using LinearAlgebra: LinearAlgebra, BlasFloat, Diagonal, svdvals, svdvals!, eigen, eigen!, +using LinearAlgebra: LinearAlgebra, BlasFloat, Diagonal, + svdvals, svdvals!, eigen, eigen!, isposdef, isposdef!, ishermitian using TensorOperations: Index2Tuple diff --git a/src/factorizations/matrixalgebrakit.jl b/src/factorizations/matrixalgebrakit.jl index 4564a8137..90776892c 100644 --- a/src/factorizations/matrixalgebrakit.jl +++ b/src/factorizations/matrixalgebrakit.jl @@ -74,7 +74,7 @@ end function MAK.initialize_output(::typeof(svd_compact!), t::AbstractTensorMap, ::AbstractAlgorithm) V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) U = similar(t, codomain(t) ← V_cod) - S = DiagonalTensorMap{real(scalartype(t))}(undef, V_cod) + S = similar_diagonal(t, real(scalartype(t)), V_cod) Vᴴ = similar(t, V_dom ← domain(t)) return U, S, Vᴴ end @@ -82,15 +82,15 @@ end function MAK.initialize_output(::typeof(svd_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm) V_cod = infimum(fuse(codomain(t)), fuse(domain(t))) T = real(scalartype(t)) - return SectorVector{T}(undef, V_cod) + A = similarstoragetype(t, T) + return SectorVector{T, sectortype(t), A}(undef, V_cod) end # Eigenvalue decomposition # ------------------------ function MAK.initialize_output(::typeof(eigh_full!), t::AbstractTensorMap, ::AbstractAlgorithm) V_D = fuse(domain(t)) - T = real(scalartype(t)) - D = DiagonalTensorMap{T}(undef, V_D) + D = similar_diagonal(t, real(scalartype(t)), V_D) V = similar(t, codomain(t) ← V_D) return D, V end @@ -98,7 +98,7 @@ end function MAK.initialize_output(::typeof(eig_full!), t::AbstractTensorMap, ::AbstractAlgorithm) V_D = fuse(domain(t)) Tc = complex(scalartype(t)) - D = DiagonalTensorMap{Tc}(undef, V_D) + D = similar_diagonal(t, Tc, V_D) V = similar(t, Tc, codomain(t) ← V_D) return D, V end @@ -106,13 +106,15 @@ end function MAK.initialize_output(::typeof(eigh_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm) V_D = fuse(domain(t)) T = real(scalartype(t)) - return SectorVector{T}(undef, V_D) + A = similarstoragetype(t, T) + return SectorVector{T, sectortype(t), A}(undef, V_D) end function MAK.initialize_output(::typeof(eig_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm) V_D = fuse(domain(t)) Tc = complex(scalartype(t)) - return SectorVector{Tc}(undef, V_D) + A = similarstoragetype(t, Tc) + return SectorVector{Tc, sectortype(t), A}(undef, V_D) end # QR decomposition diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index dd2e663c6..42d1e3248 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -67,7 +67,7 @@ function MAK.truncate( Ũ = similar(U, codomain(U) ← V_truncated) truncate_domain!(Ũ, U, ind) - S̃ = DiagonalTensorMap{scalartype(S)}(undef, V_truncated) + S̃ = similar_diagonal(S, V_truncated) truncate_diagonal!(S̃, S, ind) Ṽᴴ = similar(Vᴴ, V_truncated ← domain(Vᴴ)) truncate_codomain!(Ṽᴴ, Vᴴ, ind) @@ -132,7 +132,7 @@ for f! in (:eig_trunc!, :eigh_trunc!) ind = MAK.findtruncated(diagview(D), strategy) V_truncated = truncate_space(space(D, 1), ind) - D̃ = DiagonalTensorMap{scalartype(D)}(undef, V_truncated) + D̃ = similar_diagonal(D, V_truncated) truncate_diagonal!(D̃, D, ind) Ṽ = similar(V, codomain(V) ← V_truncated) diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index 7eeb3af9d..217429a4e 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -491,6 +491,8 @@ The structure may be specified either as a single `HomSpace` argument or as `cod By default, this will result in `TensorMap{T}(undef, V)` when custom objects do not specialize this method. + +See also [`similar_diagonal`](@ref). """ Base.similar(::AbstractTensorMap, args...) function Base.similar( @@ -543,6 +545,51 @@ function Base.similar( return TensorMap{scalartype(TT)}(undef, cod, dom) end +# similar diagonal +# ---------------- +# The implementation is again written for similar_diagonal(t, TorA, V::ElementarySpace) -> DiagonalTensorMap +# and all other methods are just filling in default arguments +@doc """ + similar_diagonal(t::AbstractTensorMap, [AorT=storagetype(t)], [V::ElementarySpace]) + +Creates an uninitialized mutable diagonal tensor with the given scalar or storagetype `AorT` and +structure `V ← V`, based on the source tensormap. The second argument is optional and defaults +to the given tensor's `storagetype`, while the third argument can only be omitted for square +input tensors of space `V ← V`, to conform with the diagonal structure. + +By default, this will result in `DiagonalTensorMap{T}(undef, V)` when custom objects do not +specialize this method. Furthermore, the method will throw if the provided space is not compatible +with a diagonal structure. + +See also [`Base.similar`](@ref). +""" similar_diagonal(::AbstractTensorMap, args...) + +# 3 arguments +function similar_diagonal(t::AbstractTensorMap, ::Type{TorA}, V::ElementarySpace) where {TorA} + if TorA <: Number + T = TorA + A = similarstoragetype(t, T) + elseif TorA <: DenseVector + A = TorA + T = scalartype(A) + else + throw(ArgumentError("Type $TorA not supported for similar")) + end + + return DiagonalTensorMap{T, spacetype(V), A}(undef, V) +end + +similar_diagonal(t::AbstractTensorMap) = similar_diagonal(t, similarstoragetype(t), _diagspace(t)) +similar_diagonal(t::AbstractTensorMap, V::ElementarySpace) = similar_diagonal(t, similarstoragetype(t), V) +similar_diagonal(t::AbstractTensorMap, T::Type) = similar_diagonal(t, T, _diagspace(t)) + +function _diagspace(t) + cod, dom = codomain(t), domain(t) + length(cod) == 1 && cod == dom || + throw(ArgumentError("space does not support a DiagonalTensorMap")) + return only(cod) +end + # Equality and approximality #---------------------------- function Base.:(==)(t1::AbstractTensorMap, t2::AbstractTensorMap) diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index 356f5dbf4..e35c2ffdd 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -78,10 +78,12 @@ function DiagonalTensorMap(t::AbstractTensorMap{T, S, 1, 1}) where {T, S} return d end -Base.similar(d::DiagonalTensorMap) = DiagonalTensorMap(similar(d.data), d.domain) -function Base.similar(d::DiagonalTensorMap, ::Type{T}) where {T <: Number} - return DiagonalTensorMap(similar(d.data, T), d.domain) -end +Base.similar(d::DiagonalTensorMap) = similar_diagonal(d) +Base.similar(d::DiagonalTensorMap, ::Type{T}) where {T} = similar_diagonal(d, T) + +similar_diagonal(d::DiagonalTensorMap) = DiagonalTensorMap(similar(d.data), d.domain) +similar_diagonal(d::DiagonalTensorMap, ::Type{T}) where {T <: Number} = + DiagonalTensorMap(similar(d.data, T), d.domain) # TODO: more constructors needed? diff --git a/src/tensors/sectorvector.jl b/src/tensors/sectorvector.jl index a2a081058..4c914d6de 100644 --- a/src/tensors/sectorvector.jl +++ b/src/tensors/sectorvector.jl @@ -16,6 +16,11 @@ function SectorVector{T}(::UndefInitializer, V::ElementarySpace) where {T} structure = diagonalblockstructure(V ← V) return SectorVector(data, structure) end +function SectorVector{T, I, A}(::UndefInitializer, V::ElementarySpace) where {T, I, A <: AbstractVector{T}} + data = A(undef, reduceddim(V)) + structure = diagonalblockstructure(V ← V) + return SectorVector{T, I, A}(data, structure) +end Base.parent(v::SectorVector) = v.data