From 38951f52c04baa8c77e681e01e87f98a2660520a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 2 Dec 2025 12:44:30 -0500 Subject: [PATCH 1/7] matricize based on just the codomain length --- Project.toml | 2 +- src/blockedpermutation.jl | 5 ++ src/factorizations.jl | 142 +++++++++++++---------------------- src/matricize.jl | 123 ++++++++++++++++-------------- src/matrixfunctions.jl | 12 +-- test/test_factorizations.jl | 2 +- test/test_matrixfunctions.jl | 2 +- 7 files changed, 132 insertions(+), 156 deletions(-) diff --git a/Project.toml b/Project.toml index 63c7d4c..311e276 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.5.4" +version = "0.5.5" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index dd942dc..1a2c74b 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -206,6 +206,11 @@ function blockedtrivialperm(blocklengths::Tuple{Vararg{Int}}) return blockedtrivialperm(Val.(blocklengths)) end +function trivialbiperm(length_codomain::Val, length::Val) + length_domain = Val(unval(length) - unval(length_codomain)) + return blockedtrivialperm((length_codomain, length_domain)) +end + function trivialperm(blockedperm::AbstractBlockTuple) return blockedtrivialperm(blocklengths(blockedperm)) end diff --git a/src/factorizations.jl b/src/factorizations.jl index db7862d..c232a97 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -2,22 +2,14 @@ using LinearAlgebra: LinearAlgebra using MatrixAlgebraKit: MatrixAlgebraKit for f in ( - :qr, :lq, :left_polar, :right_polar, :polar, :left_orth, :right_orth, :orth, :factorize, + :qr, :lq, :left_polar, :right_polar, :polar, :left_orth, :right_orth, :orth, + :factorize, ) @eval begin - function $f( - A::AbstractArray, - codomain_length::Val, domain_length::Val; - kwargs..., - ) - # tensor to matrix - A_mat = matricize(A, codomain_length, domain_length) - - # factorization + function $f(A::AbstractArray, length_codomain::Val; kwargs...) + A_mat = matricize(A, length_codomain) X, Y = MatrixAlgebra.$f(A_mat; kwargs...) - - # matrix to tensor - biperm = blockedtrivialperm((codomain_length, domain_length)) + biperm = trivialbiperm(length_codomain, Val(ndims(A))) axes_codomain, axes_domain = blocks(axes(A)[biperm]) axes_X = tuplemortar((axes_codomain, (axes(X, 2),))) axes_Y = tuplemortar(((axes(Y, 1),), axes_domain)) @@ -27,17 +19,17 @@ for f in ( end for f in ( - :qr, :lq, :left_polar, :right_polar, :polar, :left_orth, :right_orth, :orth, :factorize, - :eigen, :eigvals, :svd, :svdvals, :left_null, :right_null, + :qr, :lq, :left_polar, :right_polar, :polar, :left_orth, :right_orth, :orth, + :factorize, :eigen, :eigvals, :svd, :svdvals, :left_null, :right_null, ) @eval begin function $f( A::AbstractArray, - codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; + perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs..., ) - A_perm = bipermutedims(A, codomain_perm, domain_perm) - return $f(A_perm, Val(length(codomain_perm)), Val(length(domain_perm)); kwargs...) + A_perm = bipermutedims(A, perm_codomain, perm_domain) + return $f(A_perm, Val(length(perm_codomain)); kwargs...) end function $f(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) @@ -51,8 +43,8 @@ end """ qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Q, R - qr(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> Q, R - qr(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> Q, R + qr(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> Q, R + qr(A::AbstractArray, length_codomain::Val; kwargs...) -> Q, R qr(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> Q, R Compute the QR decomposition of a generic N-dimensional array, by interpreting it as @@ -71,8 +63,8 @@ qr """ lq(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> L, Q - lq(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> L, Q - lq(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> L, Q + lq(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> L, Q + lq(A::AbstractArray, length_codomain::Val; kwargs...) -> L, Q lq(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> L, Q Compute the LQ decomposition of a generic N-dimensional array, by interpreting it as @@ -91,8 +83,8 @@ lq """ left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> W, P - left_polar(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> W, P - left_polar(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> W, P + left_polar(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> W, P + left_polar(A::AbstractArray, length_codomain::Val; kwargs...) -> W, P left_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> W, P Compute the left polar decomposition of a generic N-dimensional array, by interpreting it as @@ -109,8 +101,8 @@ left_polar """ right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> P, W - right_polar(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> P, W - right_polar(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> P, W + right_polar(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> P, W + right_polar(A::AbstractArray, length_codomain::Val; kwargs...) -> P, W right_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> P, W Compute the right polar decomposition of a generic N-dimensional array, by interpreting it as @@ -127,8 +119,8 @@ right_polar """ left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> V, C - left_orth(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> V, C - left_orth(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> V, C + left_orth(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> V, C + left_orth(A::AbstractArray, length_codomain::Val; kwargs...) -> V, C left_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> V, C Compute the left orthogonal decomposition of a generic N-dimensional array, by interpreting it as @@ -145,8 +137,8 @@ left_orth """ right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> C, V - right_orth(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> C, V - right_orth(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> C, V + right_orth(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> C, V + right_orth(A::AbstractArray, length_codomain::Val; kwargs...) -> C, V right_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> C, V Compute the right orthogonal decomposition of a generic N-dimensional array, by interpreting it as @@ -163,8 +155,8 @@ right_orth """ factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> X, Y - factorize(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> X, Y - factorize(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> X, Y + factorize(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> X, Y + factorize(A::AbstractArray, length_codomain::Val; kwargs...) -> X, Y factorize(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> X, Y Compute the decomposition of a generic N-dimensional array, by interpreting it as @@ -181,8 +173,8 @@ factorize """ eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> D, V - eigen(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> D, V - eigen(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> D, V + eigen(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> D, V + eigen(A::AbstractArray, length_codomain::Val; kwargs...) -> D, V eigen(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> D, V Compute the eigenvalue decomposition of a generic N-dimensional array, by interpreting it as @@ -199,18 +191,11 @@ their labels or directly through a bi-permutation. See also `MatrixAlgebraKit.eig_full!`, `MatrixAlgebraKit.eig_trunc!`, `MatrixAlgebraKit.eig_vals!`, `MatrixAlgebraKit.eigh_full!`, `MatrixAlgebraKit.eigh_trunc!`, and `MatrixAlgebraKit.eigh_vals!`. """ -function eigen( - A::AbstractArray, - codomain_length::Val, domain_length::Val; - kwargs..., - ) +function eigen(A::AbstractArray, length_codomain::Val; kwargs...) # tensor to matrix - A_mat = matricize(A, codomain_length, domain_length) - # factorization + A_mat = matricize(A, length_codomain) D, V = MatrixAlgebra.eigen!(A_mat; kwargs...) - - # matrix to tensor - biperm = blockedtrivialperm((codomain_length, domain_length)) + biperm = trivialbiperm(length_codomain, Val(ndims(A))) axes_codomain, = blocks(axes(A)[biperm]) axes_V = tuplemortar((axes_codomain, (axes(V, ndims(V)),))) return D, unmatricize(V, axes_V) @@ -218,8 +203,8 @@ end """ eigvals(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> D - eigvals(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> D - eigvals(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> D + eigvals(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> D + eigvals(A::AbstractArray, length_codomain::Val; kwargs...) -> D eigvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> D Compute the eigenvalues of a generic N-dimensional array, by interpreting it as @@ -234,19 +219,15 @@ their labels or directly through a bi-permutation. The output is a vector of eig See also `MatrixAlgebraKit.eig_vals!` and `MatrixAlgebraKit.eigh_vals!`. """ -function eigvals( - A::AbstractArray, - codomain_length::Val, domain_length::Val; - kwargs..., - ) - A_mat = matricize(A, codomain_length, domain_length) +function eigvals(A::AbstractArray, length_codomain::Val; kwargs...) + A_mat = matricize(A, length_codomain) return MatrixAlgebra.eigvals!(A_mat; kwargs...) end """ svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> U, S, Vᴴ - svd(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> U, S, Vᴴ - svd(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> U, S, Vᴴ + svd(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> U, S, Vᴴ + svd(A::AbstractArray, length_codomain::Val; kwargs...) -> U, S, Vᴴ svd(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> U, S, Vᴴ Compute the SVD decomposition of a generic N-dimensional array, by interpreting it as @@ -262,18 +243,10 @@ their labels or directly through a bi-permutation. See also `MatrixAlgebraKit.svd_full!`, `MatrixAlgebraKit.svd_compact!`, and `MatrixAlgebraKit.svd_trunc!`. """ -function svd( - A::AbstractArray, - codomain_length::Val, domain_length::Val; - kwargs..., - ) - # tensor to matrix - A_mat = matricize(A, codomain_length, domain_length) - # factorization +function svd(A::AbstractArray, length_codomain::Val; kwargs...) + A_mat = matricize(A, length_codomain) U, S, Vᴴ = MatrixAlgebra.svd!(A_mat; kwargs...) - - # matrix to tensor - biperm = blockedtrivialperm((codomain_length, domain_length)) + biperm = trivialbiperm(length_codomain, Val(ndims(A))) axes_codomain, axes_domain = blocks(axes(A)[biperm]) axes_U = tuplemortar((axes_codomain, (axes(U, 2),))) axes_Vᴴ = tuplemortar(((axes(Vᴴ, 1),), axes_domain)) @@ -282,8 +255,8 @@ end """ svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain) -> S - svdvals(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}) -> S - svdvals(A::AbstractArray, codomain_length::Val, domain_length::Val) -> S + svdvals(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}) -> S + svdvals(A::AbstractArray, length_codomain::Val) -> S svdvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}) -> S Compute the singular values of a generic N-dimensional array, by interpreting it as @@ -292,18 +265,15 @@ their labels or directly through a bi-permutation. The output is a vector of sin See also `MatrixAlgebraKit.svd_vals!`. """ -function svdvals( - A::AbstractArray, - codomain_length::Val, domain_length::Val - ) - A_mat = matricize(A, codomain_length, domain_length) +function svdvals(A::AbstractArray, length_codomain::Val) + A_mat = matricize(A, length_codomain) return MatrixAlgebra.svdvals!(A_mat) end """ left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> N - left_null(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> N - left_null(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> N + left_null(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> N + left_null(A::AbstractArray, length_codomain::Val; kwargs...) -> N left_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> N Compute the left nullspace of a generic N-dimensional array, by interpreting it as @@ -319,14 +289,10 @@ The output satisfies `N' * A ≈ 0` and `N' * N ≈ I`. The options are `:qr`, `:qrpos` and `:svd`. The former two require `0 == atol == rtol`. The default is `:qrpos` if `atol == rtol == 0`, and `:svd` otherwise. """ -function left_null( - A::AbstractArray, - codomain_length::Val, domain_length::Val; - kwargs..., - ) - A_mat = matricize(A, codomain_length, domain_length) +function left_null(A::AbstractArray, length_codomain::Val; kwargs...) + A_mat = matricize(A, length_codomain) N = MatrixAlgebraKit.left_null!(A_mat; kwargs...) - biperm = blockedtrivialperm((codomain_length, domain_length)) + biperm = trivialbiperm(length_codomain, Val(ndims(A))) axes_codomain = first(blocks(axes(A)[biperm])) axes_N = tuplemortar((axes_codomain, (axes(N, 2),))) return unmatricize(N, axes_N) @@ -334,8 +300,8 @@ end """ right_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Nᴴ - right_null(A::AbstractArray, codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; kwargs...) -> Nᴴ - right_null(A::AbstractArray, codomain_length::Val, domain_length::Val; kwargs...) -> Nᴴ + right_null(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> Nᴴ + right_null(A::AbstractArray, length_codomain::Val::Val; kwargs...) -> Nᴴ right_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> Nᴴ Compute the right nullspace of a generic N-dimensional array, by interpreting it as @@ -351,14 +317,10 @@ The output satisfies `A * Nᴴ' ≈ 0` and `Nᴴ * Nᴴ' ≈ I`. The options are `:lq`, `:lqpos` and `:svd`. The former two require `0 == atol == rtol`. The default is `:lqpos` if `atol == rtol == 0`, and `:svd` otherwise. """ -function right_null( - A::AbstractArray, - codomain_length::Val, domain_length::Val; - kwargs..., - ) - A_mat = matricize(A, codomain_length, domain_length) +function right_null(A::AbstractArray, length_codomain::Val; kwargs...) + A_mat = matricize(A, length_codomain) Nᴴ = MatrixAlgebraKit.right_null!(A_mat; kwargs...) - biperm = blockedtrivialperm((codomain_length, domain_length)) + biperm = trivialbiperm(length_codomain, Val(ndims(A))) axes_domain = last(blocks((axes(A)[biperm]))) axes_Nᴴ = tuplemortar(((axes(Nᴴ, 1),), axes_domain)) return unmatricize(Nᴴ, axes_Nᴴ) diff --git a/src/matricize.jl b/src/matricize.jl index 2e29b4d..52661d0 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -12,30 +12,31 @@ FusionStyle(x) = FusionStyle(typeof(x)) FusionStyle(T::Type) = throw(MethodError(FusionStyle, (T,))) # ======================================= misc ======================================== -trivial_axis(::Tuple{}) = Base.OneTo(1) -trivial_axis(::Tuple{Vararg{AbstractUnitRange}}) = Base.OneTo(1) -trivial_axis(::Tuple{Vararg{AbstractBlockedUnitRange}}) = blockedrange([1]) +axis_type(a::AbstractArray) = eltype(axes(a)) +axis_type(a::AbstractArray{<:Any, 0}) = Base.OneTo{Int} +trivial_axis(::Type{<:AbstractUnitRange}) = Base.OneTo(1) +trivial_axis(::Type{<:AbstractBlockedUnitRange}) = blockedrange([1]) -# Inner version takes a list of sub-permutations, overload this one if needed. -function fuseaxes( - axes::Tuple{Vararg{AbstractUnitRange}}, lengths::Val... - ) - axesblocks = blocks(axes[blockedtrivialperm(lengths)]) - return map(block -> isempty(block) ? trivial_axis(axes) : ⊗(block...), axesblocks) -end +# Fallback to `TensorProducts.⊗`. +# TODO: Remove dependency on TensorProducts.jl, update downstream packages: +# BlockSparseArrays.jl, GradedArrays.jl, FusionTensors.jl. +tensor_product(::FusionStyle, ax1, ax2) = ax1 ⊗ ax2 # Inner version takes a list of sub-permutations, overload this one if needed. -function fuseaxes( - axes::Tuple{Vararg{AbstractUnitRange}}, permblocks::Tuple{Vararg{Int}}... - ) - axes′ = map(d -> axes[d], permmortar(permblocks)) - return fuseaxes(axes′, Val.(length.(permblocks))...) -end - -function fuseaxes( - axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation - ) - return fuseaxes(axes, blocks(blockedperm)...) +function matricize_axes(style::FusionStyle, a::AbstractArray, length_codomain::Val) + unval(length_codomain) ≤ ndims(a) || + throw(ArgumentError("Codomain length exceeds number of dimensions.")) + biperm = trivialbiperm(length_codomain, Val(ndims(a))) + axesblocks = blocks(axes(a)[biperm]) + init_axis = trivial_axis(axis_type(a)) + return map(axesblocks) do axesblock + return reduce(axesblock; init = init_axis) do ax1, ax2 + return tensor_product(style, ax1, ax2) + end + end +end +function matricize_axes(a::AbstractArray, length_codomain::Val) + return matricize_axes(FusionStyle(a), a, length_codomain) end # Inner version takes a list of sub-permutations, overload this one if needed. @@ -81,33 +82,36 @@ end # matrix factorizations assume copy # maybe: copy=false kwarg -function matricize(a::AbstractArray, length1::Val, length2::Val) - return matricize(FusionStyle(a), a, length1, length2) +function matricize(a::AbstractArray, length_codomain::Val) + return matricize(FusionStyle(a), a, length_codomain) end # This is the primary function that should be overloaded for new fusion styles. # This assumes the permutation was already performed. -function matricize(style::FusionStyle, a::AbstractArray, length1::Val, length2::Val) +function matricize( + style::FusionStyle, a::AbstractArray, length_codomain::Val + ) return throw( MethodError( - matricize, Tuple{typeof(style), typeof(a), typeof(length1), typeof(length2)} + matricize, Tuple{typeof(style), typeof(a), typeof(length_codomain)} ) ) end function matricize( - a::AbstractArray, permblock1::Tuple{Vararg{Int}}, permblock2::Tuple{Vararg{Int}} + a::AbstractArray, + permblock_codomain::Tuple{Vararg{Int}}, permblock_domain::Tuple{Vararg{Int}} ) - return matricize(FusionStyle(a), a, permblock1, permblock2) + return matricize(FusionStyle(a), a, permblock_codomain, permblock_domain) end # This is a more advanced version to overload where the permutation is actually performed. function matricize( style::FusionStyle, a::AbstractArray, - permblock1::NTuple{N1, Int}, permblock2::NTuple{N2, Int} - ) where {N1, N2} - ndims(a) == length(permblock1) + length(permblock2) || + permblock_codomain::Tuple{Vararg{Int}}, permblock_domain::Tuple{Vararg{Int}} + ) + ndims(a) == length(permblock_codomain) + length(permblock_domain) || throw(ArgumentError("Invalid bipermutation")) - a_perm = bipermutedims(a, permblock1, permblock2) - return matricize(style, a_perm, Val(length(permblock1)), Val(length(permblock2))) + a_perm = bipermutedims(a, permblock_codomain, permblock_domain) + return matricize(style, a_perm, Val(length(permblock_codomain))) end # Process inputs such as `EllipsisNotation.Ellipsis`. @@ -133,11 +137,14 @@ function to_permblocks( permblocks2 = tuplesetcomplement(ntuple(identity, ndims(a)), permblocks[1]) return (permblocks[1], permblocks2) end -function matricize(a::AbstractArray, permblock1, permblock2) - return matricize(FusionStyle(a), a, permblock1, permblock2) + +function matricize(a::AbstractArray, permblock_codomain, permblock_domain) + return matricize(FusionStyle(a), a, permblock_codomain, permblock_domain) end -function matricize(style::FusionStyle, a::AbstractArray, permblock1, permblock2) - return matricize(style, a, to_permblocks(a, (permblock1, permblock2))...) +function matricize( + style::FusionStyle, a::AbstractArray, permblock_codomain, permblock_domain + ) + return matricize(style, a, to_permblocks(a, (permblock_codomain, permblock_domain))...) end function matricize(a::AbstractArray, biperm_dest::AbstractBlockPermutation{2}) @@ -152,22 +159,22 @@ end # ==================================== unmatricize ======================================= function unmatricize( m::AbstractMatrix, - codomain_axes::Tuple{Vararg{AbstractUnitRange}}, - domain_axes::Tuple{Vararg{AbstractUnitRange}}, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) - return unmatricize(FusionStyle(m), m, codomain_axes, domain_axes) + return unmatricize(FusionStyle(m), m, axes_codomain, axes_domain) end # This is the primary function that should be overloaded for new fusion styles. function unmatricize( style::FusionStyle, m::AbstractMatrix, - codomain_axes::Tuple{Vararg{AbstractUnitRange}}, - domain_axes::Tuple{Vararg{AbstractUnitRange}}, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) return throw( MethodError( unmatricize, Tuple{ - typeof(style), typeof(m), typeof(codomain_axes), typeof(domain_axes), + typeof(style), typeof(m), typeof(axes_codomain), typeof(axes_domain), }, ) ) @@ -190,9 +197,9 @@ function unmatricize( end function unmatricize( style::FusionStyle, m::AbstractMatrix, axes_dest, - invperm1::Tuple{Vararg{Int}}, invperm2::Tuple{Vararg{Int}}, + invperm_codomain::Tuple{Vararg{Int}}, invperm_domain::Tuple{Vararg{Int}}, ) - invbiperm = permmortar((invperm1, invperm2)) + invbiperm = permmortar((invperm_codomain, invperm_domain)) length(axes_dest) == length(invbiperm) || throw(ArgumentError("axes do not match permutation")) blocked_axes = axes_dest[invbiperm] @@ -213,15 +220,15 @@ end function unmatricize!( a_dest::AbstractArray, m::AbstractMatrix, - invperm1::Tuple{Vararg{Int}}, invperm2::Tuple{Vararg{Int}}, + invperm_codomain::Tuple{Vararg{Int}}, invperm_domain::Tuple{Vararg{Int}}, ) - return unmatricize!(FusionStyle(m), a_dest, m, invperm1, invperm2) + return unmatricize!(FusionStyle(m), a_dest, m, invperm_codomain, invperm_domain) end function unmatricize!( style::FusionStyle, a_dest::AbstractArray, m::AbstractMatrix, - invperm1::Tuple{Vararg{Int}}, invperm2::Tuple{Vararg{Int}}, + invperm_codomain::Tuple{Vararg{Int}}, invperm_domain::Tuple{Vararg{Int}}, ) - invbiperm = permmortar((invperm1, invperm2)) + invbiperm = permmortar((invperm_codomain, invperm_domain)) ndims(a_dest) == length(invbiperm) || throw(ArgumentError("destination does not match permutation")) blocked_axes = axes(a_dest)[invbiperm] @@ -244,17 +251,19 @@ end function unmatricizeadd!( a_dest::AbstractArray, m::AbstractMatrix, - invperm1::Tuple{Vararg{Int}}, invperm2::Tuple{Vararg{Int}}, + invperm_codomain::Tuple{Vararg{Int}}, invperm_domain::Tuple{Vararg{Int}}, α::Number, β::Number ) - return unmatricizeadd!(FusionStyle(a_dest), a_dest, m, invperm1, invperm2, α, β) + return unmatricizeadd!( + FusionStyle(a_dest), a_dest, m, invperm_codomain, invperm_domain, α, β + ) end function unmatricizeadd!( style::FusionStyle, a_dest::AbstractArray, m::AbstractMatrix, - invperm1::Tuple{Vararg{Int}}, invperm2::Tuple{Vararg{Int}}, + invperm_codomain::Tuple{Vararg{Int}}, invperm_domain::Tuple{Vararg{Int}}, α::Number, β::Number, ) - a12 = unmatricize(style, m, axes(a_dest), invperm1, invperm2) + a12 = unmatricize(style, m, axes(a_dest), invperm_codomain, invperm_domain) a_dest .= α .* a12 .+ β .* a_dest return a_dest end @@ -279,13 +288,13 @@ end # Defaults to ReshapeFusion, a simple reshape struct ReshapeFusion <: FusionStyle end FusionStyle(::Type{<:AbstractArray}) = ReshapeFusion() -function matricize(style::ReshapeFusion, a::AbstractArray, length1::Val, length2::Val) - return reshape(a, fuseaxes(axes(a), length1, length2)) +function matricize(style::ReshapeFusion, a::AbstractArray, length_codomain::Val) + return reshape(a, matricize_axes(style, a, length_codomain)) end function unmatricize( style::ReshapeFusion, m::AbstractMatrix, - codomain_axes::Tuple{Vararg{AbstractUnitRange}}, - domain_axes::Tuple{Vararg{AbstractUnitRange}}, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) - return reshape(m, (codomain_axes..., domain_axes...)) + return reshape(m, (axes_codomain..., axes_domain...)) end diff --git a/src/matrixfunctions.jl b/src/matrixfunctions.jl index 883c5be..9b2ec87 100644 --- a/src/matrixfunctions.jl +++ b/src/matrixfunctions.jl @@ -35,21 +35,21 @@ for f in MATRIX_FUNCTIONS @eval begin function $f( a::AbstractArray, - codomain_length::Val, domain_length::Val; + length_codomain::Val; kwargs..., ) - a_mat = matricize(a, codomain_length, domain_length) + a_mat = matricize(a, length_codomain) fa_mat = Base.$f(a_mat; kwargs...) - biperm = blockedtrivialperm((codomain_length, domain_length)) + biperm = trivialbiperm(length_codomain, Val(ndims(a))) return unmatricize(fa_mat, axes(a)[biperm]) end function $f( a::AbstractArray, - codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; + perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs..., ) - a_perm = bipermutedims(a, codomain_perm, domain_perm) - return $f(a_perm, Val(length(codomain_perm)), Val(length(domain_perm)); kwargs...) + a_perm = bipermutedims(a, perm_codomain, perm_domain) + return $f(a_perm, Val(length(perm_codomain)); kwargs...) end function $f(a::AbstractArray, labels_a, labels_codomain, labels_domain; kwargs...) biperm = blockedperm_indexin(Tuple.((labels_a, labels_codomain, labels_domain))...) diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index 9292f90..9bd174a 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -40,7 +40,7 @@ elts = (Float64, ComplexF64) Q, R = qr(A, (2, 1), (4, 3); full = true) @test A ≈ contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...)) - Q, R = qr(A, Val(2), Val(2); full = true) + Q, R = qr(A, Val(2); full = true) @test A ≈ contract((:a, :b, :c, :d), Q, (:a, :b, :q), R, (:q, :c, :d)) end diff --git a/test/test_matrixfunctions.jl b/test/test_matrixfunctions.jl index 3e68c4c..056201e 100644 --- a/test/test_matrixfunctions.jl +++ b/test/test_matrixfunctions.jl @@ -17,7 +17,7 @@ using Test: @test, @testset local fa′ = reshape($f(reshape(permutedims(a, (3, 2, 4, 1)), (4, 4))), (2, 2, 2, 2)) @test fa ≈ fa′ end - fa = TensorAlgebra.$f(a, Val(2), Val(2)) + fa = TensorAlgebra.$f(a, Val(2)) fa′ = reshape($f(reshape(a, (4, 4))), (2, 2, 2, 2)) @test fa ≈ fa′ end From 9e0b233aa7a3d8e81bf48376d85c2a2cce7e4d10 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 2 Dec 2025 12:52:02 -0500 Subject: [PATCH 2/7] Rename length_codomain to ndims_codomain --- src/factorizations.jl | 64 +++++++++++++++++++++--------------------- src/matricize.jl | 22 +++++++-------- src/matrixfunctions.jl | 10 ++----- 3 files changed, 46 insertions(+), 50 deletions(-) diff --git a/src/factorizations.jl b/src/factorizations.jl index c232a97..837ffb5 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -6,10 +6,10 @@ for f in ( :factorize, ) @eval begin - function $f(A::AbstractArray, length_codomain::Val; kwargs...) - A_mat = matricize(A, length_codomain) + function $f(A::AbstractArray, ndims_codomain::Val; kwargs...) + A_mat = matricize(A, ndims_codomain) X, Y = MatrixAlgebra.$f(A_mat; kwargs...) - biperm = trivialbiperm(length_codomain, Val(ndims(A))) + biperm = trivialbiperm(ndims_codomain, Val(ndims(A))) axes_codomain, axes_domain = blocks(axes(A)[biperm]) axes_X = tuplemortar((axes_codomain, (axes(X, 2),))) axes_Y = tuplemortar(((axes(Y, 1),), axes_domain)) @@ -44,7 +44,7 @@ end """ qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Q, R qr(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> Q, R - qr(A::AbstractArray, length_codomain::Val; kwargs...) -> Q, R + qr(A::AbstractArray, ndims_codomain::Val; kwargs...) -> Q, R qr(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> Q, R Compute the QR decomposition of a generic N-dimensional array, by interpreting it as @@ -64,7 +64,7 @@ qr """ lq(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> L, Q lq(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> L, Q - lq(A::AbstractArray, length_codomain::Val; kwargs...) -> L, Q + lq(A::AbstractArray, ndims_codomain::Val; kwargs...) -> L, Q lq(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> L, Q Compute the LQ decomposition of a generic N-dimensional array, by interpreting it as @@ -84,7 +84,7 @@ lq """ left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> W, P left_polar(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> W, P - left_polar(A::AbstractArray, length_codomain::Val; kwargs...) -> W, P + left_polar(A::AbstractArray, ndims_codomain::Val; kwargs...) -> W, P left_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> W, P Compute the left polar decomposition of a generic N-dimensional array, by interpreting it as @@ -102,7 +102,7 @@ left_polar """ right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> P, W right_polar(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> P, W - right_polar(A::AbstractArray, length_codomain::Val; kwargs...) -> P, W + right_polar(A::AbstractArray, ndims_codomain::Val; kwargs...) -> P, W right_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> P, W Compute the right polar decomposition of a generic N-dimensional array, by interpreting it as @@ -120,7 +120,7 @@ right_polar """ left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> V, C left_orth(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> V, C - left_orth(A::AbstractArray, length_codomain::Val; kwargs...) -> V, C + left_orth(A::AbstractArray, ndims_codomain::Val; kwargs...) -> V, C left_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> V, C Compute the left orthogonal decomposition of a generic N-dimensional array, by interpreting it as @@ -138,7 +138,7 @@ left_orth """ right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> C, V right_orth(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> C, V - right_orth(A::AbstractArray, length_codomain::Val; kwargs...) -> C, V + right_orth(A::AbstractArray, ndims_codomain::Val; kwargs...) -> C, V right_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> C, V Compute the right orthogonal decomposition of a generic N-dimensional array, by interpreting it as @@ -156,7 +156,7 @@ right_orth """ factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> X, Y factorize(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> X, Y - factorize(A::AbstractArray, length_codomain::Val; kwargs...) -> X, Y + factorize(A::AbstractArray, ndims_codomain::Val; kwargs...) -> X, Y factorize(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> X, Y Compute the decomposition of a generic N-dimensional array, by interpreting it as @@ -174,7 +174,7 @@ factorize """ eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> D, V eigen(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> D, V - eigen(A::AbstractArray, length_codomain::Val; kwargs...) -> D, V + eigen(A::AbstractArray, ndims_codomain::Val; kwargs...) -> D, V eigen(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> D, V Compute the eigenvalue decomposition of a generic N-dimensional array, by interpreting it as @@ -191,11 +191,11 @@ their labels or directly through a bi-permutation. See also `MatrixAlgebraKit.eig_full!`, `MatrixAlgebraKit.eig_trunc!`, `MatrixAlgebraKit.eig_vals!`, `MatrixAlgebraKit.eigh_full!`, `MatrixAlgebraKit.eigh_trunc!`, and `MatrixAlgebraKit.eigh_vals!`. """ -function eigen(A::AbstractArray, length_codomain::Val; kwargs...) +function eigen(A::AbstractArray, ndims_codomain::Val; kwargs...) # tensor to matrix - A_mat = matricize(A, length_codomain) + A_mat = matricize(A, ndims_codomain) D, V = MatrixAlgebra.eigen!(A_mat; kwargs...) - biperm = trivialbiperm(length_codomain, Val(ndims(A))) + biperm = trivialbiperm(ndims_codomain, Val(ndims(A))) axes_codomain, = blocks(axes(A)[biperm]) axes_V = tuplemortar((axes_codomain, (axes(V, ndims(V)),))) return D, unmatricize(V, axes_V) @@ -204,7 +204,7 @@ end """ eigvals(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> D eigvals(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> D - eigvals(A::AbstractArray, length_codomain::Val; kwargs...) -> D + eigvals(A::AbstractArray, ndims_codomain::Val; kwargs...) -> D eigvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> D Compute the eigenvalues of a generic N-dimensional array, by interpreting it as @@ -219,15 +219,15 @@ their labels or directly through a bi-permutation. The output is a vector of eig See also `MatrixAlgebraKit.eig_vals!` and `MatrixAlgebraKit.eigh_vals!`. """ -function eigvals(A::AbstractArray, length_codomain::Val; kwargs...) - A_mat = matricize(A, length_codomain) +function eigvals(A::AbstractArray, ndims_codomain::Val; kwargs...) + A_mat = matricize(A, ndims_codomain) return MatrixAlgebra.eigvals!(A_mat; kwargs...) end """ svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> U, S, Vᴴ svd(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> U, S, Vᴴ - svd(A::AbstractArray, length_codomain::Val; kwargs...) -> U, S, Vᴴ + svd(A::AbstractArray, ndims_codomain::Val; kwargs...) -> U, S, Vᴴ svd(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> U, S, Vᴴ Compute the SVD decomposition of a generic N-dimensional array, by interpreting it as @@ -243,10 +243,10 @@ their labels or directly through a bi-permutation. See also `MatrixAlgebraKit.svd_full!`, `MatrixAlgebraKit.svd_compact!`, and `MatrixAlgebraKit.svd_trunc!`. """ -function svd(A::AbstractArray, length_codomain::Val; kwargs...) - A_mat = matricize(A, length_codomain) +function svd(A::AbstractArray, ndims_codomain::Val; kwargs...) + A_mat = matricize(A, ndims_codomain) U, S, Vᴴ = MatrixAlgebra.svd!(A_mat; kwargs...) - biperm = trivialbiperm(length_codomain, Val(ndims(A))) + biperm = trivialbiperm(ndims_codomain, Val(ndims(A))) axes_codomain, axes_domain = blocks(axes(A)[biperm]) axes_U = tuplemortar((axes_codomain, (axes(U, 2),))) axes_Vᴴ = tuplemortar(((axes(Vᴴ, 1),), axes_domain)) @@ -256,7 +256,7 @@ end """ svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain) -> S svdvals(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}) -> S - svdvals(A::AbstractArray, length_codomain::Val) -> S + svdvals(A::AbstractArray, ndims_codomain::Val) -> S svdvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}) -> S Compute the singular values of a generic N-dimensional array, by interpreting it as @@ -265,15 +265,15 @@ their labels or directly through a bi-permutation. The output is a vector of sin See also `MatrixAlgebraKit.svd_vals!`. """ -function svdvals(A::AbstractArray, length_codomain::Val) - A_mat = matricize(A, length_codomain) +function svdvals(A::AbstractArray, ndims_codomain::Val) + A_mat = matricize(A, ndims_codomain) return MatrixAlgebra.svdvals!(A_mat) end """ left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> N left_null(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> N - left_null(A::AbstractArray, length_codomain::Val; kwargs...) -> N + left_null(A::AbstractArray, ndims_codomain::Val; kwargs...) -> N left_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> N Compute the left nullspace of a generic N-dimensional array, by interpreting it as @@ -289,10 +289,10 @@ The output satisfies `N' * A ≈ 0` and `N' * N ≈ I`. The options are `:qr`, `:qrpos` and `:svd`. The former two require `0 == atol == rtol`. The default is `:qrpos` if `atol == rtol == 0`, and `:svd` otherwise. """ -function left_null(A::AbstractArray, length_codomain::Val; kwargs...) - A_mat = matricize(A, length_codomain) +function left_null(A::AbstractArray, ndims_codomain::Val; kwargs...) + A_mat = matricize(A, ndims_codomain) N = MatrixAlgebraKit.left_null!(A_mat; kwargs...) - biperm = trivialbiperm(length_codomain, Val(ndims(A))) + biperm = trivialbiperm(ndims_codomain, Val(ndims(A))) axes_codomain = first(blocks(axes(A)[biperm])) axes_N = tuplemortar((axes_codomain, (axes(N, 2),))) return unmatricize(N, axes_N) @@ -301,7 +301,7 @@ end """ right_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Nᴴ right_null(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> Nᴴ - right_null(A::AbstractArray, length_codomain::Val::Val; kwargs...) -> Nᴴ + right_null(A::AbstractArray, ndims_codomain::Val::Val; kwargs...) -> Nᴴ right_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> Nᴴ Compute the right nullspace of a generic N-dimensional array, by interpreting it as @@ -317,10 +317,10 @@ The output satisfies `A * Nᴴ' ≈ 0` and `Nᴴ * Nᴴ' ≈ I`. The options are `:lq`, `:lqpos` and `:svd`. The former two require `0 == atol == rtol`. The default is `:lqpos` if `atol == rtol == 0`, and `:svd` otherwise. """ -function right_null(A::AbstractArray, length_codomain::Val; kwargs...) - A_mat = matricize(A, length_codomain) +function right_null(A::AbstractArray, ndims_codomain::Val; kwargs...) + A_mat = matricize(A, ndims_codomain) Nᴴ = MatrixAlgebraKit.right_null!(A_mat; kwargs...) - biperm = trivialbiperm(length_codomain, Val(ndims(A))) + biperm = trivialbiperm(ndims_codomain, Val(ndims(A))) axes_domain = last(blocks((axes(A)[biperm]))) axes_Nᴴ = tuplemortar(((axes(Nᴴ, 1),), axes_domain)) return unmatricize(Nᴴ, axes_Nᴴ) diff --git a/src/matricize.jl b/src/matricize.jl index 52661d0..5297638 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -23,10 +23,10 @@ trivial_axis(::Type{<:AbstractBlockedUnitRange}) = blockedrange([1]) tensor_product(::FusionStyle, ax1, ax2) = ax1 ⊗ ax2 # Inner version takes a list of sub-permutations, overload this one if needed. -function matricize_axes(style::FusionStyle, a::AbstractArray, length_codomain::Val) - unval(length_codomain) ≤ ndims(a) || +function matricize_axes(style::FusionStyle, a::AbstractArray, ndims_codomain::Val) + unval(ndims_codomain) ≤ ndims(a) || throw(ArgumentError("Codomain length exceeds number of dimensions.")) - biperm = trivialbiperm(length_codomain, Val(ndims(a))) + biperm = trivialbiperm(ndims_codomain, Val(ndims(a))) axesblocks = blocks(axes(a)[biperm]) init_axis = trivial_axis(axis_type(a)) return map(axesblocks) do axesblock @@ -35,8 +35,8 @@ function matricize_axes(style::FusionStyle, a::AbstractArray, length_codomain::V end end end -function matricize_axes(a::AbstractArray, length_codomain::Val) - return matricize_axes(FusionStyle(a), a, length_codomain) +function matricize_axes(a::AbstractArray, ndims_codomain::Val) + return matricize_axes(FusionStyle(a), a, ndims_codomain) end # Inner version takes a list of sub-permutations, overload this one if needed. @@ -82,17 +82,17 @@ end # matrix factorizations assume copy # maybe: copy=false kwarg -function matricize(a::AbstractArray, length_codomain::Val) - return matricize(FusionStyle(a), a, length_codomain) +function matricize(a::AbstractArray, ndims_codomain::Val) + return matricize(FusionStyle(a), a, ndims_codomain) end # This is the primary function that should be overloaded for new fusion styles. # This assumes the permutation was already performed. function matricize( - style::FusionStyle, a::AbstractArray, length_codomain::Val + style::FusionStyle, a::AbstractArray, ndims_codomain::Val ) return throw( MethodError( - matricize, Tuple{typeof(style), typeof(a), typeof(length_codomain)} + matricize, Tuple{typeof(style), typeof(a), typeof(ndims_codomain)} ) ) end @@ -288,8 +288,8 @@ end # Defaults to ReshapeFusion, a simple reshape struct ReshapeFusion <: FusionStyle end FusionStyle(::Type{<:AbstractArray}) = ReshapeFusion() -function matricize(style::ReshapeFusion, a::AbstractArray, length_codomain::Val) - return reshape(a, matricize_axes(style, a, length_codomain)) +function matricize(style::ReshapeFusion, a::AbstractArray, ndims_codomain::Val) + return reshape(a, matricize_axes(style, a, ndims_codomain)) end function unmatricize( style::ReshapeFusion, m::AbstractMatrix, diff --git a/src/matrixfunctions.jl b/src/matrixfunctions.jl index 9b2ec87..2a4294e 100644 --- a/src/matrixfunctions.jl +++ b/src/matrixfunctions.jl @@ -33,14 +33,10 @@ const MATRIX_FUNCTIONS = [ for f in MATRIX_FUNCTIONS @eval begin - function $f( - a::AbstractArray, - length_codomain::Val; - kwargs..., - ) - a_mat = matricize(a, length_codomain) + function $f(a::AbstractArray, ndims_codomain::Val; kwargs...) + a_mat = matricize(a, ndims_codomain) fa_mat = Base.$f(a_mat; kwargs...) - biperm = trivialbiperm(length_codomain, Val(ndims(a))) + biperm = trivialbiperm(ndims_codomain, Val(ndims(a))) return unmatricize(fa_mat, axes(a)[biperm]) end function $f( From 30a76195293ce1897e0fca4152643d6641d24d8b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 2 Dec 2025 16:12:17 -0500 Subject: [PATCH 3/7] Reorganize BlockArrays matricization code --- Project.toml | 2 -- src/TensorAlgebra.jl | 20 ++------------ src/blockarrays.jl | 25 +++++++++++++++++ src/matricize.jl | 66 ++++++++++++++------------------------------ test/test_basics.jl | 43 ++++++++++++----------------- 5 files changed, 67 insertions(+), 89 deletions(-) create mode 100644 src/blockarrays.jl diff --git a/Project.toml b/Project.toml index 311e276..7bfcb2c 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" -TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d" TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" @@ -26,7 +25,6 @@ EllipsisNotation = "1.8" LinearAlgebra = "1.10" MatrixAlgebraKit = "0.2, 0.3, 0.4, 0.5, 0.6" TensorOperations = "5" -TensorProducts = "0.1.5" TupleTools = "1.6" TypeParameterAccessors = "0.2.1, 0.3, 0.4" julia = "1.10" diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index b848c1d..ec12c5a 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -1,28 +1,14 @@ module TensorAlgebra -export contract, - contract!, - eigen, - eigvals, - factorize, - left_null, - left_orth, - left_polar, - lq, - qr, - right_null, - right_orth, - right_polar, - orth, - polar, - svd, - svdvals +export contract, contract!, eigen, eigvals, factorize, left_null, left_orth, left_polar, + lq, qr, right_null, right_orth, right_polar, orth, polar, svd, svdvals include("MatrixAlgebra.jl") include("blockedtuple.jl") include("blockedpermutation.jl") include("BaseExtensions/BaseExtensions.jl") include("matricize.jl") +include("blockarrays.jl") include("contract/contract.jl") include("contract/output_labels.jl") include("contract/blockedperms.jl") diff --git a/src/blockarrays.jl b/src/blockarrays.jl new file mode 100644 index 0000000..09cec3d --- /dev/null +++ b/src/blockarrays.jl @@ -0,0 +1,25 @@ +using BlockArrays: AbstractBlockArray, AbstractBlockedUnitRange, blockedrange, + eachblockaxes1 + +struct BlockReshapeFusion <: FusionStyle end + +FusionStyle(::Type{<:AbstractBlockArray}) = BlockReshapeFusion() + +trivial_axis(::Type{<:AbstractBlockedUnitRange}) = blockedrange([1]) +function mortar_axis(axs) + all(isone ∘ first, axs) || + throw(ArgumentError("Only one-based axes are supported")) + return blockedrange(length.(axs)) +end +function tensor_product_axis( + ::BlockReshapeFusion, r1::AbstractUnitRange, r2::AbstractUnitRange + ) + isone(first(r1)) || isone(first(r2)) || + throw(ArgumentError("Only one-based axes are supported")) + blockaxpairs = Iterators.product(eachblockaxes1(r1), eachblockaxes1(r2)) + blockaxs = vec(map(splat(tensor_product_axis), blockaxpairs)) + return mortar_axis(blockaxs) +end +function matricize(style::BlockReshapeFusion, a::AbstractArray, ndims_codomain::Val) + return reshape(a, matricize_axes(style, a, ndims_codomain)) +end diff --git a/src/matricize.jl b/src/matricize.jl index 5297638..2e04dcb 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -1,9 +1,5 @@ -using LinearAlgebra: Diagonal - -using BlockArrays: AbstractBlockedUnitRange, blockedrange - -using TensorProducts: ⊗ using .BaseExtensions: _permutedims, _permutedims! +using LinearAlgebra: Diagonal # ===================================== FusionStyle ====================================== abstract type FusionStyle end @@ -15,14 +11,18 @@ FusionStyle(T::Type) = throw(MethodError(FusionStyle, (T,))) axis_type(a::AbstractArray) = eltype(axes(a)) axis_type(a::AbstractArray{<:Any, 0}) = Base.OneTo{Int} trivial_axis(::Type{<:AbstractUnitRange}) = Base.OneTo(1) -trivial_axis(::Type{<:AbstractBlockedUnitRange}) = blockedrange([1]) -# Fallback to `TensorProducts.⊗`. -# TODO: Remove dependency on TensorProducts.jl, update downstream packages: -# BlockSparseArrays.jl, GradedArrays.jl, FusionTensors.jl. -tensor_product(::FusionStyle, ax1, ax2) = ax1 ⊗ ax2 +# Tensor product two spaces (ranges) together based on a fusion style. +function tensor_product_axis(::FusionStyle, r1::AbstractUnitRange, r2::AbstractUnitRange) + return tensor_product_axis(ReshapeFusion(), r1, r2) +end +function tensor_product_axis(r1::AbstractUnitRange, r2::AbstractUnitRange) + style1 = FusionStyle(r1) + style2 = FusionStyle(r2) + style1 == style2 || error("Styles must match.") + return tensor_product_axis(style1, r1, r2) +end -# Inner version takes a list of sub-permutations, overload this one if needed. function matricize_axes(style::FusionStyle, a::AbstractArray, ndims_codomain::Val) unval(ndims_codomain) ≤ ndims(a) || throw(ArgumentError("Codomain length exceeds number of dimensions.")) @@ -31,7 +31,7 @@ function matricize_axes(style::FusionStyle, a::AbstractArray, ndims_codomain::Va init_axis = trivial_axis(axis_type(a)) return map(axesblocks) do axesblock return reduce(axesblock; init = init_axis) do ax1, ax2 - return tensor_product(style, ax1, ax2) + return tensor_product_axis(style, ax1, ax2) end end end @@ -45,10 +45,10 @@ end # TODO: Deprecate `permuteblockeddims` in favor of `bipermutedims`. # Keeping it here for backwards compatibility. function bipermutedims(a::AbstractArray, perm1, perm2) - return permuteblockeddims(a, perm1, perm2) + return _permutedims(a, (perm1..., perm2...)) end function bipermutedims!(a_dest::AbstractArray, a_src::AbstractArray, perm1, perm2) - return permuteblockeddims!(a_dest, a_src, perm1, perm2) + return _permutedims!(a_dest, a_src, (perm1..., perm2...)) end function bipermutedims(a::AbstractArray, biperm::AbstractBlockPermutation{2}) return bipermutedims(a, blocks(biperm)...) @@ -59,24 +59,6 @@ function bipermutedims!( return bipermutedims!(a_dest, a_src, blocks(biperm)...) end -# Older interface. -# TODO: Deprecate in favor of `bipermutedims` (or decide if we want to keep it -# in case there are applications of more general partitionings). -function permuteblockeddims(a::AbstractArray, perm1, perm2) - return _permutedims(a, (perm1..., perm2...)) -end -function permuteblockeddims!(a_dest::AbstractArray, a_src::AbstractArray, perm1, perm2) - return _permutedims!(a_dest, a_src, (perm1..., perm2...)) -end -function permuteblockeddims(a::AbstractArray, biperm::AbstractBlockPermutation{2}) - return permuteblockeddims(a, blocks(biperm)...) -end -function permuteblockeddims!( - a_dest::AbstractArray, a_src::AbstractArray, biperm::AbstractBlockPermutation{2} - ) - return permuteblockeddims!(a_dest, a_src, blocks(biperm)...) -end - # ===================================== matricize ======================================== # TBD settle copy/not copy convention # matrix factorizations assume copy @@ -90,11 +72,7 @@ end function matricize( style::FusionStyle, a::AbstractArray, ndims_codomain::Val ) - return throw( - MethodError( - matricize, Tuple{typeof(style), typeof(a), typeof(ndims_codomain)} - ) - ) + return matricize(ReshapeFusion(), a, ndims_codomain) end function matricize( @@ -170,14 +148,7 @@ function unmatricize( axes_codomain::Tuple{Vararg{AbstractUnitRange}}, axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) - return throw( - MethodError( - unmatricize, - Tuple{ - typeof(style), typeof(m), typeof(axes_codomain), typeof(axes_domain), - }, - ) - ) + return unmatricize(ReshapeFusion(), m, axes_codomain, axes_domain) end function unmatricize(m::AbstractMatrix, blocked_axes::AbstractBlockTuple{2}) @@ -288,6 +259,11 @@ end # Defaults to ReshapeFusion, a simple reshape struct ReshapeFusion <: FusionStyle end FusionStyle(::Type{<:AbstractArray}) = ReshapeFusion() +function tensor_product_axis(::ReshapeFusion, r1::AbstractUnitRange, r2::AbstractUnitRange) + isone(first(r1)) || isone(first(r2)) || + throw(ArgumentError("Only one-based axes are supported")) + return Base.OneTo(length(r1) * length(r2)) +end function matricize(style::ReshapeFusion, a::AbstractArray, ndims_codomain::Val) return reshape(a, matricize_axes(style, a, ndims_codomain)) end diff --git a/test/test_basics.jl b/test/test_basics.jl index b624bbf..2693c47 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -16,8 +16,6 @@ using TensorAlgebra: matricize, bipermutedims, bipermutedims!, - permuteblockeddims, - permuteblockeddims!, tuplemortar, unmatricize, unmatricize! @@ -35,29 +33,24 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test length_domain(bt) == 1 end - @testset "bipermutedims/permuteblockeddims (eltype=$elt)" for f in - (:bipermutedims, :permuteblockeddims), - elt in elts - f! = Symbol(f, :!) - @eval begin - a = randn($elt, 2, 3, 4, 5) - a_perm = $f(a, blockedpermvcat((3, 1), (2, 4))) - @test a_perm == permutedims(a, (3, 1, 2, 4)) - - a = randn($elt, 2, 3, 4, 5) - a_perm = $f(a, (3, 1), (2, 4)) - @test a_perm == permutedims(a, (3, 1, 2, 4)) - - a = randn($elt, 2, 3, 4, 5) - a_perm = Array{$elt}(undef, (4, 2, 3, 5)) - $f!(a_perm, a, blockedpermvcat((3, 1), (2, 4))) - @test a_perm == permutedims(a, (3, 1, 2, 4)) - - a = randn($elt, 2, 3, 4, 5) - a_perm = Array{$elt}(undef, (4, 2, 3, 5)) - $f!(a_perm, a, (3, 1), (2, 4)) - @test a_perm == permutedims(a, (3, 1, 2, 4)) - end + @testset "bipermutedims (eltype=$elt)" for elt in elts + a = randn(elt, 2, 3, 4, 5) + a_perm = bipermutedims(a, blockedpermvcat((3, 1), (2, 4))) + @test a_perm == permutedims(a, (3, 1, 2, 4)) + + a = randn(elt, 2, 3, 4, 5) + a_perm = bipermutedims(a, (3, 1), (2, 4)) + @test a_perm == permutedims(a, (3, 1, 2, 4)) + + a = randn(elt, 2, 3, 4, 5) + a_perm = Array{elt}(undef, (4, 2, 3, 5)) + bipermutedims!(a_perm, a, blockedpermvcat((3, 1), (2, 4))) + @test a_perm == permutedims(a, (3, 1, 2, 4)) + + a = randn(elt, 2, 3, 4, 5) + a_perm = Array{elt}(undef, (4, 2, 3, 5)) + bipermutedims!(a_perm, a, (3, 1), (2, 4)) + @test a_perm == permutedims(a, (3, 1, 2, 4)) end @testset "matricize (eltype=$elt)" for elt in elts a = randn(elt, 2, 3, 4, 5) From 8aa8ec45d56c4ba5fb2e025859919b2280cab5d4 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 2 Dec 2025 17:11:13 -0500 Subject: [PATCH 4/7] Better definitions of matricize for BlockArrays --- src/blockarrays.jl | 54 ++++++++++++++++++++++++++++--- src/matricize.jl | 7 ++-- test/test_blockarrays_contract.jl | 51 +++++++++++++---------------- 3 files changed, 74 insertions(+), 38 deletions(-) diff --git a/src/blockarrays.jl b/src/blockarrays.jl index 09cec3d..382e89b 100644 --- a/src/blockarrays.jl +++ b/src/blockarrays.jl @@ -1,11 +1,10 @@ -using BlockArrays: AbstractBlockArray, AbstractBlockedUnitRange, blockedrange, - eachblockaxes1 +using BlockArrays: AbstractBlockArray, AbstractBlockedUnitRange, BlockedArray, blockedrange, + eachblockaxes1, mortar struct BlockReshapeFusion <: FusionStyle end - FusionStyle(::Type{<:AbstractBlockArray}) = BlockReshapeFusion() -trivial_axis(::Type{<:AbstractBlockedUnitRange}) = blockedrange([1]) +trivial_axis(::BlockReshapeFusion) = blockedrange([1]) function mortar_axis(axs) all(isone ∘ first, axs) || throw(ArgumentError("Only one-based axes are supported")) @@ -21,5 +20,50 @@ function tensor_product_axis( return mortar_axis(blockaxs) end function matricize(style::BlockReshapeFusion, a::AbstractArray, ndims_codomain::Val) - return reshape(a, matricize_axes(style, a, ndims_codomain)) + ax = matricize_axes(style, a, ndims_codomain) + reshaped_blocks_a = reshape(blocks(a), blocklength.(ax)) + bs = map(reshaped_blocks_a) do b + matricize(b, ndims_codomain) + end + return mortar(bs, ax) +end +using BlockArrays: blocklengths +function unmatricize( + ::BlockReshapeFusion, + m::AbstractMatrix, + codomain_axes::Tuple{Vararg{AbstractUnitRange}}, + domain_axes::Tuple{Vararg{AbstractUnitRange}}, + ) + ax = (codomain_axes..., domain_axes...) + reshaped_blocks_m = reshape(blocks(m), blocklength.(ax)) + bs = map(CartesianIndices(reshaped_blocks_m)) do I + block_axes_I = BlockedTuple( + map(ntuple(identity, length(ax))) do i + return Base.axes1(ax[i][Block(I[i])]) + end, + (length(codomain_axes), length(domain_axes)), + ) + return unmatricize(reshaped_blocks_m[I], block_axes_I) + end + return mortar(bs, ax) +end + +struct BlockedReshapeFusion <: FusionStyle end +FusionStyle(::Type{<:BlockedArray}) = BlockedReshapeFusion() +unblock(a::BlockedArray) = a.blocks +unblock(a::AbstractBlockArray) = a[Base.OneTo.(size(a))...] +unblock(a::AbstractArray) = a +function matricize(::BlockedReshapeFusion, a::AbstractArray, ndims_codomain::Val) + return matricize(ReshapeFusion(), unblock(a), ndims_codomain) +end +function unmatricize( + style::BlockedReshapeFusion, m::AbstractMatrix, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + a = unmatricize( + ReshapeFusion(), m, + Base.OneTo.(length.(axes_codomain)), Base.OneTo.(length.(axes_domain)), + ) + return BlockedArray(a, (axes_codomain..., axes_domain...)) end diff --git a/src/matricize.jl b/src/matricize.jl index 2e04dcb..1b94a1d 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -8,9 +8,7 @@ FusionStyle(x) = FusionStyle(typeof(x)) FusionStyle(T::Type) = throw(MethodError(FusionStyle, (T,))) # ======================================= misc ======================================== -axis_type(a::AbstractArray) = eltype(axes(a)) -axis_type(a::AbstractArray{<:Any, 0}) = Base.OneTo{Int} -trivial_axis(::Type{<:AbstractUnitRange}) = Base.OneTo(1) +trivial_axis(::FusionStyle) = trivial_axis(ReshapeFusion()) # Tensor product two spaces (ranges) together based on a fusion style. function tensor_product_axis(::FusionStyle, r1::AbstractUnitRange, r2::AbstractUnitRange) @@ -28,7 +26,7 @@ function matricize_axes(style::FusionStyle, a::AbstractArray, ndims_codomain::Va throw(ArgumentError("Codomain length exceeds number of dimensions.")) biperm = trivialbiperm(ndims_codomain, Val(ndims(a))) axesblocks = blocks(axes(a)[biperm]) - init_axis = trivial_axis(axis_type(a)) + init_axis = trivial_axis(style) return map(axesblocks) do axesblock return reduce(axesblock; init = init_axis) do ax1, ax2 return tensor_product_axis(style, ax1, ax2) @@ -259,6 +257,7 @@ end # Defaults to ReshapeFusion, a simple reshape struct ReshapeFusion <: FusionStyle end FusionStyle(::Type{<:AbstractArray}) = ReshapeFusion() +trivial_axis(::ReshapeFusion) = Base.OneTo(1) function tensor_product_axis(::ReshapeFusion, r1::AbstractUnitRange, r2::AbstractUnitRange) isone(first(r1)) || isone(first(r2)) || throw(ArgumentError("Only one-based axes are supported")) diff --git a/test/test_blockarrays_contract.jl b/test/test_blockarrays_contract.jl index 8dada0f..6f74b84 100644 --- a/test/test_blockarrays_contract.jl +++ b/test/test_blockarrays_contract.jl @@ -1,6 +1,6 @@ using BlockArrays: Block, BlockArray, BlockedArray, blockedrange, blocksize using Random: randn! -using TensorAlgebra: contract +using TensorAlgebra: contract, matricize, unmatricize using Test: @test, @testset function randn_blockdiagonal(elt::Type, axes::Tuple) @@ -68,48 +68,41 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) end @testset "BlockArray" begin - a1, a3, a3 = BlockArray.((a1, a2, a3)) + a1, a2, a3 = BlockArray.((a1, a2, a3)) # matrix matrix a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4)) - a_dest_dense, dimnames_dest_dense = contract( - a1_dense, (1, -1, 2, -2), a2_dense, (2, -3, 1, -4) - ) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray{elt} - @test a_dest ≈ a_dest_dense + m1 = matricize(a1, (2, 4), (1, 3)) + m2 = matricize(a2, (3, 1), (2, 4)) + m_dest = matricize(a_dest, Val(2)) + @test m_dest ≈ m1 * m2 # matrix vector a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2)) - a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray{elt} - @test a_dest ≈ a_dest_dense + m1 = matricize(a1, (2, 3), (1, 4)) + m2 = matricize(a3, (2, 1), ()) + m_dest = matricize(a_dest, Val(2)) + @test m_dest ≈ m1 * m2 # vector matrix a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1)) - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray{elt} - @test a_dest ≈ a_dest_dense + m1 = matricize(a3, (), (1, 2)) + m2 = matricize(a1, (4, 1), (2, 3)) + m_dest = matricize(a_dest, Val(0)) + @test m_dest ≈ m1 * m2 # vector vector - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1)) a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray{elt, 0} - @test a_dest ≈ a_dest_dense + m1 = matricize(a3, (), (1, 2)) + m2 = matricize(a3, (2, 1), ()) + m_dest = matricize(a_dest, Val(0)) + @test m_dest ≈ m1 * m2 # outer product a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4)) - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray{elt} - @test a_dest ≈ a_dest_dense + m1 = matricize(a3, (1, 2), ()) + m2 = matricize(a3, (), (1, 2)) + m_dest = matricize(a_dest, Val(2)) + @test m_dest ≈ m1 * m2 end end From be259355aefe85d132d7a49581716d1f1b82e6a8 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 2 Dec 2025 17:17:56 -0500 Subject: [PATCH 5/7] More general trivial_axis interface --- src/blockarrays.jl | 2 +- src/matricize.jl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/blockarrays.jl b/src/blockarrays.jl index 382e89b..0b8dbc2 100644 --- a/src/blockarrays.jl +++ b/src/blockarrays.jl @@ -4,7 +4,7 @@ using BlockArrays: AbstractBlockArray, AbstractBlockedUnitRange, BlockedArray, b struct BlockReshapeFusion <: FusionStyle end FusionStyle(::Type{<:AbstractBlockArray}) = BlockReshapeFusion() -trivial_axis(::BlockReshapeFusion) = blockedrange([1]) +trivial_axis(::BlockReshapeFusion, a::AbstractArray) = blockedrange([1]) function mortar_axis(axs) all(isone ∘ first, axs) || throw(ArgumentError("Only one-based axes are supported")) diff --git a/src/matricize.jl b/src/matricize.jl index 1b94a1d..e532300 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -8,7 +8,7 @@ FusionStyle(x) = FusionStyle(typeof(x)) FusionStyle(T::Type) = throw(MethodError(FusionStyle, (T,))) # ======================================= misc ======================================== -trivial_axis(::FusionStyle) = trivial_axis(ReshapeFusion()) +trivial_axis(style::FusionStyle, a::AbstractArray) = trivial_axis(ReshapeFusion(), a) # Tensor product two spaces (ranges) together based on a fusion style. function tensor_product_axis(::FusionStyle, r1::AbstractUnitRange, r2::AbstractUnitRange) @@ -26,7 +26,7 @@ function matricize_axes(style::FusionStyle, a::AbstractArray, ndims_codomain::Va throw(ArgumentError("Codomain length exceeds number of dimensions.")) biperm = trivialbiperm(ndims_codomain, Val(ndims(a))) axesblocks = blocks(axes(a)[biperm]) - init_axis = trivial_axis(style) + init_axis = trivial_axis(style, a) return map(axesblocks) do axesblock return reduce(axesblock; init = init_axis) do ax1, ax2 return tensor_product_axis(style, ax1, ax2) @@ -257,7 +257,7 @@ end # Defaults to ReshapeFusion, a simple reshape struct ReshapeFusion <: FusionStyle end FusionStyle(::Type{<:AbstractArray}) = ReshapeFusion() -trivial_axis(::ReshapeFusion) = Base.OneTo(1) +trivial_axis(::ReshapeFusion, a::AbstractArray) = Base.OneTo(1) function tensor_product_axis(::ReshapeFusion, r1::AbstractUnitRange, r2::AbstractUnitRange) isone(first(r1)) || isone(first(r2)) || throw(ArgumentError("Only one-based axes are supported")) From 9095637d587891d13b07d5709aa66c124b5e6a1a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 2 Dec 2025 17:25:21 -0500 Subject: [PATCH 6/7] Mark as breaking --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7bfcb2c..feb52e0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.5.5" +version = "0.6.0" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" From a0f686f992f4ad5d15ecc4f9de26cd4e5a49cb2e Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 2 Dec 2025 17:27:06 -0500 Subject: [PATCH 7/7] Bump subdir versions --- docs/Project.toml | 2 +- examples/Project.toml | 2 +- test/Project.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index a3bb587..b9f8ff9 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -9,4 +9,4 @@ TensorAlgebra = {path = ".."} [compat] Documenter = "1.8.1" Literate = "2.20.1" -TensorAlgebra = "0.5" +TensorAlgebra = "0.6" diff --git a/examples/Project.toml b/examples/Project.toml index 8b00c47..fbe2c20 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,4 +5,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" TensorAlgebra = {path = ".."} [compat] -TensorAlgebra = "0.5" +TensorAlgebra = "0.6" diff --git a/test/Project.toml b/test/Project.toml index 9bdf530..fd7d5b2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -26,7 +26,7 @@ Random = "1.10" SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" -TensorAlgebra = "0.5" +TensorAlgebra = "0.6" TensorOperations = "5.1.4" Test = "1.10" TestExtras = "0.3.1"