Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.6.2"
version = "0.6.3"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
52 changes: 39 additions & 13 deletions src/blockarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@ using BlockArrays: AbstractBlockArray, AbstractBlockedUnitRange, BlockedArray, b
struct BlockReshapeFusion <: FusionStyle end
FusionStyle(::Type{<:AbstractBlockArray}) = BlockReshapeFusion()

trivial_axis(::BlockReshapeFusion, a::AbstractArray) = blockedrange([1])
function trivial_axis(
style::BlockReshapeFusion, side::Val{:codomain}, a::AbstractArray,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
return blockedrange([1])
end
function mortar_axis(axs)
all(isone ∘ first, axs) ||
throw(ArgumentError("Only one-based axes are supported"))
return blockedrange(length.(axs))
end
function tensor_product_axis(
::BlockReshapeFusion, r1::AbstractUnitRange, r2::AbstractUnitRange
style::BlockReshapeFusion, side::Val{:codomain},
r1::AbstractUnitRange, r2::AbstractUnitRange,
)
(isone(first(r1)) && isone(first(r2))) ||
throw(ArgumentError("Only one-based axes are supported"))
Expand All @@ -29,35 +36,33 @@ function matricize(style::BlockReshapeFusion, a::AbstractArray, ndims_codomain::
end
using BlockArrays: blocklengths
function unmatricize(
::BlockReshapeFusion,
m::AbstractMatrix,
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
domain_axes::Tuple{Vararg{AbstractUnitRange}},
::BlockReshapeFusion, m::AbstractMatrix,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
ax = (codomain_axes..., domain_axes...)
ax = (axes_codomain..., axes_domain...)
reshaped_blocks_m = reshape(blocks(m), blocklength.(ax))
bs = map(CartesianIndices(reshaped_blocks_m)) do I
block_axes_I = BlockedTuple(
map(ntuple(identity, length(ax))) do i
return Base.axes1(ax[i][Block(I[i])])
end,
(length(codomain_axes), length(domain_axes)),
(length(axes_codomain), length(axes_domain)),
)
return unmatricize(reshaped_blocks_m[I], block_axes_I)
end
return mortar(bs, ax)
end

struct BlockedReshapeFusion <: FusionStyle end
FusionStyle(::Type{<:BlockedArray}) = BlockedReshapeFusion()
FusionStyle(::Type{<:BlockedArray}) = ReshapeFusion()
unblock(a::BlockedArray) = a.blocks
unblock(a::AbstractBlockArray) = a[Base.OneTo.(size(a))...]
unblock(a::AbstractArray) = a
function matricize(::BlockedReshapeFusion, a::AbstractArray, ndims_codomain::Val)
function matricize(::ReshapeFusion, a::BlockedArray, ndims_codomain::Val)
return matricize(ReshapeFusion(), unblock(a), ndims_codomain)
end
function unmatricize(
style::BlockedReshapeFusion, m::AbstractMatrix,
function unmatricize_blocked(
style::ReshapeFusion, m::AbstractMatrix,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
Expand All @@ -67,3 +72,24 @@ function unmatricize(
)
return BlockedArray(a, (axes_codomain..., axes_domain...))
end
function unmatricize(
style::ReshapeFusion, m::AbstractMatrix,
axes_codomain::Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}},
axes_domain::Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}},
)
return unmatricize_blocked(style, m, axes_codomain, axes_domain)
end
function unmatricize(
style::ReshapeFusion, m::AbstractMatrix,
axes_codomain::Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}},
axes_domain::Tuple{Vararg{AbstractBlockedUnitRange}},
)
return unmatricize_blocked(style, m, axes_codomain, axes_domain)
end
function unmatricize(
style::ReshapeFusion, m::AbstractMatrix,
axes_codomain::Tuple{Vararg{AbstractBlockedUnitRange}},
axes_domain::Tuple{AbstractBlockedUnitRange, Vararg{AbstractBlockedUnitRange}},
)
return unmatricize_blocked(style, m, axes_codomain, axes_domain)
end
94 changes: 58 additions & 36 deletions src/matricize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,48 +9,59 @@ FusionStyle(T::Type) = throw(MethodError(FusionStyle, (T,)))

# ======================================= misc ========================================
function trivial_axis(
style::FusionStyle,
::Val{:codomain},
a::AbstractArray,
style::FusionStyle, side::Val{:codomain}, a::AbstractArray,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
return trivial_axis(style, a, axes_codomain, axes_domain)
return throw(MethodError(trivial_axis, (style, side, a, axes_codomain, axes_domain)))
end
function trivial_axis(
style::FusionStyle,
::Val{:domain},
a::AbstractArray,
style::FusionStyle, ::Val{:domain}, a::AbstractArray,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
return trivial_axis(style, a, axes_codomain, axes_domain)
return trivial_axis(style, Val(:codomain), a, axes_codomain, axes_domain)
end
function trivial_axis(
style::FusionStyle,
a::AbstractArray,
style::FusionStyle, a::AbstractArray,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
return trivial_axis(style, a)
return trivial_axis(style, Val(:codomain), a, axes_codomain, axes_domain)
end
function trivial_axis(style::FusionStyle, a::AbstractArray)
return trivial_axis(ReshapeFusion(), a)
return trivial_axis(style, a, (), ())
end
function trivial_axis(
a::AbstractArray,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
return trivial_axis(FusionStyle(a), a, axes_codomain, axes_domain)
end
function trivial_axis(side::Val, a::AbstractArray)
return trivial_axis(FusionStyle(a), side, a)
end
function trivial_axis(a::AbstractArray)
return trivial_axis(FusionStyle(a), a)
end

# Tensor product two spaces (ranges) together based on a fusion style.
function tensor_product_axis(
style::FusionStyle, ::Val{:codomain}, r1::AbstractUnitRange, r2::AbstractUnitRange
style::FusionStyle, side::Val{:codomain},
r1::AbstractUnitRange, r2::AbstractUnitRange,
)
return tensor_product_axis(style, r1, r2)
return throw(MethodError(tensor_product_axis, (style, side, r1, r2)))
end
function tensor_product_axis(
style::FusionStyle, ::Val{:domain}, r1::AbstractUnitRange, r2::AbstractUnitRange
)
return tensor_product_axis(style, r1, r2)
return tensor_product_axis(style, Val(:codomain), r1, r2)
end
function tensor_product_axis(::FusionStyle, r1::AbstractUnitRange, r2::AbstractUnitRange)
return tensor_product_axis(ReshapeFusion(), r1, r2)
function tensor_product_axis(
style::FusionStyle, r1::AbstractUnitRange, r2::AbstractUnitRange
)
return tensor_product_axis(style, Val(:codomain), r1, r2)
end
function tensor_product_axis(side::Val, r1::AbstractUnitRange, r2::AbstractUnitRange)
style = tensor_product_fusionstyle(r1, r2)
Expand All @@ -68,9 +79,7 @@ function tensor_product_fusionstyle(r1::AbstractUnitRange, r2::AbstractUnitRange
end

function fused_axis(
style::FusionStyle,
side::Val{:codomain},
a::AbstractArray,
style::FusionStyle, side::Val{:codomain}, a::AbstractArray,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
Expand All @@ -80,9 +89,7 @@ function fused_axis(
end
end
function fused_axis(
style::FusionStyle,
side::Val{:domain},
a::AbstractArray,
style::FusionStyle, side::Val{:domain}, a::AbstractArray,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
Expand All @@ -92,15 +99,21 @@ function fused_axis(
end
end
function matricize_axes(
style::FusionStyle,
a::AbstractArray,
style::FusionStyle, a::AbstractArray,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
axis_codomain = fused_axis(style, Val(:codomain), a, axes_codomain, axes_domain)
axis_domain = fused_axis(style, Val(:domain), a, axes_codomain, axes_domain)
return axis_codomain, axis_domain
end
function matricize_axes(
a::AbstractArray,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
return matricize_axes(FusionStyle(a), a, axes_codomain, axes_domain)
end
function matricize_axes(style::FusionStyle, a::AbstractArray, ndims_codomain::Val)
unval(ndims_codomain) ≤ ndims(a) ||
throw(ArgumentError("Codomain length exceeds number of dimensions."))
Expand Down Expand Up @@ -136,15 +149,15 @@ end
# matrix factorizations assume copy
# maybe: copy=false kwarg

function matricize(a::AbstractArray, ndims_codomain::Val)
return matricize(FusionStyle(a), a, ndims_codomain)
end
# This is the primary function that should be overloaded for new fusion styles.
# This assumes the permutation was already performed.
function matricize(
style::FusionStyle, a::AbstractArray, ndims_codomain::Val
)
return matricize(ReshapeFusion(), a, ndims_codomain)
return throw(MethodError(matricize, (style, a, ndims_codomain)))
end
function matricize(a::AbstractArray, ndims_codomain::Val)
return matricize(FusionStyle(a), a, ndims_codomain)
end

function matricize(
Expand Down Expand Up @@ -207,20 +220,20 @@ function matricize(
end

# ==================================== unmatricize =======================================
# This is the primary function that should be overloaded for new fusion styles.
function unmatricize(
m::AbstractMatrix,
style::FusionStyle, m::AbstractMatrix,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
return unmatricize(FusionStyle(m), m, axes_codomain, axes_domain)
return throw(MethodError(unmatricize, (style, m, axes_codomain, axes_domain)))
end
# This is the primary function that should be overloaded for new fusion styles.
function unmatricize(
style::FusionStyle, m::AbstractMatrix,
m::AbstractMatrix,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
return unmatricize(ReshapeFusion(), m, axes_codomain, axes_domain)
return unmatricize(FusionStyle(m), m, axes_codomain, axes_domain)
end

function unmatricize(m::AbstractMatrix, blocked_axes::AbstractBlockTuple{2})
Expand Down Expand Up @@ -331,8 +344,17 @@ end
# Defaults to ReshapeFusion, a simple reshape
struct ReshapeFusion <: FusionStyle end
FusionStyle(::Type{<:AbstractArray}) = ReshapeFusion()
trivial_axis(::ReshapeFusion, a::AbstractArray) = Base.OneTo(1)
function tensor_product_axis(::ReshapeFusion, r1::AbstractUnitRange, r2::AbstractUnitRange)
function trivial_axis(
style::ReshapeFusion, side::Val{:codomain}, a::AbstractArray,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
return Base.OneTo(1)
end
function tensor_product_axis(
style::ReshapeFusion, side::Val{:codomain},
r1::AbstractUnitRange, r2::AbstractUnitRange,
)
(isone(first(r1)) && isone(first(r2))) ||
throw(ArgumentError("Only one-based axes are supported"))
return Base.OneTo(length(r1) * length(r2))
Expand Down
Loading