diff --git a/docs/src/lib/spaces.md b/docs/src/lib/spaces.md index 63ecb7f01..5aa7e8fd9 100644 --- a/docs/src/lib/spaces.md +++ b/docs/src/lib/spaces.md @@ -123,7 +123,7 @@ the resuling `HomSpace` after applying certain tensor operations. ```@docs flip(W::HomSpace{S}, I) where {S} -TensorKit.permute(::HomSpace{S}, ::Index2Tuple{N₁,N₂}) where {S,N₁,N₂} +TensorKit.permute(::HomSpace, ::Index2Tuple) TensorKit.select(::HomSpace{S}, ::Index2Tuple{N₁,N₂}) where {S,N₁,N₂} TensorKit.compose(::HomSpace{S}, ::HomSpace{S}) where {S} insertleftunit(::HomSpace, ::Val{i}) where {i} diff --git a/src/TensorKit.jl b/src/TensorKit.jl index c6e003450..a7342d9fb 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -178,8 +178,11 @@ struct SpaceMismatch{S <: Union{Nothing, AbstractString}} <: TensorException message::S end SpaceMismatch() = SpaceMismatch{Nothing}(nothing) -Base.showerror(io::IO, ::SpaceMismatch{Nothing}) = print(io, "SpaceMismatch()") -Base.showerror(io::IO, e::SpaceMismatch) = print(io, "SpaceMismatch(\"", e.message, "\")") +function Base.showerror(io::IO, err::SpaceMismatch) + print(io, "SpaceMismatch: ") + isnothing(err.message) || print(io, err.message) + return nothing +end # Exception type for all errors related to invalid tensor index specification. struct IndexError{S <: Union{Nothing, AbstractString}} <: TensorException diff --git a/src/spaces/homspace.jl b/src/spaces/homspace.jl index 27df834c4..b4126401a 100644 --- a/src/spaces/homspace.jl +++ b/src/spaces/homspace.jl @@ -39,16 +39,15 @@ end spacetype(::Type{<:HomSpace{S}}) where {S} = S -numout(W::HomSpace) = length(codomain(W)) -numin(W::HomSpace) = length(domain(W)) -numind(W::HomSpace) = numin(W) + numout(W) - const TensorSpace{S <: ElementarySpace} = Union{S, ProductSpace{S}} const TensorMapSpace{S <: ElementarySpace, N₁, N₂} = HomSpace{ S, ProductSpace{S, N₁}, ProductSpace{S, N₂}, } +numout(::Type{TensorMapSpace{S, N₁, N₂}}) where {S, N₁, N₂} = N₁ +numin(::Type{TensorMapSpace{S, N₁, N₂}}) where {S, N₁, N₂} = N₂ + function Base.getindex(W::TensorMapSpace{<:IndexSpace, N₁, N₂}, i) where {N₁, N₂} return i <= N₁ ? codomain(W)[i] : dual(domain(W)[i - N₁]) end @@ -137,18 +136,33 @@ fusiontrees(W::HomSpace) = fusionblockstructure(W).fusiontreelist # Operations on HomSpaces # ----------------------- """ - permute(W::HomSpace, (p₁, p₂)::Index2Tuple{N₁,N₂}) + permute(W::HomSpace, (p₁, p₂)::Index2Tuple) Return the `HomSpace` obtained by permuting the indices of the domain and codomain of `W` according to the permutation `p₁` and `p₂` respectively. """ -function permute(W::HomSpace{S}, (p₁, p₂)::Index2Tuple{N₁, N₂}) where {S, N₁, N₂} +function permute(W::HomSpace, (p₁, p₂)::Index2Tuple) p = (p₁..., p₂...) TupleTools.isperm(p) && length(p) == numind(W) || throw(ArgumentError("$((p₁, p₂)) is not a valid permutation for $(W)")) return select(W, (p₁, p₂)) end +_transpose_indices(W::HomSpace) = (reverse(domainind(W)), reverse(codomainind(W))) + +function LinearAlgebra.transpose(W::HomSpace, (p₁, p₂)::Index2Tuple = _transpose_indices(W)) + p = linearizepermutation(p₁, p₂, numout(W), numin(W)) + iscyclicpermutation(p) || throw(ArgumentError(lazy"$((p₁, p₂)) is not a cyclic permutation for $W")) + return select(W, (p₁, p₂)) +end + +function braid(W::HomSpace, (p₁, p₂)::Index2Tuple, levels::IndexTuple) + p = (p₁..., p₂...) + TupleTools.isperm(p) && length(p) == numind(W) == length(levels) || + throw(ArgumentError("$((p₁, p₂)), $levels is not a valid braiding for $(W)")) + return select(W, (p₁, p₂)) +end + """ select(W::HomSpace, (p₁, p₂)::Index2Tuple{N₁,N₂}) @@ -188,6 +202,30 @@ function compose(W::HomSpace{S}, V::HomSpace{S}) where {S} return HomSpace(codomain(W), domain(V)) end +function TensorOperations.tensorcontract( + A::HomSpace, pA::Index2Tuple, conjA::Bool, + B::HomSpace, pB::Index2Tuple, conjB::Bool, + pAB::Index2Tuple + ) + return if conjA && conjB + A′ = A' + pA′ = adjointtensorindices(A, pA) + B′ = B' + pB′ = adjointtensorindices(B, pB) + TensorOperations.tensorcontract(A′, pA′, false, B′, pB′, false, pAB) + elseif conjA + A′ = A' + pA′ = adjointtensorindices(A, pA) + TensorOperations.tensorcontract(A′, pA′, false, B, pB, false, pAB) + elseif conjB + B′ = B' + pB′ = adjointtensorindices(B, pB) + TensorOperations.tensorcontract(A, pA, false, B′, pB′, false, pAB) + else + return permute(compose(permute(A, pA), permute(B, pB)), pAB) + end +end + """ insertleftunit(W::HomSpace, i=numind(W) + 1; conj=false, dual=false) diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index 393bd905e..4e28000ed 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -89,80 +89,84 @@ domain(t::AbstractTensorMap) = domain(space(t)) domain(t::AbstractTensorMap, i) = domain(t)[i] source(t::AbstractTensorMap) = domain(t) # categorical terminology -""" - numout(::Union{TT,Type{TT}}) where {TT<:AbstractTensorMap} -> Int +@doc """ + numout(x) -> Int + numout(T::Type) -> Int -Return the number of output spaces of a tensor. This is equivalent to the number of spaces in the codomain of that tensor. +Return the length of the codomain, i.e. the number of output spaces. +By default, this is implemented in the type domain. See also [`numin`](@ref) and [`numind`](@ref). -""" +""" numout + +numout(x) = numout(typeof(x)) +numout(T::Type) = throw(MethodError(numout, T)) # avoid infinite recursion numout(::Type{<:AbstractTensorMap{T, S, N₁}}) where {T, S, N₁} = N₁ -""" - numin(::Union{TT,Type{TT}}) where {TT<:AbstractTensorMap} -> Int +@doc """ + numin(x) -> Int + numin(T::Type) -> Int -Return the number of input spaces of a tensor. This is equivalent to the number of spaces in the domain of that tensor. +Return the length of the domain, i.e. the number of input spaces. +By default, this is implemented in the type domain. See also [`numout`](@ref) and [`numind`](@ref). -""" +""" numin + +numin(x) = numin(typeof(x)) +numin(T::Type) = throw(MethodError(numin, T)) # avoid infinite recursion numin(::Type{<:AbstractTensorMap{T, S, N₁, N₂}}) where {T, S, N₁, N₂} = N₂ """ - numind(::Union{T,Type{T}}) where {T<:AbstractTensorMap} -> Int + numind(x) -> Int + numind(T::Type) -> Int + order(x) = numind(x) -Return the total number of input and output spaces of a tensor. This is equivalent to the -total number of spaces in the domain and codomain of that tensor. +Return the total number of input and output spaces, i.e. `numin(x) + numout(x)`. +Alternatively, the alias `order` can also be used. See also [`numout`](@ref) and [`numin`](@ref). """ -numind(::Type{TT}) where {TT <: AbstractTensorMap} = numin(TT) + numout(TT) +numind(x) = numin(x) + numout(x) + const order = numind """ - codomainind(::Union{TT,Type{TT}}) where {TT<:AbstractTensorMap} -> Tuple{Int} + codomainind(x) -> Tuple{Int} -Return all indices of the codomain of a tensor. +Return all indices of the codomain. See also [`domainind`](@ref) and [`allind`](@ref). """ -function codomainind(::Type{TT}) where {TT <: AbstractTensorMap} - return ntuple(identity, numout(TT)) -end -codomainind(t::AbstractTensorMap) = codomainind(typeof(t)) +codomainind(x) = ntuple(identity, numout(x)) """ - domainind(::Union{TT,Type{TT}}) where {TT<:AbstractTensorMap} -> Tuple{Int} + domainind(x) -> Tuple{Int} -Return all indices of the domain of a tensor. +Return all indices of the domain. See also [`codomainind`](@ref) and [`allind`](@ref). """ -function domainind(::Type{TT}) where {TT <: AbstractTensorMap} - return ntuple(n -> numout(TT) + n, numin(TT)) -end -domainind(t::AbstractTensorMap) = domainind(typeof(t)) +domainind(x) = ntuple(n -> numout(x) + n, numin(x)) """ - allind(::Union{TT,Type{TT}}) where {TT<:AbstractTensorMap} -> Tuple{Int} + allind(x) -> Tuple{Int} -Return all indices of a tensor, i.e. the indices of its domain and codomain. +Return all indices, i.e. the indices of both domain and codomain. See also [`codomainind`](@ref) and [`domainind`](@ref). """ -function allind(::Type{TT}) where {TT <: AbstractTensorMap} - return ntuple(identity, numind(TT)) -end -allind(t::AbstractTensorMap) = allind(typeof(t)) +allind(x) = ntuple(identity, numind(x)) -function adjointtensorindex(t::AbstractTensorMap, i) +function adjointtensorindex(t, i) return ifelse(i <= numout(t), numin(t) + i, i - numout(t)) end -function adjointtensorindices(t::AbstractTensorMap, indices::IndexTuple) +function adjointtensorindices(t, indices::IndexTuple) return map(i -> adjointtensorindex(t, i), indices) end -function adjointtensorindices(t::AbstractTensorMap, p::Index2Tuple) +function adjointtensorindices(t, p::Index2Tuple) return (adjointtensorindices(t, p[1]), adjointtensorindices(t, p[2])) end diff --git a/src/tensors/indexmanipulations.jl b/src/tensors/indexmanipulations.jl index 408de09f0..3cfe08a61 100644 --- a/src/tensors/indexmanipulations.jl +++ b/src/tensors/indexmanipulations.jl @@ -175,8 +175,8 @@ the permutation `(p₁..., reverse(p₂)...)` should constitute a cyclic permuta See [`transpose`](@ref) for creating a new tensor and [`add_transpose!`](@ref) for a more general version. """ -function LinearAlgebra.transpose!( - tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple = _transpose_indices(t) +@propagate_inbounds function LinearAlgebra.transpose!( + tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple = _transpose_indices(tsrc) ) return add_transpose!(tdst, tsrc, (p₁, p₂), One(), Zero()) end @@ -229,7 +229,7 @@ case of a transposition that only changes the number of in- and outgoing indices See [`repartition`](@ref) for creating a new tensor. """ -function repartition!(tdst::AbstractTensorMap{S}, tsrc::AbstractTensorMap{S}) where {S} +@propagate_inbounds function repartition!(tdst::AbstractTensorMap{S}, tsrc::AbstractTensorMap{S}) where {S} numind(tsrc) == numind(tdst) || throw(ArgumentError("tsrc and tdst should have an equal amount of indices")) all_inds = (codomainind(tsrc)..., reverse(domainind(tsrc))...) @@ -410,6 +410,38 @@ end #------------------------------------- # Full implementations based on `add` #------------------------------------- +spacecheck_transform(f, tdst::AbstractTensorMap, tsrc::AbstractTensorMap, args...) = + spacecheck_transform(f, space(tdst), space(tsrc), args...) +@noinline function spacecheck_transform(f, Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple) + spacetype(Vdst) == spacetype(Vsrc) || throw(SectorMismatch("incompatible sector types")) + f(Vsrc, p) == Vdst || + throw( + SpaceMismatch( + lazy""" + incompatible spaces for `$f(Vsrc, $p) -> Vdst` + Vsrc = $Vsrc + Vdst = $Vdst + """ + ) + ) + return nothing +end +@noinline function spacecheck_transform(::typeof(braid), Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple, levels::IndexTuple) + spacetype(Vdst) == spacetype(Vsrc) || throw(SectorMismatch("incompatible sector types")) + braid(Vsrc, p, levels) == Vdst || + throw( + SpaceMismatch( + lazy""" + incompatible spaces for `braid(Vsrc, $p, $levels) -> Vdst` + Vsrc = $Vsrc + Vdst = $Vdst + """ + ) + ) + return nothing +end + + """ add_permute!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple, α::Number, β::Number, backend::AbstractBackend...) @@ -423,8 +455,9 @@ See also [`permute`](@ref), [`permute!`](@ref), [`add_braid!`](@ref), [`add_tran tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, α::Number, β::Number, backend::AbstractBackend... ) + @boundscheck spacecheck_transform(permute, tdst, tsrc, p) transformer = treepermuter(tdst, tsrc, p) - return add_transform!(tdst, tsrc, p, transformer, α, β, backend...) + return @inbounds add_transform!(tdst, tsrc, p, transformer, α, β, backend...) end """ @@ -440,14 +473,12 @@ See also [`braid`](@ref), [`braid!`](@ref), [`add_permute!`](@ref), [`add_transp tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, levels::IndexTuple, α::Number, β::Number, backend::AbstractBackend... ) - length(levels) == numind(tsrc) || - throw(ArgumentError("incorrect levels $levels for tensor map $(codomain(tsrc)) ← $(domain(tsrc))")) - + @boundscheck spacecheck_transform(braid, tdst, tsrc, p, levels) levels1 = TupleTools.getindices(levels, codomainind(tsrc)) levels2 = TupleTools.getindices(levels, domainind(tsrc)) # TODO: arg order for tensormaps is different than for fusiontrees transformer = treebraider(tdst, tsrc, p, (levels1, levels2)) - return add_transform!(tdst, tsrc, p, transformer, α, β, backend...) + return @inbounds add_transform!(tdst, tsrc, p, transformer, α, β, backend...) end """ @@ -463,19 +494,16 @@ See also [`transpose`](@ref), [`transpose!`](@ref), [`add_permute!`](@ref), [`ad tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, α::Number, β::Number, backend::AbstractBackend... ) + @boundscheck spacecheck_transform(transpose, tdst, tsrc, p) transformer = treetransposer(tdst, tsrc, p) - return add_transform!(tdst, tsrc, p, transformer, α, β, backend...) + return @inbounds add_transform!(tdst, tsrc, p, transformer, α, β, backend...) end -function add_transform!( +@propagate_inbounds function add_transform!( tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, transformer, α::Number, β::Number, backend::AbstractBackend... ) - @boundscheck begin - permute(space(tsrc), p) == space(tdst) || - throw(SpaceMismatch("source = $(codomain(tsrc))←$(domain(tsrc)), - dest = $(codomain(tdst))←$(domain(tdst)), p₁ = $(p[1]), p₂ = $(p[2])")) - end + @boundscheck spacecheck_transform(permute, tdst, tsrc, p) if p[1] === codomainind(tsrc) && p[2] === domainind(tsrc) add!(tdst, tsrc, α, β) diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index 7bd0ab96b..d22fa45d2 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -90,6 +90,35 @@ function TO.tensortrace!( end # tensorcontract! +function spacecheck_contract( + C::AbstractTensorMap, + A::AbstractTensorMap, pA::Index2Tuple, conjA::Bool, + B::AbstractTensorMap, pB::Index2Tuple, conjB::Bool, + pAB::Index2Tuple + ) + return spacecheck_contract(space(C), space(A), pA, conjA, space(B), pB, conjB, pAB) +end +@noinline function spacecheck_contract( + VC::TensorMapSpace, + VA::TensorMapSpace, pA::Index2Tuple, conjA::Bool, + VB::TensorMapSpace, pB::Index2Tuple, conjB::Bool, + pAB::Index2Tuple + ) + spacetype(VC) == spacetype(VA) == spacetype(VB) || throw(SectorMismatch("incompatible sector types")) + TO.tensorcontract(VA, pA, conjA, VB, pB, conjB, pAB) == VC || + throw( + SpaceMismatch( + lazy""" + incompatible spaces for `tensorcontract(VA, $pA, $conjA, VB, $pB, $conjB, $pAB) -> VC` + VA = $VA + VB = $VB + VC = $VC + """ + ) + ) + return nothing +end + function TO.tensorcontract!( C::AbstractTensorMap, A::AbstractTensorMap, pA::Index2Tuple, conjA::Bool, @@ -98,6 +127,7 @@ function TO.tensorcontract!( backend, allocator ) pAB′ = _canonicalize(pAB, C) + @boundscheck spacecheck_contract(C, A, pA, conjA, B, pB, conjB, pAB′) if conjA && conjB A′ = A' pA′ = adjointtensorindices(A, pA)