From 134bd49906c80b399c015de7a187af74e1cc301b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 5 Dec 2025 15:21:05 -0500 Subject: [PATCH] Rework the matricize interface --- Project.toml | 2 +- src/blockarrays.jl | 52 ++++++++++++++++++------- src/matricize.jl | 94 ++++++++++++++++++++++++++++------------------ 3 files changed, 98 insertions(+), 50 deletions(-) diff --git a/Project.toml b/Project.toml index 9ac4fbd..22b73c6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.6.2" +version = "0.6.3" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/blockarrays.jl b/src/blockarrays.jl index 056fcac..cbb2b62 100644 --- a/src/blockarrays.jl +++ b/src/blockarrays.jl @@ -4,14 +4,21 @@ using BlockArrays: AbstractBlockArray, AbstractBlockedUnitRange, BlockedArray, b struct BlockReshapeFusion <: FusionStyle end FusionStyle(::Type{<:AbstractBlockArray}) = BlockReshapeFusion() -trivial_axis(::BlockReshapeFusion, a::AbstractArray) = blockedrange([1]) +function trivial_axis( + style::BlockReshapeFusion, side::Val{:codomain}, a::AbstractArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + return blockedrange([1]) +end function mortar_axis(axs) all(isone ∘ first, axs) || throw(ArgumentError("Only one-based axes are supported")) return blockedrange(length.(axs)) end function tensor_product_axis( - ::BlockReshapeFusion, r1::AbstractUnitRange, r2::AbstractUnitRange + style::BlockReshapeFusion, side::Val{:codomain}, + r1::AbstractUnitRange, r2::AbstractUnitRange, ) (isone(first(r1)) && isone(first(r2))) || throw(ArgumentError("Only one-based axes are supported")) @@ -29,35 +36,33 @@ function matricize(style::BlockReshapeFusion, a::AbstractArray, ndims_codomain:: end using BlockArrays: blocklengths function unmatricize( - ::BlockReshapeFusion, - m::AbstractMatrix, - codomain_axes::Tuple{Vararg{AbstractUnitRange}}, - domain_axes::Tuple{Vararg{AbstractUnitRange}}, + ::BlockReshapeFusion, m::AbstractMatrix, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) - ax = (codomain_axes..., domain_axes...) + ax = (axes_codomain..., axes_domain...) reshaped_blocks_m = reshape(blocks(m), blocklength.(ax)) bs = map(CartesianIndices(reshaped_blocks_m)) do I block_axes_I = BlockedTuple( map(ntuple(identity, length(ax))) do i return Base.axes1(ax[i][Block(I[i])]) end, - (length(codomain_axes), length(domain_axes)), + (length(axes_codomain), length(axes_domain)), ) return unmatricize(reshaped_blocks_m[I], block_axes_I) end return mortar(bs, ax) end -struct BlockedReshapeFusion <: FusionStyle end -FusionStyle(::Type{<:BlockedArray}) = BlockedReshapeFusion() +FusionStyle(::Type{<:BlockedArray}) = ReshapeFusion() unblock(a::BlockedArray) = a.blocks unblock(a::AbstractBlockArray) = a[Base.OneTo.(size(a))...] unblock(a::AbstractArray) = a -function matricize(::BlockedReshapeFusion, a::AbstractArray, ndims_codomain::Val) +function matricize(::ReshapeFusion, a::BlockedArray, ndims_codomain::Val) return matricize(ReshapeFusion(), unblock(a), ndims_codomain) end -function unmatricize( - style::BlockedReshapeFusion, m::AbstractMatrix, +function unmatricize_blocked( + style::ReshapeFusion, m::AbstractMatrix, axes_codomain::Tuple{Vararg{AbstractUnitRange}}, axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) @@ -67,3 +72,24 @@ function unmatricize( ) return BlockedArray(a, (axes_codomain..., axes_domain...)) end +function unmatricize( + style::ReshapeFusion, m::AbstractMatrix, + axes_codomain::Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}}, + axes_domain::Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}}, + ) + return unmatricize_blocked(style, m, axes_codomain, axes_domain) +end +function unmatricize( + style::ReshapeFusion, m::AbstractMatrix, + axes_codomain::Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}}, + axes_domain::Tuple{Vararg{AbstractBlockedUnitRange}}, + ) + return unmatricize_blocked(style, m, axes_codomain, axes_domain) +end +function unmatricize( + style::ReshapeFusion, m::AbstractMatrix, + axes_codomain::Tuple{Vararg{AbstractBlockedUnitRange}}, + axes_domain::Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}}, + ) + return unmatricize_blocked(style, m, axes_codomain, axes_domain) +end diff --git a/src/matricize.jl b/src/matricize.jl index 10b15d9..01e9f66 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -9,48 +9,59 @@ FusionStyle(T::Type) = throw(MethodError(FusionStyle, (T,))) # ======================================= misc ======================================== function trivial_axis( - style::FusionStyle, - ::Val{:codomain}, - a::AbstractArray, + style::FusionStyle, side::Val{:codomain}, a::AbstractArray, axes_codomain::Tuple{Vararg{AbstractUnitRange}}, axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) - return trivial_axis(style, a, axes_codomain, axes_domain) + return throw(MethodError(trivial_axis, (style, side, a, axes_codomain, axes_domain))) end function trivial_axis( - style::FusionStyle, - ::Val{:domain}, - a::AbstractArray, + style::FusionStyle, ::Val{:domain}, a::AbstractArray, axes_codomain::Tuple{Vararg{AbstractUnitRange}}, axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) - return trivial_axis(style, a, axes_codomain, axes_domain) + return trivial_axis(style, Val(:codomain), a, axes_codomain, axes_domain) end function trivial_axis( - style::FusionStyle, - a::AbstractArray, + style::FusionStyle, a::AbstractArray, axes_codomain::Tuple{Vararg{AbstractUnitRange}}, axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) - return trivial_axis(style, a) + return trivial_axis(style, Val(:codomain), a, axes_codomain, axes_domain) end function trivial_axis(style::FusionStyle, a::AbstractArray) - return trivial_axis(ReshapeFusion(), a) + return trivial_axis(style, a, (), ()) +end +function trivial_axis( + a::AbstractArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + return trivial_axis(FusionStyle(a), a, axes_codomain, axes_domain) +end +function trivial_axis(side::Val, a::AbstractArray) + return trivial_axis(FusionStyle(a), side, a) +end +function trivial_axis(a::AbstractArray) + return trivial_axis(FusionStyle(a), a) end # Tensor product two spaces (ranges) together based on a fusion style. function tensor_product_axis( - style::FusionStyle, ::Val{:codomain}, r1::AbstractUnitRange, r2::AbstractUnitRange + style::FusionStyle, side::Val{:codomain}, + r1::AbstractUnitRange, r2::AbstractUnitRange, ) - return tensor_product_axis(style, r1, r2) + return throw(MethodError(tensor_product_axis, (style, side, r1, r2))) end function tensor_product_axis( style::FusionStyle, ::Val{:domain}, r1::AbstractUnitRange, r2::AbstractUnitRange ) - return tensor_product_axis(style, r1, r2) + return tensor_product_axis(style, Val(:codomain), r1, r2) end -function tensor_product_axis(::FusionStyle, r1::AbstractUnitRange, r2::AbstractUnitRange) - return tensor_product_axis(ReshapeFusion(), r1, r2) +function tensor_product_axis( + style::FusionStyle, r1::AbstractUnitRange, r2::AbstractUnitRange + ) + return tensor_product_axis(style, Val(:codomain), r1, r2) end function tensor_product_axis(side::Val, r1::AbstractUnitRange, r2::AbstractUnitRange) style = tensor_product_fusionstyle(r1, r2) @@ -68,9 +79,7 @@ function tensor_product_fusionstyle(r1::AbstractUnitRange, r2::AbstractUnitRange end function fused_axis( - style::FusionStyle, - side::Val{:codomain}, - a::AbstractArray, + style::FusionStyle, side::Val{:codomain}, a::AbstractArray, axes_codomain::Tuple{Vararg{AbstractUnitRange}}, axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) @@ -80,9 +89,7 @@ function fused_axis( end end function fused_axis( - style::FusionStyle, - side::Val{:domain}, - a::AbstractArray, + style::FusionStyle, side::Val{:domain}, a::AbstractArray, axes_codomain::Tuple{Vararg{AbstractUnitRange}}, axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) @@ -92,8 +99,7 @@ function fused_axis( end end function matricize_axes( - style::FusionStyle, - a::AbstractArray, + style::FusionStyle, a::AbstractArray, axes_codomain::Tuple{Vararg{AbstractUnitRange}}, axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) @@ -101,6 +107,13 @@ function matricize_axes( axis_domain = fused_axis(style, Val(:domain), a, axes_codomain, axes_domain) return axis_codomain, axis_domain end +function matricize_axes( + a::AbstractArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + return matricize_axes(FusionStyle(a), a, axes_codomain, axes_domain) +end function matricize_axes(style::FusionStyle, a::AbstractArray, ndims_codomain::Val) unval(ndims_codomain) ≤ ndims(a) || throw(ArgumentError("Codomain length exceeds number of dimensions.")) @@ -136,15 +149,15 @@ end # matrix factorizations assume copy # maybe: copy=false kwarg -function matricize(a::AbstractArray, ndims_codomain::Val) - return matricize(FusionStyle(a), a, ndims_codomain) -end # This is the primary function that should be overloaded for new fusion styles. # This assumes the permutation was already performed. function matricize( style::FusionStyle, a::AbstractArray, ndims_codomain::Val ) - return matricize(ReshapeFusion(), a, ndims_codomain) + return throw(MethodError(matricize, (style, a, ndims_codomain))) +end +function matricize(a::AbstractArray, ndims_codomain::Val) + return matricize(FusionStyle(a), a, ndims_codomain) end function matricize( @@ -207,20 +220,20 @@ function matricize( end # ==================================== unmatricize ======================================= +# This is the primary function that should be overloaded for new fusion styles. function unmatricize( - m::AbstractMatrix, + style::FusionStyle, m::AbstractMatrix, axes_codomain::Tuple{Vararg{AbstractUnitRange}}, axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) - return unmatricize(FusionStyle(m), m, axes_codomain, axes_domain) + return throw(MethodError(unmatricize, (style, m, axes_codomain, axes_domain))) end -# This is the primary function that should be overloaded for new fusion styles. function unmatricize( - style::FusionStyle, m::AbstractMatrix, + m::AbstractMatrix, axes_codomain::Tuple{Vararg{AbstractUnitRange}}, axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) - return unmatricize(ReshapeFusion(), m, axes_codomain, axes_domain) + return unmatricize(FusionStyle(m), m, axes_codomain, axes_domain) end function unmatricize(m::AbstractMatrix, blocked_axes::AbstractBlockTuple{2}) @@ -331,8 +344,17 @@ end # Defaults to ReshapeFusion, a simple reshape struct ReshapeFusion <: FusionStyle end FusionStyle(::Type{<:AbstractArray}) = ReshapeFusion() -trivial_axis(::ReshapeFusion, a::AbstractArray) = Base.OneTo(1) -function tensor_product_axis(::ReshapeFusion, r1::AbstractUnitRange, r2::AbstractUnitRange) +function trivial_axis( + style::ReshapeFusion, side::Val{:codomain}, a::AbstractArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + return Base.OneTo(1) +end +function tensor_product_axis( + style::ReshapeFusion, side::Val{:codomain}, + r1::AbstractUnitRange, r2::AbstractUnitRange, + ) (isone(first(r1)) && isone(first(r2))) || throw(ArgumentError("Only one-based axes are supported")) return Base.OneTo(length(r1) * length(r2))