Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/lib/spaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ the resuling `HomSpace` after applying certain tensor operations.

```@docs
flip(W::HomSpace{S}, I) where {S}
TensorKit.permute(::HomSpace{S}, ::Index2Tuple{N₁,N₂}) where {S,N₁,N₂}
TensorKit.permute(::HomSpace, ::Index2Tuple)
TensorKit.select(::HomSpace{S}, ::Index2Tuple{N₁,N₂}) where {S,N₁,N₂}
TensorKit.compose(::HomSpace{S}, ::HomSpace{S}) where {S}
insertleftunit(::HomSpace, ::Val{i}) where {i}
Expand Down
7 changes: 5 additions & 2 deletions src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,11 @@ struct SpaceMismatch{S <: Union{Nothing, AbstractString}} <: TensorException
message::S
end
SpaceMismatch() = SpaceMismatch{Nothing}(nothing)
Base.showerror(io::IO, ::SpaceMismatch{Nothing}) = print(io, "SpaceMismatch()")
Base.showerror(io::IO, e::SpaceMismatch) = print(io, "SpaceMismatch(\"", e.message, "\")")
function Base.showerror(io::IO, err::SpaceMismatch)
print(io, "SpaceMismatch: ")
isnothing(err.message) || print(io, err.message)
return nothing
end

# Exception type for all errors related to invalid tensor index specification.
struct IndexError{S <: Union{Nothing, AbstractString}} <: TensorException
Expand Down
50 changes: 44 additions & 6 deletions src/spaces/homspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,15 @@ end

spacetype(::Type{<:HomSpace{S}}) where {S} = S

numout(W::HomSpace) = length(codomain(W))
numin(W::HomSpace) = length(domain(W))
numind(W::HomSpace) = numin(W) + numout(W)

const TensorSpace{S <: ElementarySpace} = Union{S, ProductSpace{S}}
const TensorMapSpace{S <: ElementarySpace, N₁, N₂} = HomSpace{
S, ProductSpace{S, N₁},
ProductSpace{S, N₂},
}

numout(::Type{TensorMapSpace{S, N₁, N₂}}) where {S, N₁, N₂} = N₁
numin(::Type{TensorMapSpace{S, N₁, N₂}}) where {S, N₁, N₂} = N₂

function Base.getindex(W::TensorMapSpace{<:IndexSpace, N₁, N₂}, i) where {N₁, N₂}
return i <= N₁ ? codomain(W)[i] : dual(domain(W)[i - N₁])
end
Expand Down Expand Up @@ -137,18 +136,33 @@ fusiontrees(W::HomSpace) = fusionblockstructure(W).fusiontreelist
# Operations on HomSpaces
# -----------------------
"""
permute(W::HomSpace, (p₁, p₂)::Index2Tuple{N₁,N₂})
permute(W::HomSpace, (p₁, p₂)::Index2Tuple)

Return the `HomSpace` obtained by permuting the indices of the domain and codomain of `W`
according to the permutation `p₁` and `p₂` respectively.
"""
function permute(W::HomSpace{S}, (p₁, p₂)::Index2Tuple{N₁, N₂}) where {S, N₁, N₂}
function permute(W::HomSpace, (p₁, p₂)::Index2Tuple)
p = (p₁..., p₂...)
TupleTools.isperm(p) && length(p) == numind(W) ||
throw(ArgumentError("$((p₁, p₂)) is not a valid permutation for $(W)"))
return select(W, (p₁, p₂))
end

_transpose_indices(W::HomSpace) = (reverse(domainind(W)), reverse(codomainind(W)))

function LinearAlgebra.transpose(W::HomSpace, (p₁, p₂)::Index2Tuple = _transpose_indices(W))
p = linearizepermutation(p₁, p₂, numout(W), numin(W))
iscyclicpermutation(p) || throw(ArgumentError(lazy"$((p₁, p₂)) is not a cyclic permutation for $W"))
return select(W, (p₁, p₂))
end

function braid(W::HomSpace, (p₁, p₂)::Index2Tuple, levels::IndexTuple)
p = (p₁..., p₂...)
TupleTools.isperm(p) && length(p) == numind(W) == length(levels) ||
throw(ArgumentError("$((p₁, p₂)), $levels is not a valid braiding for $(W)"))
return select(W, (p₁, p₂))
end

"""
select(W::HomSpace, (p₁, p₂)::Index2Tuple{N₁,N₂})

Expand Down Expand Up @@ -188,6 +202,30 @@ function compose(W::HomSpace{S}, V::HomSpace{S}) where {S}
return HomSpace(codomain(W), domain(V))
end

function TensorOperations.tensorcontract(
A::HomSpace, pA::Index2Tuple, conjA::Bool,
B::HomSpace, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple
)
return if conjA && conjB
A′ = A'
pA′ = adjointtensorindices(A, pA)
B′ = B'
pB′ = adjointtensorindices(B, pB)
TensorOperations.tensorcontract(A′, pA′, false, B′, pB′, false, pAB)
elseif conjA
A′ = A'
pA′ = adjointtensorindices(A, pA)
TensorOperations.tensorcontract(A′, pA′, false, B, pB, false, pAB)
elseif conjB
B′ = B'
pB′ = adjointtensorindices(B, pB)
TensorOperations.tensorcontract(A, pA, false, B′, pB′, false, pAB)
else
return permute(compose(permute(A, pA), permute(B, pB)), pAB)
end
end

"""
insertleftunit(W::HomSpace, i=numind(W) + 1; conj=false, dual=false)

Expand Down
70 changes: 37 additions & 33 deletions src/tensors/abstracttensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,80 +89,84 @@ domain(t::AbstractTensorMap) = domain(space(t))
domain(t::AbstractTensorMap, i) = domain(t)[i]
source(t::AbstractTensorMap) = domain(t) # categorical terminology

"""
numout(::Union{TT,Type{TT}}) where {TT<:AbstractTensorMap} -> Int
@doc """
numout(x) -> Int
numout(T::Type) -> Int

Return the number of output spaces of a tensor. This is equivalent to the number of spaces in the codomain of that tensor.
Return the length of the codomain, i.e. the number of output spaces.
By default, this is implemented in the type domain.

See also [`numin`](@ref) and [`numind`](@ref).
"""
""" numout

numout(x) = numout(typeof(x))
numout(T::Type) = throw(MethodError(numout, T)) # avoid infinite recursion
numout(::Type{<:AbstractTensorMap{T, S, N₁}}) where {T, S, N₁} = N₁

"""
numin(::Union{TT,Type{TT}}) where {TT<:AbstractTensorMap} -> Int
@doc """
numin(x) -> Int
numin(T::Type) -> Int

Return the number of input spaces of a tensor. This is equivalent to the number of spaces in the domain of that tensor.
Return the length of the domain, i.e. the number of input spaces.
By default, this is implemented in the type domain.

See also [`numout`](@ref) and [`numind`](@ref).
"""
""" numin

numin(x) = numin(typeof(x))
numin(T::Type) = throw(MethodError(numin, T)) # avoid infinite recursion
numin(::Type{<:AbstractTensorMap{T, S, N₁, N₂}}) where {T, S, N₁, N₂} = N₂

"""
numind(::Union{T,Type{T}}) where {T<:AbstractTensorMap} -> Int
numind(x) -> Int
numind(T::Type) -> Int
order(x) = numind(x)

Return the total number of input and output spaces of a tensor. This is equivalent to the
total number of spaces in the domain and codomain of that tensor.
Return the total number of input and output spaces, i.e. `numin(x) + numout(x)`.
Alternatively, the alias `order` can also be used.

See also [`numout`](@ref) and [`numin`](@ref).
"""
numind(::Type{TT}) where {TT <: AbstractTensorMap} = numin(TT) + numout(TT)
numind(x) = numin(x) + numout(x)

const order = numind

"""
codomainind(::Union{TT,Type{TT}}) where {TT<:AbstractTensorMap} -> Tuple{Int}
codomainind(x) -> Tuple{Int}

Return all indices of the codomain of a tensor.
Return all indices of the codomain.

See also [`domainind`](@ref) and [`allind`](@ref).
"""
function codomainind(::Type{TT}) where {TT <: AbstractTensorMap}
return ntuple(identity, numout(TT))
end
codomainind(t::AbstractTensorMap) = codomainind(typeof(t))
codomainind(x) = ntuple(identity, numout(x))

"""
domainind(::Union{TT,Type{TT}}) where {TT<:AbstractTensorMap} -> Tuple{Int}
domainind(x) -> Tuple{Int}

Return all indices of the domain of a tensor.
Return all indices of the domain.

See also [`codomainind`](@ref) and [`allind`](@ref).
"""
function domainind(::Type{TT}) where {TT <: AbstractTensorMap}
return ntuple(n -> numout(TT) + n, numin(TT))
end
domainind(t::AbstractTensorMap) = domainind(typeof(t))
domainind(x) = ntuple(n -> numout(x) + n, numin(x))

"""
allind(::Union{TT,Type{TT}}) where {TT<:AbstractTensorMap} -> Tuple{Int}
allind(x) -> Tuple{Int}

Return all indices of a tensor, i.e. the indices of its domain and codomain.
Return all indices, i.e. the indices of both domain and codomain.

See also [`codomainind`](@ref) and [`domainind`](@ref).
"""
function allind(::Type{TT}) where {TT <: AbstractTensorMap}
return ntuple(identity, numind(TT))
end
allind(t::AbstractTensorMap) = allind(typeof(t))
allind(x) = ntuple(identity, numind(x))

function adjointtensorindex(t::AbstractTensorMap, i)
function adjointtensorindex(t, i)
return ifelse(i <= numout(t), numin(t) + i, i - numout(t))
end

function adjointtensorindices(t::AbstractTensorMap, indices::IndexTuple)
function adjointtensorindices(t, indices::IndexTuple)
return map(i -> adjointtensorindex(t, i), indices)
end

function adjointtensorindices(t::AbstractTensorMap, p::Index2Tuple)
function adjointtensorindices(t, p::Index2Tuple)
return (adjointtensorindices(t, p[1]), adjointtensorindices(t, p[2]))
end

Expand Down
58 changes: 43 additions & 15 deletions src/tensors/indexmanipulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ the permutation `(p₁..., reverse(p₂)...)` should constitute a cyclic permuta

See [`transpose`](@ref) for creating a new tensor and [`add_transpose!`](@ref) for a more general version.
"""
function LinearAlgebra.transpose!(
tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple = _transpose_indices(t)
@propagate_inbounds function LinearAlgebra.transpose!(
tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple = _transpose_indices(tsrc)
)
return add_transpose!(tdst, tsrc, (p₁, p₂), One(), Zero())
end
Expand Down Expand Up @@ -229,7 +229,7 @@ case of a transposition that only changes the number of in- and outgoing indices

See [`repartition`](@ref) for creating a new tensor.
"""
function repartition!(tdst::AbstractTensorMap{S}, tsrc::AbstractTensorMap{S}) where {S}
@propagate_inbounds function repartition!(tdst::AbstractTensorMap{S}, tsrc::AbstractTensorMap{S}) where {S}
numind(tsrc) == numind(tdst) ||
throw(ArgumentError("tsrc and tdst should have an equal amount of indices"))
all_inds = (codomainind(tsrc)..., reverse(domainind(tsrc))...)
Expand Down Expand Up @@ -410,6 +410,38 @@ end
#-------------------------------------
# Full implementations based on `add`
#-------------------------------------
spacecheck_transform(f, tdst::AbstractTensorMap, tsrc::AbstractTensorMap, args...) =
spacecheck_transform(f, space(tdst), space(tsrc), args...)
@noinline function spacecheck_transform(f, Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple)
spacetype(Vdst) == spacetype(Vsrc) || throw(SectorMismatch("incompatible sector types"))
f(Vsrc, p) == Vdst ||
throw(
SpaceMismatch(
lazy"""
incompatible spaces for `$f(Vsrc, $p) -> Vdst`
Vsrc = $Vsrc
Vdst = $Vdst
"""
)
)
return nothing
end
@noinline function spacecheck_transform(::typeof(braid), Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple, levels::IndexTuple)
spacetype(Vdst) == spacetype(Vsrc) || throw(SectorMismatch("incompatible sector types"))
braid(Vsrc, p, levels) == Vdst ||
throw(
SpaceMismatch(
lazy"""
incompatible spaces for `braid(Vsrc, $p, $levels) -> Vdst`
Vsrc = $Vsrc
Vdst = $Vdst
"""
)
)
return nothing
end


"""
add_permute!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple,
α::Number, β::Number, backend::AbstractBackend...)
Expand All @@ -423,8 +455,9 @@ See also [`permute`](@ref), [`permute!`](@ref), [`add_braid!`](@ref), [`add_tran
tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple,
α::Number, β::Number, backend::AbstractBackend...
)
@boundscheck spacecheck_transform(permute, tdst, tsrc, p)
transformer = treepermuter(tdst, tsrc, p)
return add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
return @inbounds add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
end

"""
Expand All @@ -440,14 +473,12 @@ See also [`braid`](@ref), [`braid!`](@ref), [`add_permute!`](@ref), [`add_transp
tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, levels::IndexTuple,
α::Number, β::Number, backend::AbstractBackend...
)
length(levels) == numind(tsrc) ||
throw(ArgumentError("incorrect levels $levels for tensor map $(codomain(tsrc)) ← $(domain(tsrc))"))

@boundscheck spacecheck_transform(braid, tdst, tsrc, p, levels)
levels1 = TupleTools.getindices(levels, codomainind(tsrc))
levels2 = TupleTools.getindices(levels, domainind(tsrc))
# TODO: arg order for tensormaps is different than for fusiontrees
transformer = treebraider(tdst, tsrc, p, (levels1, levels2))
return add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
return @inbounds add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
end

"""
Expand All @@ -463,19 +494,16 @@ See also [`transpose`](@ref), [`transpose!`](@ref), [`add_permute!`](@ref), [`ad
tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple,
α::Number, β::Number, backend::AbstractBackend...
)
@boundscheck spacecheck_transform(transpose, tdst, tsrc, p)
transformer = treetransposer(tdst, tsrc, p)
return add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
return @inbounds add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
end

function add_transform!(
@propagate_inbounds function add_transform!(
tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, transformer,
α::Number, β::Number, backend::AbstractBackend...
)
@boundscheck begin
permute(space(tsrc), p) == space(tdst) ||
throw(SpaceMismatch("source = $(codomain(tsrc))←$(domain(tsrc)),
dest = $(codomain(tdst))←$(domain(tdst)), p₁ = $(p[1]), p₂ = $(p[2])"))
end
@boundscheck spacecheck_transform(permute, tdst, tsrc, p)

if p[1] === codomainind(tsrc) && p[2] === domainind(tsrc)
add!(tdst, tsrc, α, β)
Expand Down
30 changes: 30 additions & 0 deletions src/tensors/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,35 @@ function TO.tensortrace!(
end

# tensorcontract!
function spacecheck_contract(
C::AbstractTensorMap,
A::AbstractTensorMap, pA::Index2Tuple, conjA::Bool,
B::AbstractTensorMap, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple
)
return spacecheck_contract(space(C), space(A), pA, conjA, space(B), pB, conjB, pAB)
end
@noinline function spacecheck_contract(
VC::TensorMapSpace,
VA::TensorMapSpace, pA::Index2Tuple, conjA::Bool,
VB::TensorMapSpace, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple
)
spacetype(VC) == spacetype(VA) == spacetype(VB) || throw(SectorMismatch("incompatible sector types"))
TO.tensorcontract(VA, pA, conjA, VB, pB, conjB, pAB) == VC ||
throw(
SpaceMismatch(
lazy"""
incompatible spaces for `tensorcontract(VA, $pA, $conjA, VB, $pB, $conjB, $pAB) -> VC`
VA = $VA
VB = $VB
VC = $VC
"""
)
)
return nothing
end

function TO.tensorcontract!(
C::AbstractTensorMap,
A::AbstractTensorMap, pA::Index2Tuple, conjA::Bool,
Expand All @@ -98,6 +127,7 @@ function TO.tensorcontract!(
backend, allocator
)
pAB′ = _canonicalize(pAB, C)
@boundscheck spacecheck_contract(C, A, pA, conjA, B, pB, conjB, pAB′)
if conjA && conjB
A′ = A'
pA′ = adjointtensorindices(A, pA)
Expand Down
Loading