diff --git a/Project.toml b/Project.toml index b05c4cc..63c7d4c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.5.3" +version = "0.5.4" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/ext/TensorAlgebraTensorOperationsExt.jl b/ext/TensorAlgebraTensorOperationsExt.jl index cc05324..050f223 100644 --- a/ext/TensorAlgebraTensorOperationsExt.jl +++ b/ext/TensorAlgebraTensorOperationsExt.jl @@ -1,6 +1,6 @@ module TensorAlgebraTensorOperationsExt -using TensorAlgebra: TensorAlgebra, BlockedPermutation, Algorithm, blocklengths +using TensorAlgebra: TensorAlgebra, BlockedPermutation, ContractAlgorithm, blocklengths using TupleTools: TupleTools using TensorOperations: TensorOperations, AbstractBackend, DefaultBackend, Index2Tuple @@ -9,11 +9,13 @@ using TensorOperations: TensorOperations, AbstractBackend, DefaultBackend, Index Wrapper type for making a TensorOperations backend work as a TensorAlgebra algorithm. """ -struct TensorOperationsAlgorithm{B <: AbstractBackend} <: Algorithm +struct TensorOperationsAlgorithm{B <: AbstractBackend} <: ContractAlgorithm backend::B end -TensorAlgebra.Algorithm(backend::AbstractBackend) = TensorOperationsAlgorithm(backend) +function TensorAlgebra.ContractAlgorithm(backend::AbstractBackend) + return TensorOperationsAlgorithm(backend) +end trivtuple(n) = ntuple(identity, n) @@ -111,7 +113,7 @@ function TensorOperations.tensorcontract!( pAB::Index2Tuple, α::Number, β::Number, - backend::Algorithm, + backend::ContractAlgorithm, allocator, ) bipermA = _blockedpermutation(pA) @@ -131,7 +133,7 @@ function TensorOperations.tensortrace!( conjA::Bool, α::Number, β::Number, - ::Algorithm, + ::ContractAlgorithm, allocator, ) return TensorOperations.tensortrace!(C, A, p, q, conjA, α, β, DefaultBackend(), allocator) @@ -143,7 +145,7 @@ function TensorOperations.tensoradd!( conjA::Bool, α::Number, β::Number, - ::Algorithm, + ::ContractAlgorithm, allocator, ) return TensorOperations.tensoradd!(C, A, pA, conjA, α, β, DefaultBackend(), allocator) diff --git a/src/contract/blockedperms.jl b/src/contract/blockedperms.jl index 790c1d9..76baafa 100644 --- a/src/contract/blockedperms.jl +++ b/src/contract/blockedperms.jl @@ -17,7 +17,7 @@ length_domain(t) = 0 length_codomain(t) = length(t) - length_domain(t) function blockedperms( - f::typeof(contract), alg::Algorithm, dimnames_dest, dimnames1, dimnames2 + f::typeof(contract), alg::ContractAlgorithm, dimnames_dest, dimnames1, dimnames2 ) return blockedperms(f, dimnames_dest, dimnames1, dimnames2) end diff --git a/src/contract/contract.jl b/src/contract/contract.jl index a943628..a77eecc 100644 --- a/src/contract/contract.jl +++ b/src/contract/contract.jl @@ -1,62 +1,47 @@ # TODO: Add `contract!!` definitions as pass-throughs to `contract!`. -abstract type Algorithm end +abstract type ContractAlgorithm end -Algorithm(alg::Algorithm) = alg +ContractAlgorithm(alg::ContractAlgorithm) = alg -struct Matricize{Style} <: Algorithm +struct Matricize{Style} <: ContractAlgorithm fusion_style::Style end Matricize() = Matricize(ReshapeFusion()) -function default_contract_alg(a1::AbstractArray, labels1, a2::AbstractArray, labels2) - style1 = FusionStyle(a1) - style2 = FusionStyle(a2) +function default_contract_algorithm(A1::Type{<:AbstractArray}, A2::Type{<:AbstractArray}) + style1 = FusionStyle(A1) + style2 = FusionStyle(A2) style1 == style2 || error("Styles must match.") return Matricize(style1) end -function default_contractadd!_alg( - a_dest::AbstractArray, labels_dest, - a1::AbstractArray, labels1, - a2::AbstractArray, labels2, - α::Number, β::Number, - ) - style_dest = FusionStyle(a_dest) - style1 = FusionStyle(a1) - style2 = FusionStyle(a2) - style_dest == style1 == style2 || error("Styles must match.") - return Matricize(style_dest) -end # Required interface if not using # matricized contraction. function contractadd!( - alg::Algorithm, - a_dest::AbstractArray, - biperm_dest::AbstractBlockPermutation, - a1::AbstractArray, - biperm1::AbstractBlockPermutation, - a2::AbstractArray, - biperm2::AbstractBlockPermutation, - α::Number, - β::Number, + alg::ContractAlgorithm, + a_dest::AbstractArray, biperm_dest::AbstractBlockPermutation{2}, + a1::AbstractArray, biperm1::AbstractBlockPermutation{2}, + a2::AbstractArray, biperm2::AbstractBlockPermutation{2}, + α::Number, β::Number, ) return error("Not implemented") end function contract( - a1::AbstractArray, - labels1, - a2::AbstractArray, - labels2; - alg = default_contract_alg(a1, labels1, a2, labels2), + a1::AbstractArray, labels1, + a2::AbstractArray, labels2; + alg = default_contract_algorithm(typeof(a1), typeof(a2)), kwargs..., ) - return contract(Algorithm(alg), a1, labels1, a2, labels2; kwargs...) + return contract(ContractAlgorithm(alg), a1, labels1, a2, labels2; kwargs...) end function contract( - alg::Algorithm, a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs... + alg::ContractAlgorithm, + a1::AbstractArray, labels1, + a2::AbstractArray, labels2; + kwargs..., ) labels_dest = output_labels(contract, alg, a1, labels1, a2, labels2; kwargs...) return contract(alg, labels_dest, a1, labels1, a2, labels2; kwargs...), labels_dest @@ -68,49 +53,40 @@ function contract( labels1, a2::AbstractArray, labels2; - alg = default_contract_alg(a1, labels1, a2, labels2), + alg = default_contract_algorithm(typeof(a1), typeof(a2)), kwargs..., ) - return contract(Algorithm(alg), labels_dest, a1, labels1, a2, labels2; kwargs...) + return contract(ContractAlgorithm(alg), labels_dest, a1, labels1, a2, labels2; kwargs...) end function contract!( - a_dest::AbstractArray, - labels_dest, - a1::AbstractArray, - labels1, - a2::AbstractArray, - labels2; + a_dest::AbstractArray, labels_dest, + a1::AbstractArray, labels1, + a2::AbstractArray, labels2; kwargs..., ) return contractadd!(a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...) end function contractadd!( - a_dest::AbstractArray, - labels_dest, - a1::AbstractArray, - labels1, - a2::AbstractArray, - labels2, - α::Number, - β::Number; - alg = default_contractadd!_alg(a_dest, labels_dest, a1, labels1, a2, labels2, α, β), + a_dest::AbstractArray, labels_dest, + a1::AbstractArray, labels1, + a2::AbstractArray, labels2, + α::Number, β::Number; + alg = default_contract_algorithm(typeof(a1), typeof(a2)), kwargs..., ) contractadd!( - Algorithm(alg), a_dest, labels_dest, a1, labels1, a2, labels2, α, β; kwargs... + ContractAlgorithm(alg), a_dest, labels_dest, a1, labels1, a2, labels2, α, β; kwargs... ) return a_dest end function contract( - alg::Algorithm, + alg::ContractAlgorithm, labels_dest, - a1::AbstractArray, - labels1, - a2::AbstractArray, - labels2; + a1::AbstractArray, labels1, + a2::AbstractArray, labels2; kwargs..., ) check_input(contract, a1, labels1, a2, labels2) @@ -119,13 +95,10 @@ function contract( end function contract!( - alg::Algorithm, - a_dest::AbstractArray, - labels_dest, - a1::AbstractArray, - labels1, - a2::AbstractArray, - labels2; + alg::ContractAlgorithm, + a_dest::AbstractArray, labels_dest, + a1::AbstractArray, labels1, + a2::AbstractArray, labels2; kwargs..., ) return contractadd!( @@ -134,15 +107,11 @@ function contract!( end function contractadd!( - alg::Algorithm, - a_dest::AbstractArray, - labels_dest, - a1::AbstractArray, - labels1, - a2::AbstractArray, - labels2, - α::Number, - β::Number; + alg::ContractAlgorithm, + a_dest::AbstractArray, labels_dest, + a1::AbstractArray, labels1, + a2::AbstractArray, labels2, + α::Number, β::Number; kwargs..., ) check_input(contract, a_dest, labels_dest, a1, labels1, a2, labels2) @@ -151,12 +120,10 @@ function contractadd!( end function contract( - alg::Algorithm, - biperm_dest::AbstractBlockPermutation, - a1::AbstractArray, - biperm1::AbstractBlockPermutation, - a2::AbstractArray, - biperm2::AbstractBlockPermutation; + alg::ContractAlgorithm, + biperm_dest::AbstractBlockPermutation{2}, + a1::AbstractArray, biperm1::AbstractBlockPermutation{2}, + a2::AbstractArray, biperm2::AbstractBlockPermutation{2}; kwargs..., ) check_input(contract, a1, biperm1, a2, biperm2) diff --git a/src/contract/contract_matricize/contract.jl b/src/contract/contract_matricize/contract.jl index 752ef7c..9ed4223 100644 --- a/src/contract/contract_matricize/contract.jl +++ b/src/contract/contract_matricize/contract.jl @@ -2,14 +2,20 @@ using LinearAlgebra: mul! function contractadd!( alg::Matricize, - a_dest::AbstractArray, - biperm_dest::AbstractBlockPermutation{2}, - a1::AbstractArray, - biperm1::AbstractBlockPermutation{2}, - a2::AbstractArray, - biperm2::AbstractBlockPermutation{2}, - α::Number, - β::Number, + a_dest::AbstractArray, biperm_dest::AbstractBlockPermutation{2}, + a1::AbstractArray, biperm1::AbstractBlockPermutation{2}, + a2::AbstractArray, biperm2::AbstractBlockPermutation{2}, + α::Number, β::Number, + ) + return contractadd!_matricize(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β) +end + +function contractadd!_matricize( + alg::Matricize, + a_dest::AbstractArray, biperm_dest::AbstractBlockPermutation{2}, + a1::AbstractArray, biperm1::AbstractBlockPermutation{2}, + a2::AbstractArray, biperm2::AbstractBlockPermutation{2}, + α::Number, β::Number, ) invbiperm = biperm(invperm(biperm_dest), length_codomain(biperm1)) check_input(contract, a_dest, invbiperm, a1, biperm1, a2, biperm2) diff --git a/src/contract/output_labels.jl b/src/contract/output_labels.jl index 071869d..3bbc210 100644 --- a/src/contract/output_labels.jl +++ b/src/contract/output_labels.jl @@ -1,6 +1,6 @@ function output_labels( f::typeof(contract), - alg::Algorithm, + alg::ContractAlgorithm, a1::AbstractArray, labels1, a2::AbstractArray, @@ -9,7 +9,7 @@ function output_labels( return output_labels(f, alg, labels1, labels2) end -function output_labels(f::typeof(contract), ::Algorithm, labels1, labels2) +function output_labels(f::typeof(contract), ::ContractAlgorithm, labels1, labels2) return output_labels(f, labels1, labels2) end diff --git a/test/test_basics.jl b/test/test_basics.jl index 2b408ae..b624bbf 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -5,8 +5,8 @@ using StableRNGs: StableRNG using TensorOperations: TensorOperations using TensorAlgebra: - Algorithm, BlockedTuple, + ContractAlgorithm, blockedpermvcat, contract, contract!, @@ -159,7 +159,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test_throws ArgumentError unmatricize!(m, m, blockedpermvcat((1, 2), (3,))) end - alg_tensoroperations = Algorithm(TensorOperations.StridedBLAS()) + alg_tensoroperations = ContractAlgorithm(TensorOperations.StridedBLAS()) @testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts elt_dest = promote_type(elt1, elt2) a1 = ones(elt1, (1, 1))