Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/factorizations/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function MAK.initialize_output(
V_cod = fuse(codomain(t))
V_dom = fuse(domain(t))
U = similar(t, codomain(t) ← V_cod)
S = DiagonalTensorMap{real(scalartype(t))}(undef, V_cod ← V_dom)
S = similar_diagonal(t, real(scalartype(t)), V_cod)
Vᴴ = similar(t, V_dom ← domain(t))
return U, S, Vᴴ
end
Expand Down
7 changes: 5 additions & 2 deletions src/factorizations/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ module Factorizations
export copy_oftype, factorisation_scalartype, one!, truncspace

using ..TensorKit
using ..TensorKit: AdjointTensorMap, SectorDict, SectorVector, blocktype, foreachblock, one!
using ..TensorKit: AdjointTensorMap, SectorDict, SectorVector,
blocktype, foreachblock, one!,
similar_diagonal, similarstoragetype

using LinearAlgebra: LinearAlgebra, BlasFloat, Diagonal, svdvals, svdvals!, eigen, eigen!,
using LinearAlgebra: LinearAlgebra, BlasFloat, Diagonal,
svdvals, svdvals!, eigen, eigen!,
isposdef, isposdef!, ishermitian

using TensorOperations: Index2Tuple
Expand Down
16 changes: 9 additions & 7 deletions src/factorizations/matrixalgebrakit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,45 +74,47 @@ end
function MAK.initialize_output(::typeof(svd_compact!), t::AbstractTensorMap, ::AbstractAlgorithm)
V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t)))
U = similar(t, codomain(t) ← V_cod)
S = DiagonalTensorMap{real(scalartype(t))}(undef, V_cod)
S = similar_diagonal(t, real(scalartype(t)), V_cod)
Vᴴ = similar(t, V_dom ← domain(t))
return U, S, Vᴴ
end

function MAK.initialize_output(::typeof(svd_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm)
V_cod = infimum(fuse(codomain(t)), fuse(domain(t)))
T = real(scalartype(t))
return SectorVector{T}(undef, V_cod)
A = similarstoragetype(t, T)
return SectorVector{T, sectortype(t), A}(undef, V_cod)
end

# Eigenvalue decomposition
# ------------------------
function MAK.initialize_output(::typeof(eigh_full!), t::AbstractTensorMap, ::AbstractAlgorithm)
V_D = fuse(domain(t))
T = real(scalartype(t))
D = DiagonalTensorMap{T}(undef, V_D)
D = similar_diagonal(t, real(scalartype(t)), V_D)
V = similar(t, codomain(t) ← V_D)
return D, V
end

function MAK.initialize_output(::typeof(eig_full!), t::AbstractTensorMap, ::AbstractAlgorithm)
V_D = fuse(domain(t))
Tc = complex(scalartype(t))
D = DiagonalTensorMap{Tc}(undef, V_D)
D = similar_diagonal(t, Tc, V_D)
V = similar(t, Tc, codomain(t) ← V_D)
return D, V
end

function MAK.initialize_output(::typeof(eigh_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm)
V_D = fuse(domain(t))
T = real(scalartype(t))
return SectorVector{T}(undef, V_D)
A = similarstoragetype(t, T)
return SectorVector{T, sectortype(t), A}(undef, V_D)
end

function MAK.initialize_output(::typeof(eig_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm)
V_D = fuse(domain(t))
Tc = complex(scalartype(t))
return SectorVector{Tc}(undef, V_D)
A = similarstoragetype(t, Tc)
return SectorVector{Tc, sectortype(t), A}(undef, V_D)
end

# QR decomposition
Expand Down
4 changes: 2 additions & 2 deletions src/factorizations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ function MAK.truncate(

Ũ = similar(U, codomain(U) ← V_truncated)
truncate_domain!(Ũ, U, ind)
S̃ = DiagonalTensorMap{scalartype(S)}(undef, V_truncated)
S̃ = similar_diagonal(S, V_truncated)
truncate_diagonal!(S̃, S, ind)
Ṽᴴ = similar(Vᴴ, V_truncated ← domain(Vᴴ))
truncate_codomain!(Ṽᴴ, Vᴴ, ind)
Expand Down Expand Up @@ -132,7 +132,7 @@ for f! in (:eig_trunc!, :eigh_trunc!)
ind = MAK.findtruncated(diagview(D), strategy)
V_truncated = truncate_space(space(D, 1), ind)

D̃ = DiagonalTensorMap{scalartype(D)}(undef, V_truncated)
D̃ = similar_diagonal(D, V_truncated)
truncate_diagonal!(D̃, D, ind)

Ṽ = similar(V, codomain(V) ← V_truncated)
Expand Down
47 changes: 47 additions & 0 deletions src/tensors/abstracttensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,8 @@ The structure may be specified either as a single `HomSpace` argument or as `cod

By default, this will result in `TensorMap{T}(undef, V)` when custom objects do not
specialize this method.

See also [`similar_diagonal`](@ref).
""" Base.similar(::AbstractTensorMap, args...)

function Base.similar(
Expand Down Expand Up @@ -543,6 +545,51 @@ function Base.similar(
return TensorMap{scalartype(TT)}(undef, cod, dom)
end

# similar diagonal
# ----------------
# The implementation is again written for similar_diagonal(t, TorA, V::ElementarySpace) -> DiagonalTensorMap
# and all other methods are just filling in default arguments
@doc """
similar_diagonal(t::AbstractTensorMap, [AorT=storagetype(t)], [V::ElementarySpace])

Creates an uninitialized mutable diagonal tensor with the given scalar or storagetype `AorT` and
structure `V ← V`, based on the source tensormap. The second argument is optional and defaults
to the given tensor's `storagetype`, while the third argument can only be omitted for square
input tensors of space `V ← V`, to conform with the diagonal structure.

By default, this will result in `DiagonalTensorMap{T}(undef, V)` when custom objects do not
specialize this method. Furthermore, the method will throw if the provided space is not compatible
with a diagonal structure.

See also [`Base.similar`](@ref).
""" similar_diagonal(::AbstractTensorMap, args...)

# 3 arguments
function similar_diagonal(t::AbstractTensorMap, ::Type{TorA}, V::ElementarySpace) where {TorA}
if TorA <: Number
T = TorA
A = similarstoragetype(t, T)
elseif TorA <: DenseVector
A = TorA
T = scalartype(A)
else
throw(ArgumentError("Type $TorA not supported for similar"))
end

return DiagonalTensorMap{T, spacetype(V), A}(undef, V)
end

similar_diagonal(t::AbstractTensorMap) = similar_diagonal(t, similarstoragetype(t), _diagspace(t))
similar_diagonal(t::AbstractTensorMap, V::ElementarySpace) = similar_diagonal(t, similarstoragetype(t), V)
similar_diagonal(t::AbstractTensorMap, T::Type) = similar_diagonal(t, T, _diagspace(t))

function _diagspace(t)
cod, dom = codomain(t), domain(t)
length(cod) == 1 && cod == dom ||
throw(ArgumentError("space does not support a DiagonalTensorMap"))
return only(cod)
end

# Equality and approximality
#----------------------------
function Base.:(==)(t1::AbstractTensorMap, t2::AbstractTensorMap)
Expand Down
10 changes: 6 additions & 4 deletions src/tensors/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,12 @@ function DiagonalTensorMap(t::AbstractTensorMap{T, S, 1, 1}) where {T, S}
return d
end

Base.similar(d::DiagonalTensorMap) = DiagonalTensorMap(similar(d.data), d.domain)
function Base.similar(d::DiagonalTensorMap, ::Type{T}) where {T <: Number}
return DiagonalTensorMap(similar(d.data, T), d.domain)
end
Base.similar(d::DiagonalTensorMap) = similar_diagonal(d)
Base.similar(d::DiagonalTensorMap, ::Type{T}) where {T} = similar_diagonal(d, T)

similar_diagonal(d::DiagonalTensorMap) = DiagonalTensorMap(similar(d.data), d.domain)
similar_diagonal(d::DiagonalTensorMap, ::Type{T}) where {T <: Number} =
DiagonalTensorMap(similar(d.data, T), d.domain)

# TODO: more constructors needed?

Expand Down
5 changes: 5 additions & 0 deletions src/tensors/sectorvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ function SectorVector{T}(::UndefInitializer, V::ElementarySpace) where {T}
structure = diagonalblockstructure(V ← V)
return SectorVector(data, structure)
end
function SectorVector{T, I, A}(::UndefInitializer, V::ElementarySpace) where {T, I, A <: AbstractVector{T}}
data = A(undef, reduceddim(V))
structure = diagonalblockstructure(V ← V)
return SectorVector{T, I, A}(data, structure)
end

Base.parent(v::SectorVector) = v.data

Expand Down
Loading