Skip to content

Commit fe2c5d5

Browse files
committed
update similar_diagonal to same logic
1 parent 42e22bf commit fe2c5d5

File tree

1 file changed

+5
-14
lines changed

1 file changed

+5
-14
lines changed

src/tensors/abstracttensor.jl

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ Base.similar(::Type{TT}, cod::TensorSpace, dom::TensorSpace) where {TT <: Abstra
586586
# The implementation is again written for similar_diagonal(t, TorA, V::ElementarySpace) -> DiagonalTensorMap
587587
# and all other methods are just filling in default arguments
588588
@doc """
589-
similar_diagonal(t::AbstractTensorMap, [AorT=storagetype(t)], [V::ElementarySpace])
589+
similar_diagonal(t::AbstractTensorMap, [AorT=scalartype(t)], [V::ElementarySpace])
590590
591591
Creates an uninitialized mutable diagonal tensor with the given scalar or storagetype `AorT` and
592592
structure `V ← V`, based on the source tensormap. The second argument is optional and defaults
@@ -602,21 +602,12 @@ See also [`Base.similar`](@ref).
602602

603603
# 3 arguments
604604
function similar_diagonal(t::AbstractTensorMap, ::Type{TorA}, V::ElementarySpace) where {TorA}
605-
if TorA <: Number
606-
T = TorA
607-
A = similarstoragetype(t, T)
608-
elseif TorA <: DenseVector
609-
A = TorA
610-
T = scalartype(A)
611-
else
612-
throw(ArgumentError("Type $TorA not supported for similar"))
613-
end
614-
615-
return DiagonalTensorMap{T, spacetype(V), A}(undef, V)
605+
A = similarstoragetype(TorA <: Number ? similarstoragetype(t, TorA) : TorA)
606+
return DiagonalTensorMap{scalartype(A), spacetype(V), A}(undef, V)
616607
end
617608

618-
similar_diagonal(t::AbstractTensorMap) = similar_diagonal(t, similarstoragetype(t), _diagspace(t))
619-
similar_diagonal(t::AbstractTensorMap, V::ElementarySpace) = similar_diagonal(t, similarstoragetype(t), V)
609+
similar_diagonal(t::AbstractTensorMap) = similar_diagonal(t, scalartype(t), _diagspace(t))
610+
similar_diagonal(t::AbstractTensorMap, V::ElementarySpace) = similar_diagonal(t, scalartype(t), V)
620611
similar_diagonal(t::AbstractTensorMap, T::Type) = similar_diagonal(t, T, _diagspace(t))
621612

622613
function _diagspace(t)

0 commit comments

Comments
 (0)