diff --git a/docs/src/lib/spaces.md b/docs/src/lib/spaces.md index 2a2b5696d..554e40c43 100644 --- a/docs/src/lib/spaces.md +++ b/docs/src/lib/spaces.md @@ -90,6 +90,7 @@ dual conj flip ⊕ +zero(::ElementarySpace) oneunit supremum infimum diff --git a/src/spaces/cartesianspace.jl b/src/spaces/cartesianspace.jl index c790b9ba6..f29ed263d 100644 --- a/src/spaces/cartesianspace.jl +++ b/src/spaces/cartesianspace.jl @@ -48,6 +48,7 @@ sectors(V::CartesianSpace) = OneOrNoneIterator(dim(V) != 0, Trivial()) sectortype(::Type{CartesianSpace}) = Trivial Base.oneunit(::Type{CartesianSpace}) = CartesianSpace(1) +Base.zero(::Type{CartesianSpace}) = CartesianSpace(0) ⊕(V₁::CartesianSpace, V₂::CartesianSpace) = CartesianSpace(V₁.d + V₂.d) fuse(V₁::CartesianSpace, V₂::CartesianSpace) = CartesianSpace(V₁.d * V₂.d) flip(V::CartesianSpace) = V diff --git a/src/spaces/complexspace.jl b/src/spaces/complexspace.jl index 96f99492f..ff05888b8 100644 --- a/src/spaces/complexspace.jl +++ b/src/spaces/complexspace.jl @@ -49,6 +49,7 @@ sectortype(::Type{ComplexSpace}) = Trivial Base.conj(V::ComplexSpace) = ComplexSpace(dim(V), !isdual(V)) Base.oneunit(::Type{ComplexSpace}) = ComplexSpace(1) +Base.zero(::Type{ComplexSpace}) = ComplexSpace(0) function ⊕(V₁::ComplexSpace, V₂::ComplexSpace) return isdual(V₁) == isdual(V₂) ? ComplexSpace(dim(V₁) + dim(V₂), isdual(V₁)) : diff --git a/src/spaces/generalspace.jl b/src/spaces/generalspace.jl index 4468db443..c72b55e70 100644 --- a/src/spaces/generalspace.jl +++ b/src/spaces/generalspace.jl @@ -35,6 +35,9 @@ sectortype(::Type{<:GeneralSpace}) = Trivial field(::Type{GeneralSpace{𝔽}}) where {𝔽} = 𝔽 InnerProductStyle(::Type{<:GeneralSpace}) = NoInnerProduct() +Base.oneunit(::Type{GeneralSpace{𝔽}}) where {𝔽} = GeneralSpace{𝔽}(1, false, false) +Base.zero(::Type{GeneralSpace{𝔽}}) where {𝔽} = GeneralSpace{𝔽}(0, false, false) + dual(V::GeneralSpace{𝔽}) where {𝔽} = GeneralSpace{𝔽}(dim(V), !isdual(V), isconj(V)) Base.conj(V::GeneralSpace{𝔽}) where {𝔽} = GeneralSpace{𝔽}(dim(V), isdual(V), !isconj(V)) diff --git a/src/spaces/gradedspace.jl b/src/spaces/gradedspace.jl index 23c380fda..00f97c962 100644 --- a/src/spaces/gradedspace.jl +++ b/src/spaces/gradedspace.jl @@ -132,6 +132,7 @@ function Base.axes(V::GradedSpace{I}, c::I) where {I<:Sector} end Base.oneunit(S::Type{<:GradedSpace{I}}) where {I<:Sector} = S(one(I) => 1) +Base.zero(S::Type{<:GradedSpace{I}}) where {I<:Sector} = S(one(I) => 0) # TODO: the following methods can probably be implemented more efficiently for # `FiniteGradedSpace`, but we don't expect them to be used often in hot loops, so diff --git a/src/spaces/vectorspaces.jl b/src/spaces/vectorspaces.jl index 1468892e5..e2f39007f 100644 --- a/src/spaces/vectorspaces.jl +++ b/src/spaces/vectorspaces.jl @@ -128,6 +128,14 @@ that this is different from `one(V::S)`, which returns the empty product space """ Base.oneunit(V::ElementarySpace) = oneunit(typeof(V)) +""" + zero(V::S) where {S<:ElementarySpace} -> S + +Return the corresponding vector space of type `S` that represents the zero-dimensional or empty space. +This is, with a slight abuse of notation, the zero element of the direct sum of vector spaces. +""" +Base.zero(V::ElementarySpace) = zero(typeof(V)) + """ ⊕(V₁::S, V₂::S, V₃::S...) where {S<:ElementarySpace} -> S oplus(V₁::S, V₂::S, V₃::S...) where {S<:ElementarySpace} -> S diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index ebc1788fc..8c86f81ea 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -497,21 +497,13 @@ function Base.convert(::Type{Array}, t::AbstractTensorMap) else cod = codomain(t) dom = domain(t) - local A + T = sectorscalartype(I) <: Complex ? complex(scalartype(t)) : + sectorscalartype(I) <: Integer ? scalartype(t) : float(scalartype(t)) + A = zeros(T, dims(cod)..., dims(dom)...) for (f₁, f₂) in fusiontrees(t) F = convert(Array, (f₁, f₂)) - if !(@isdefined A) - if eltype(F) <: Complex - T = complex(float(scalartype(t))) - elseif eltype(F) <: Integer - T = scalartype(t) - else - T = float(scalartype(t)) - end - A = fill(zero(T), (dims(cod)..., dims(dom)...)) - end Aslice = StridedView(A)[axes(cod, f₁.uncoupled)..., axes(dom, f₂.uncoupled)...] - axpy!(1, StridedView(_kron(convert(Array, t[f₁, f₂]), F)), Aslice) + add!(Aslice, StridedView(_kron(convert(Array, t[f₁, f₂]), F))) end return A end diff --git a/test/bugfixes.jl b/test/bugfixes.jl index d539c01e0..e1b2df1e5 100644 --- a/test/bugfixes.jl +++ b/test/bugfixes.jl @@ -22,4 +22,11 @@ @test w == v @test scalartype(w) == Float64 end + + # https://github.com/Jutho/TensorKit.jl/issues/178 + @testset "Issue #178" begin + t = rand(U1Space(1 => 1) ← U1Space(1 => 1)') + a = convert(Array, t) + @test a == zeros(size(a)) + end end diff --git a/test/spaces.jl b/test/spaces.jl index c3890db0c..57592ca35 100644 --- a/test/spaces.jl +++ b/test/spaces.jl @@ -66,12 +66,14 @@ println("------------------------------------") @test length(sectors(V)) == 1 @test @constinferred(TensorKit.hassector(V, Trivial())) @test @constinferred(dim(V)) == d == @constinferred(dim(V, Trivial())) - @test dim(@constinferred(typeof(V)())) == 0 - @test (sectors(typeof(V)())...,) == () + @test dim(@constinferred(zero(V))) == 0 + @test (sectors(zero(V))...,) == () @test @constinferred(TensorKit.axes(V)) == Base.OneTo(d) @test ℝ^d == ℝ[](d) == CartesianSpace(d) == typeof(V)(d) W = @constinferred ℝ^1 @test @constinferred(oneunit(V)) == W == oneunit(typeof(V)) + @test @constinferred(zero(V)) == ℝ^0 == zero(typeof(V)) + @test @constinferred(⊕(V, zero(V))) == V @test @constinferred(⊕(V, V)) == ℝ^(2d) @test @constinferred(⊕(V, oneunit(V))) == ℝ^(d + 1) @test @constinferred(⊕(V, V, V, V)) == ℝ^(4d) @@ -111,12 +113,14 @@ println("------------------------------------") @test length(sectors(V)) == 1 @test @constinferred(TensorKit.hassector(V, Trivial())) @test @constinferred(dim(V)) == d == @constinferred(dim(V, Trivial())) - @test dim(@constinferred(typeof(V)())) == 0 - @test (sectors(typeof(V)())...,) == () + @test dim(@constinferred(zero(V))) == 0 + @test (sectors(zero(V))...,) == () @test @constinferred(TensorKit.axes(V)) == Base.OneTo(d) @test ℂ^d == Vect[Trivial](d) == Vect[](Trivial() => d) == ℂ[](d) == typeof(V)(d) W = @constinferred ℂ^1 @test @constinferred(oneunit(V)) == W == oneunit(typeof(V)) + @test @constinferred(zero(V)) == ℂ^0 == zero(typeof(V)) + @test @constinferred(⊕(V, zero(V))) == V @test @constinferred(⊕(V, V)) == ℂ^(2d) @test_throws SpaceMismatch (⊕(V, V')) # promote_except = ErrorException("promotion of types $(typeof(ℝ^d)) and " * @@ -200,11 +204,12 @@ println("------------------------------------") @test eval(Meta.parse(sprint(show, V))) == V @test eval(Meta.parse(sprint(show, typeof(V)))) == typeof(V) # space with no sectors - @test dim(@constinferred(typeof(V)())) == 0 + @test dim(@constinferred(zero(V))) == 0 # space with a single sector W = @constinferred GradedSpace(one(I) => 1) @test W == GradedSpace(one(I) => 1, randsector(I) => 0) @test @constinferred(oneunit(V)) == W == oneunit(typeof(V)) + @test @constinferred(zero(V)) == GradedSpace(one(I) => 0) # randsector never returns trivial sector, so this cannot error @test_throws ArgumentError GradedSpace(one(I) => 1, randsector(I) => 0, one(I) => 3) @test eval(Meta.parse(sprint(show, W))) == W @@ -226,6 +231,7 @@ println("------------------------------------") if hasfusiontensor(I) @test @constinferred(TensorKit.axes(V)) == Base.OneTo(dim(V)) end + @test @constinferred(⊕(V, zero(V))) == V @test @constinferred(⊕(V, V)) == Vect[I](c => 2dim(V, c) for c in sectors(V)) @test @constinferred(⊕(V, V, V, V)) == Vect[I](c => 4dim(V, c) for c in sectors(V)) @test @constinferred(⊕(V, oneunit(V))) == diff --git a/test/tensors.jl b/test/tensors.jl index 99d73b7c8..545ac6cc6 100644 --- a/test/tensors.jl +++ b/test/tensors.jl @@ -126,6 +126,11 @@ for V in spacelist @test t === @constinferred TensorMap(t.data, W) end end + for T in (Int, Float32, ComplexF64) + t = randn(T, V1 ⊗ V2 ← zero(V1)) + a = convert(Array, t) + @test norm(a) == 0 + end end end @timedtestset "Basic linear algebra" begin @@ -466,7 +471,7 @@ for V in spacelist end end @testset "empty tensor" begin - t = randn(T, V1 ⊗ V2, typeof(V1)()) + t = randn(T, V1 ⊗ V2, zero(V1)) @testset "leftorth with $alg" for alg in (TensorKit.QR(), TensorKit.QRpos(), TensorKit.QL(), TensorKit.QLpos(),