diff --git a/Project.toml b/Project.toml index 2207a08..9ac4fbd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.6.1" +version = "0.6.2" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/matricize.jl b/src/matricize.jl index 200af43..10b15d9 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -8,30 +8,104 @@ FusionStyle(x) = FusionStyle(typeof(x)) FusionStyle(T::Type) = throw(MethodError(FusionStyle, (T,))) # ======================================= misc ======================================== -trivial_axis(style::FusionStyle, a::AbstractArray) = trivial_axis(ReshapeFusion(), a) +function trivial_axis( + style::FusionStyle, + ::Val{:codomain}, + a::AbstractArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + return trivial_axis(style, a, axes_codomain, axes_domain) +end +function trivial_axis( + style::FusionStyle, + ::Val{:domain}, + a::AbstractArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + return trivial_axis(style, a, axes_codomain, axes_domain) +end +function trivial_axis( + style::FusionStyle, + a::AbstractArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + return trivial_axis(style, a) +end +function trivial_axis(style::FusionStyle, a::AbstractArray) + return trivial_axis(ReshapeFusion(), a) +end # Tensor product two spaces (ranges) together based on a fusion style. +function tensor_product_axis( + style::FusionStyle, ::Val{:codomain}, r1::AbstractUnitRange, r2::AbstractUnitRange + ) + return tensor_product_axis(style, r1, r2) +end +function tensor_product_axis( + style::FusionStyle, ::Val{:domain}, r1::AbstractUnitRange, r2::AbstractUnitRange + ) + return tensor_product_axis(style, r1, r2) +end function tensor_product_axis(::FusionStyle, r1::AbstractUnitRange, r2::AbstractUnitRange) return tensor_product_axis(ReshapeFusion(), r1, r2) end +function tensor_product_axis(side::Val, r1::AbstractUnitRange, r2::AbstractUnitRange) + style = tensor_product_fusionstyle(r1, r2) + return tensor_product_axis(style, side, r1, r2) +end function tensor_product_axis(r1::AbstractUnitRange, r2::AbstractUnitRange) + style = tensor_product_fusionstyle(r1, r2) + return tensor_product_axis(style, r1, r2) +end +function tensor_product_fusionstyle(r1::AbstractUnitRange, r2::AbstractUnitRange) style1 = FusionStyle(r1) style2 = FusionStyle(r2) style1 == style2 || error("Styles must match.") - return tensor_product_axis(style1, r1, r2) + return style1 end +function fused_axis( + style::FusionStyle, + side::Val{:codomain}, + a::AbstractArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + init_axis = trivial_axis(style, side, a, axes_codomain, axes_domain) + return reduce(axes_codomain; init = init_axis) do ax1, ax2 + return tensor_product_axis(style, side, ax1, ax2) + end +end +function fused_axis( + style::FusionStyle, + side::Val{:domain}, + a::AbstractArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + init_axis = trivial_axis(style, side, a, axes_codomain, axes_domain) + return reduce(axes_domain; init = init_axis) do ax1, ax2 + return tensor_product_axis(style, side, ax1, ax2) + end +end +function matricize_axes( + style::FusionStyle, + a::AbstractArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + axis_codomain = fused_axis(style, Val(:codomain), a, axes_codomain, axes_domain) + axis_domain = fused_axis(style, Val(:domain), a, axes_codomain, axes_domain) + return axis_codomain, axis_domain +end function matricize_axes(style::FusionStyle, a::AbstractArray, ndims_codomain::Val) unval(ndims_codomain) ≤ ndims(a) || throw(ArgumentError("Codomain length exceeds number of dimensions.")) biperm = trivialbiperm(ndims_codomain, Val(ndims(a))) - axesblocks = blocks(axes(a)[biperm]) - init_axis = trivial_axis(style, a) - return map(axesblocks) do axesblock - return reduce(axesblock; init = init_axis) do ax1, ax2 - return tensor_product_axis(style, ax1, ax2) - end - end + return matricize_axes(style, a, blocks(axes(a)[biperm])...) end function matricize_axes(a::AbstractArray, ndims_codomain::Val) return matricize_axes(FusionStyle(a), a, ndims_codomain)