From be5562f2685c94a06b9f4ce66bfa31f2a209ee9a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 8 Dec 2025 16:24:30 -0500 Subject: [PATCH 1/3] Use FusionStyle in tensor factorizations and functions --- Project.toml | 5 +- src/factorizations.jl | 103 +++++++++++++++++++++++++++++++---------- src/matrixfunctions.jl | 44 ++++++++++++++---- 3 files changed, 119 insertions(+), 33 deletions(-) diff --git a/Project.toml b/Project.toml index 22b73c6..c93f7ec 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.6.3" +version = "0.6.4" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" @@ -28,3 +28,6 @@ TensorOperations = "5" TupleTools = "1.6" TypeParameterAccessors = "0.2.1, 0.3, 0.4" julia = "1.10" + +[workspace] +projects = ["benchmark", "dev", "docs", "examples", "test"] diff --git a/src/factorizations.jl b/src/factorizations.jl index 837ffb5..e0c0d3c 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -6,14 +6,17 @@ for f in ( :factorize, ) @eval begin - function $f(A::AbstractArray, ndims_codomain::Val; kwargs...) - A_mat = matricize(A, ndims_codomain) + function $f(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...) + A_mat = matricize(style, A, ndims_codomain) X, Y = MatrixAlgebra.$f(A_mat; kwargs...) biperm = trivialbiperm(ndims_codomain, Val(ndims(A))) axes_codomain, axes_domain = blocks(axes(A)[biperm]) axes_X = tuplemortar((axes_codomain, (axes(X, 2),))) axes_Y = tuplemortar(((axes(Y, 1),), axes_domain)) - return unmatricize(X, axes_X), unmatricize(Y, axes_Y) + return unmatricize(style, X, axes_X), unmatricize(style, Y, axes_Y) + end + function $f(A::AbstractArray, ndims_codomain::Val; kwargs...) + return $f(FusionStyle(A), A, ndims_codomain; kwargs...) end end end @@ -24,19 +27,40 @@ for f in ( ) @eval begin function $f( - A::AbstractArray, + style::FusionStyle, A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs..., ) A_perm = bipermutedims(A, perm_codomain, perm_domain) - return $f(A_perm, Val(length(perm_codomain)); kwargs...) + return $f(style, A_perm, Val(length(perm_codomain)); kwargs...) end - function $f(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) + function $f( + A::AbstractArray, + perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; + kwargs..., + ) + return $f(FusionStyle(A), A, perm_codomain, perm_domain; kwargs...) + end + + function $f( + style::FusionStyle, A::AbstractArray, + labels_A, labels_codomain, labels_domain; kwargs..., + ) biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) - return $f(A, blocks(biperm)...; kwargs...) + return $f(style, A, blocks(biperm)...; kwargs...) + end + function $f(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) + return $f(FusionStyle(A), A, labels_A, labels_codomain, labels_domain; kwargs...) + end + + function $f( + style::FusionStyle, A::AbstractArray, + biperm::AbstractBlockPermutation{2}; kwargs..., + ) + return $f(style, A, blocks(biperm)...; kwargs...) end function $f(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - return $f(A, blocks(biperm)...; kwargs...) + return $f(FusionStyle(A), A, biperm; kwargs...) end end end @@ -191,14 +215,20 @@ their labels or directly through a bi-permutation. See also `MatrixAlgebraKit.eig_full!`, `MatrixAlgebraKit.eig_trunc!`, `MatrixAlgebraKit.eig_vals!`, `MatrixAlgebraKit.eigh_full!`, `MatrixAlgebraKit.eigh_trunc!`, and `MatrixAlgebraKit.eigh_vals!`. """ -function eigen(A::AbstractArray, ndims_codomain::Val; kwargs...) +eigen + +function eigen(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...) # tensor to matrix - A_mat = matricize(A, ndims_codomain) + A_mat = matricize(style, A, ndims_codomain) D, V = MatrixAlgebra.eigen!(A_mat; kwargs...) biperm = trivialbiperm(ndims_codomain, Val(ndims(A))) axes_codomain, = blocks(axes(A)[biperm]) axes_V = tuplemortar((axes_codomain, (axes(V, ndims(V)),))) - return D, unmatricize(V, axes_V) + # TODO: Make sure `D` has the same basis as `V`. + return D, unmatricize(style, V, axes_V) +end +function eigen(A::AbstractArray, ndims_codomain::Val; kwargs...) + return eigen(FusionStyle(A), A, ndims_codomain; kwargs...) end """ @@ -219,10 +249,15 @@ their labels or directly through a bi-permutation. The output is a vector of eig See also `MatrixAlgebraKit.eig_vals!` and `MatrixAlgebraKit.eigh_vals!`. """ -function eigvals(A::AbstractArray, ndims_codomain::Val; kwargs...) - A_mat = matricize(A, ndims_codomain) +eigvals + +function eigvals(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...) + A_mat = matricize(style, A, ndims_codomain) return MatrixAlgebra.eigvals!(A_mat; kwargs...) end +function eigvals(A::AbstractArray, ndims_codomain::Val; kwargs...) + return eigvals(FusionStyle(A), A, ndims_codomain; kwargs...) +end """ svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> U, S, Vᴴ @@ -243,14 +278,19 @@ their labels or directly through a bi-permutation. See also `MatrixAlgebraKit.svd_full!`, `MatrixAlgebraKit.svd_compact!`, and `MatrixAlgebraKit.svd_trunc!`. """ -function svd(A::AbstractArray, ndims_codomain::Val; kwargs...) - A_mat = matricize(A, ndims_codomain) +svd + +function svd(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...) + A_mat = matricize(style, A, ndims_codomain) U, S, Vᴴ = MatrixAlgebra.svd!(A_mat; kwargs...) biperm = trivialbiperm(ndims_codomain, Val(ndims(A))) axes_codomain, axes_domain = blocks(axes(A)[biperm]) axes_U = tuplemortar((axes_codomain, (axes(U, 2),))) axes_Vᴴ = tuplemortar(((axes(Vᴴ, 1),), axes_domain)) - return unmatricize(U, axes_U), S, unmatricize(Vᴴ, axes_Vᴴ) + return unmatricize(style, U, axes_U), S, unmatricize(style, Vᴴ, axes_Vᴴ) +end +function svd(A::AbstractArray, ndims_codomain::Val; kwargs...) + return svd(FusionStyle(A), A, ndims_codomain; kwargs...) end """ @@ -265,10 +305,15 @@ their labels or directly through a bi-permutation. The output is a vector of sin See also `MatrixAlgebraKit.svd_vals!`. """ -function svdvals(A::AbstractArray, ndims_codomain::Val) - A_mat = matricize(A, ndims_codomain) +svdvals + +function svdvals(style::FusionStyle, A::AbstractArray, ndims_codomain::Val) + A_mat = matricize(style, A, ndims_codomain) return MatrixAlgebra.svdvals!(A_mat) end +function svdvals(A::AbstractArray, ndims_codomain::Val) + return svdvals(FusionStyle(A), A, ndims_codomain) +end """ left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> N @@ -289,13 +334,18 @@ The output satisfies `N' * A ≈ 0` and `N' * N ≈ I`. The options are `:qr`, `:qrpos` and `:svd`. The former two require `0 == atol == rtol`. The default is `:qrpos` if `atol == rtol == 0`, and `:svd` otherwise. """ -function left_null(A::AbstractArray, ndims_codomain::Val; kwargs...) - A_mat = matricize(A, ndims_codomain) +left_null + +function left_null(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...) + A_mat = matricize(style, A, ndims_codomain) N = MatrixAlgebraKit.left_null!(A_mat; kwargs...) biperm = trivialbiperm(ndims_codomain, Val(ndims(A))) axes_codomain = first(blocks(axes(A)[biperm])) axes_N = tuplemortar((axes_codomain, (axes(N, 2),))) - return unmatricize(N, axes_N) + return unmatricize(style, N, axes_N) +end +function left_null(A::AbstractArray, ndims_codomain::Val; kwargs...) + return left_null(FusionStyle(A), A, ndims_codomain; kwargs...) end """ @@ -317,11 +367,16 @@ The output satisfies `A * Nᴴ' ≈ 0` and `Nᴴ * Nᴴ' ≈ I`. The options are `:lq`, `:lqpos` and `:svd`. The former two require `0 == atol == rtol`. The default is `:lqpos` if `atol == rtol == 0`, and `:svd` otherwise. """ -function right_null(A::AbstractArray, ndims_codomain::Val; kwargs...) - A_mat = matricize(A, ndims_codomain) +right_null + +function right_null(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...) + A_mat = matricize(style, A, ndims_codomain) Nᴴ = MatrixAlgebraKit.right_null!(A_mat; kwargs...) biperm = trivialbiperm(ndims_codomain, Val(ndims(A))) axes_domain = last(blocks((axes(A)[biperm]))) axes_Nᴴ = tuplemortar(((axes(Nᴴ, 1),), axes_domain)) - return unmatricize(Nᴴ, axes_Nᴴ) + return unmatricize(style, Nᴴ, axes_Nᴴ) +end +function right_null(A::AbstractArray, ndims_codomain::Val; kwargs...) + return right_null(FusionStyle(A), A, ndims_codomain; kwargs...) end diff --git a/src/matrixfunctions.jl b/src/matrixfunctions.jl index 2a4294e..1679a40 100644 --- a/src/matrixfunctions.jl +++ b/src/matrixfunctions.jl @@ -33,26 +33,54 @@ const MATRIX_FUNCTIONS = [ for f in MATRIX_FUNCTIONS @eval begin - function $f(a::AbstractArray, ndims_codomain::Val; kwargs...) - a_mat = matricize(a, ndims_codomain) + function $f(style::FusionStyle, a::AbstractArray, ndims_codomain::Val; kwargs...) + a_mat = matricize(style, a, ndims_codomain) fa_mat = Base.$f(a_mat; kwargs...) biperm = trivialbiperm(ndims_codomain, Val(ndims(a))) - return unmatricize(fa_mat, axes(a)[biperm]) + return unmatricize(style, fa_mat, axes(a)[biperm]) + end + function $f(a::AbstractArray, ndims_codomain::Val; kwargs...) + return $f(FusionStyle(a), a, ndims_codomain; kwargs...) end + function $f( - a::AbstractArray, + style::FusionStyle, a::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs..., ) a_perm = bipermutedims(a, perm_codomain, perm_domain) - return $f(a_perm, Val(length(perm_codomain)); kwargs...) + return $f(style, a_perm, Val(length(perm_codomain)); kwargs...) + end + function $f( + a::AbstractArray, + perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; + kwargs..., + ) + return $f(FusionStyle(a), a, perm_codomain, perm_domain; kwargs...) end - function $f(a::AbstractArray, labels_a, labels_codomain, labels_domain; kwargs...) + + function $f( + style::FusionStyle, a::AbstractArray, + labels_a, labels_codomain, labels_domain; kwargs..., + ) biperm = blockedperm_indexin(Tuple.((labels_a, labels_codomain, labels_domain))...) - return $f(a, blocks(biperm)...; kwargs...) + return $f(style, a, blocks(biperm)...; kwargs...) + end + function $f( + a::AbstractArray, + labels_a, labels_codomain, labels_domain; kwargs..., + ) + return $f(FusionStyle(a), a, labels_a, labels_codomain, labels_domain; kwargs...) + end + + function $f( + style::FusionStyle, a::AbstractArray, + biperm::AbstractBlockPermutation{2}; kwargs..., + ) + return $f(style, a, blocks(biperm)...; kwargs...) end function $f(a::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - return $f(a, blocks(biperm)...; kwargs...) + return $f(FusionStyle(a), a, biperm; kwargs...) end end end From 184e857393d734b0d3694668861d6d8b87eb20e6 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 8 Dec 2025 17:01:58 -0500 Subject: [PATCH 2/3] Change flow logic to avoid breaking FusionTensors --- src/factorizations.jl | 8 +++++--- src/matrixfunctions.jl | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/factorizations.jl b/src/factorizations.jl index e0c0d3c..75c74a6 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -39,7 +39,8 @@ for f in ( perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs..., ) - return $f(FusionStyle(A), A, perm_codomain, perm_domain; kwargs...) + A_perm = bipermutedims(A, perm_codomain, perm_domain) + return $f(A_perm, perm_codomain, perm_domain; kwargs...) end function $f( @@ -50,7 +51,8 @@ for f in ( return $f(style, A, blocks(biperm)...; kwargs...) end function $f(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) - return $f(FusionStyle(A), A, labels_A, labels_codomain, labels_domain; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) + return $f(A, blocks(biperm)...; kwargs...) end function $f( @@ -60,7 +62,7 @@ for f in ( return $f(style, A, blocks(biperm)...; kwargs...) end function $f(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - return $f(FusionStyle(A), A, biperm; kwargs...) + return $f(A, blocks(biperm)...; kwargs...) end end end diff --git a/src/matrixfunctions.jl b/src/matrixfunctions.jl index 1679a40..d0bb8e8 100644 --- a/src/matrixfunctions.jl +++ b/src/matrixfunctions.jl @@ -56,7 +56,8 @@ for f in MATRIX_FUNCTIONS perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs..., ) - return $f(FusionStyle(a), a, perm_codomain, perm_domain; kwargs...) + a_perm = bipermutedims(a, perm_codomain, perm_domain) + return $f(a_perm, perm_codomain, perm_domain; kwargs...) end function $f( @@ -70,7 +71,8 @@ for f in MATRIX_FUNCTIONS a::AbstractArray, labels_a, labels_codomain, labels_domain; kwargs..., ) - return $f(FusionStyle(a), a, labels_a, labels_codomain, labels_domain; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_a, labels_codomain, labels_domain))...) + return $f(a, labels_a, labels_codomain, labels_domain; kwargs...) end function $f( @@ -80,7 +82,7 @@ for f in MATRIX_FUNCTIONS return $f(style, a, blocks(biperm)...; kwargs...) end function $f(a::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) - return $f(FusionStyle(a), a, biperm; kwargs...) + return $f(a, blocks(biperm)...; kwargs...) end end end From cbbfea9f1d49bb4e5498a3af82cf382dced56bb6 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 8 Dec 2025 17:07:11 -0500 Subject: [PATCH 3/3] Fix typo --- src/factorizations.jl | 2 +- src/matrixfunctions.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/factorizations.jl b/src/factorizations.jl index 75c74a6..3c4b1b7 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -40,7 +40,7 @@ for f in ( kwargs..., ) A_perm = bipermutedims(A, perm_codomain, perm_domain) - return $f(A_perm, perm_codomain, perm_domain; kwargs...) + return $f(A_perm, Val(length(perm_codomain)); kwargs...) end function $f( diff --git a/src/matrixfunctions.jl b/src/matrixfunctions.jl index d0bb8e8..9a0e780 100644 --- a/src/matrixfunctions.jl +++ b/src/matrixfunctions.jl @@ -57,7 +57,7 @@ for f in MATRIX_FUNCTIONS kwargs..., ) a_perm = bipermutedims(a, perm_codomain, perm_domain) - return $f(a_perm, perm_codomain, perm_domain; kwargs...) + return $f(a_perm, Val(length(perm_codomain)); kwargs...) end function $f( @@ -72,7 +72,7 @@ for f in MATRIX_FUNCTIONS labels_a, labels_codomain, labels_domain; kwargs..., ) biperm = blockedperm_indexin(Tuple.((labels_a, labels_codomain, labels_domain))...) - return $f(a, labels_a, labels_codomain, labels_domain; kwargs...) + return $f(a, blocks(biperm)...; kwargs...) end function $f(