Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.6.4"
version = "0.6.5"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
90 changes: 41 additions & 49 deletions src/MatrixAlgebra.jl
Original file line number Diff line number Diff line change
@@ -1,86 +1,82 @@
module MatrixAlgebra

export eigen,
eigen!,
eigen!!,
eigvals,
eigvals!,
eigvals!!,
factorize,
factorize!,
factorize!!,
lq,
lq!,
lq!!,
orth,
orth!,
orth!!,
polar,
polar!,
polar!!,
qr,
qr!,
qr!!,
svd,
svd!,
svd!!,
svdvals,
svdvals!
svdvals!!

using LinearAlgebra: LinearAlgebra, norm
using MatrixAlgebraKit
import MatrixAlgebraKit as MAK

for (f, f_full, f_compact) in (
(:qr, :qr_full, :qr_compact),
(:qr!, :qr_full!, :qr_compact!),
(:qr!!, :qr_full!, :qr_compact!),
(:lq, :lq_full, :lq_compact),
(:lq!, :lq_full!, :lq_compact!),
(:lq!!, :lq_full!, :lq_compact!),
)
@eval begin
function $f(A::AbstractMatrix; full::Bool = false, kwargs...)
f = full ? $f_full : $f_compact
return f(A; kwargs...)
return full ? MAK.$f_full(A; kwargs...) : MAK.$f_compact(A; kwargs...)
end
end
end

for (eigen, eigh_full, eig_full, eigh_trunc, eig_trunc) in (
(:eigen, :eigh_full, :eig_full, :eigh_trunc, :eig_trunc),
(:eigen!, :eigh_full!, :eig_full!, :eigh_trunc!, :eig_trunc!),
(:eigen!!, :eigh_full!, :eig_full!, :eigh_trunc!, :eig_trunc!),
)
@eval begin
function $eigen(A::AbstractMatrix; trunc = nothing, ishermitian = nothing, kwargs...)
ishermitian = @something ishermitian LinearAlgebra.ishermitian(A)
return if !isnothing(trunc)
if ishermitian
$eigh_trunc(A; trunc, kwargs...)
MAK.$eigh_trunc(A; trunc, kwargs...)
else
$eig_trunc(A; trunc, kwargs...)
MAK.$eig_trunc(A; trunc, kwargs...)
end
else
if ishermitian
$eigh_full(A; kwargs...)
MAK.$eigh_full(A; kwargs...)
else
$eig_full(A; kwargs...)
MAK.$eig_full(A; kwargs...)
end
end
end
end
end

for (eigvals, eigh_vals, eig_vals) in
((:eigvals, :eigh_vals, :eig_vals), (:eigvals!, :eigh_vals!, :eig_vals!))
((:eigvals, :eigh_vals, :eig_vals), (:eigvals!!, :eigh_vals!, :eig_vals!))
@eval begin
function $eigvals(A::AbstractMatrix; ishermitian = nothing, kwargs...)
ishermitian = @something ishermitian LinearAlgebra.ishermitian(A)
f = (ishermitian ? $eigh_vals : $eig_vals)
return f(A; kwargs...)
return ishermitian ? MAK.$eigh_vals(A; kwargs...) : MAK.$eig_vals(A; kwargs...)
end
end
end

for (svd, svd_trunc, svd_full, svd_compact) in (
(:svd, :svd_trunc, :svd_full, :svd_compact),
(:svd!, :svd_trunc!, :svd_full!, :svd_compact!),
(:svd!!, :svd_trunc!, :svd_full!, :svd_compact!),
)
_svd = Symbol(:_, svd)
@eval begin
function $svd(
A::AbstractMatrix;
full::Union{Bool, Val} = Val(false),
trunc = nothing,
A::AbstractMatrix; full::Union{Bool, Val} = Val(false), trunc = nothing,
kwargs...,
)
return $_svd(full, trunc, A; kwargs...)
Expand All @@ -89,13 +85,13 @@ for (svd, svd_trunc, svd_full, svd_compact) in (
return $_svd(Val(full), trunc, A; kwargs...)
end
function $_svd(full::Val{false}, trunc::Nothing, A::AbstractMatrix; kwargs...)
return $svd_compact(A; kwargs...)
return MAK.$svd_compact(A; kwargs...)
end
function $_svd(full::Val{false}, trunc, A::AbstractMatrix; kwargs...)
return $svd_trunc(A; trunc, kwargs...)
return MAK.$svd_trunc(A; trunc, kwargs...)
end
function $_svd(full::Val{true}, trunc::Nothing, A::AbstractMatrix; kwargs...)
return $svd_full(A; kwargs...)
return MAK.$svd_full(A; kwargs...)
end
function $_svd(full::Val{true}, trunc, A::AbstractMatrix; kwargs...)
return throw(
Expand All @@ -107,55 +103,52 @@ for (svd, svd_trunc, svd_full, svd_compact) in (
end
end

for (svdvals, svd_vals) in ((:svdvals, :svd_vals), (:svdvals!, :svd_vals!))
for (svdvals, svd_vals) in ((:svdvals, :svd_vals), (:svdvals!!, :svd_vals!))
@eval begin
function $svdvals(A::AbstractMatrix; ishermitian = nothing, kwargs...)
return $svd_vals(A; kwargs...)
return MAK.$svd_vals(A; kwargs...)
end
end
end

for (polar, left_polar, right_polar) in
((:polar, :left_polar, :right_polar), (:polar!, :left_polar!, :right_polar!))
((:polar, :left_polar, :right_polar), (:polar!!, :left_polar!, :right_polar!))
@eval begin
function $polar(A::AbstractMatrix; side = :left, kwargs...)
f = if side == :left
$left_polar
return if side == :left
MAK.$left_polar(A; kwargs...)
elseif side == :right
$right_polar
MAK.$right_polar(A; kwargs...)
else
throw(ArgumentError("`side=$side` not supported."))
throw(ArgumentError("`side = $side` not supported."))
end
return f(A; kwargs...)
end
end
end

for (orth, left_orth, right_orth) in
((:orth, :left_orth, :right_orth), (:orth!, :left_orth!, :right_orth!))
((:orth, :left_orth, :right_orth), (:orth!!, :left_orth!, :right_orth!))
@eval begin
function $orth(A::AbstractMatrix; side = :left, kwargs...)
f = if side == :left
$left_orth
return if side == :left
MAK.$left_orth(A; kwargs...)
elseif side == :right
$right_orth
MAK.$right_orth(A; kwargs...)
else
throw(ArgumentError("`side=$side` not supported."))
throw(ArgumentError("`side = $side` not supported."))
end
return f(A; kwargs...)
end
end
end

for (factorize, orth_f) in ((:factorize, :(MatrixAlgebra.orth)), (:factorize!, :orth!))
for (factorize, orth_f) in ((:factorize, :(MatrixAlgebra.orth)), (:factorize!!, :orth!!))
@eval begin
function $factorize(A::AbstractMatrix; orth = :left, kwargs...)
f = if orth in (:left, :right)
$orth_f
return if orth in (:left, :right)
$orth_f(A; side = orth, kwargs...)
else
throw(ArgumentError("`orth=$orth` not supported."))
throw(ArgumentError("`orth = $orth` not supported."))
end
return f(A; side = orth, kwargs...)
end
end
end
Expand Down Expand Up @@ -190,7 +183,6 @@ function truncdegen(strategy::TruncationStrategy; atol::Real = 0, rtol::Real = 0
end

using MatrixAlgebraKit: findtruncated

function MatrixAlgebraKit.findtruncated(
values::AbstractVector, strategy::TruncationDegenerate
)
Expand Down
89 changes: 69 additions & 20 deletions src/factorizations.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
using LinearAlgebra: LinearAlgebra
using MatrixAlgebraKit: MatrixAlgebraKit

for f in (
:qr, :lq, :left_polar, :right_polar, :polar, :left_orth, :right_orth, :orth,
:factorize,
for (f, f_mat) in (
(:qr, :(MatrixAlgebra.qr)),
(:lq, :(MatrixAlgebra.lq)),
(:left_polar, :(MatrixAlgebraKit.left_polar)),
(:right_polar, :(MatrixAlgebraKit.right_polar)),
(:polar, :(MatrixAlgebra.polar)),
(:left_orth, :(MatrixAlgebraKit.left_orth)),
(:right_orth, :(MatrixAlgebraKit.right_orth)),
(:orth, :(MatrixAlgebra.orth)),
(:factorize, :(MatrixAlgebra.factorize)),
)
@eval begin
function $f(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
A_mat = matricize(style, A, ndims_codomain)
X, Y = MatrixAlgebra.$f(A_mat; kwargs...)
X, Y = $f_mat(A_mat; kwargs...)
biperm = trivialbiperm(ndims_codomain, Val(ndims(A)))
axes_codomain, axes_domain = blocks(axes(A)[biperm])
axes_X = tuplemortar((axes_codomain, (axes(X, 2),)))
Expand Down Expand Up @@ -219,18 +226,25 @@ See also `MatrixAlgebraKit.eig_full!`, `MatrixAlgebraKit.eig_trunc!`, `MatrixAlg
"""
eigen

function eigen(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
function eigen!!(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
# tensor to matrix
A_mat = matricize(style, A, ndims_codomain)
D, V = MatrixAlgebra.eigen!(A_mat; kwargs...)
D, V = MatrixAlgebra.eigen!!(A_mat; kwargs...)
biperm = trivialbiperm(ndims_codomain, Val(ndims(A)))
axes_codomain, = blocks(axes(A)[biperm])
axes_V = tuplemortar((axes_codomain, (axes(V, ndims(V)),)))
# TODO: Make sure `D` has the same basis as `V`.
return D, unmatricize(style, V, axes_V)
end
function eigen!!(A::AbstractArray, ndims_codomain::Val; kwargs...)
return eigen!!(FusionStyle(A), A, ndims_codomain; kwargs...)
end

function eigen(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
return eigen!!(style, copy(A), ndims_codomain; kwargs...)
end
function eigen(A::AbstractArray, ndims_codomain::Val; kwargs...)
return eigen(FusionStyle(A), A, ndims_codomain; kwargs...)
return eigen!!(copy(A), ndims_codomain; kwargs...)
end

"""
Expand All @@ -253,12 +267,19 @@ See also `MatrixAlgebraKit.eig_vals!` and `MatrixAlgebraKit.eigh_vals!`.
"""
eigvals

function eigvals(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
function eigvals!!(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
A_mat = matricize(style, A, ndims_codomain)
return MatrixAlgebra.eigvals!(A_mat; kwargs...)
return MatrixAlgebra.eigvals!!(A_mat; kwargs...)
end
function eigvals!!(A::AbstractArray, ndims_codomain::Val; kwargs...)
return eigvals!!(FusionStyle(A), A, ndims_codomain; kwargs...)
end

function eigvals(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
return eigvals!!(style, copy(A), ndims_codomain; kwargs...)
end
function eigvals(A::AbstractArray, ndims_codomain::Val; kwargs...)
return eigvals(FusionStyle(A), A, ndims_codomain; kwargs...)
return eigvals!!(copy(A), ndims_codomain; kwargs...)
end

"""
Expand All @@ -282,17 +303,24 @@ See also `MatrixAlgebraKit.svd_full!`, `MatrixAlgebraKit.svd_compact!`, and `Mat
"""
svd

function svd(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
function svd!!(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
A_mat = matricize(style, A, ndims_codomain)
U, S, Vᴴ = MatrixAlgebra.svd!(A_mat; kwargs...)
U, S, Vᴴ = MatrixAlgebra.svd!!(A_mat; kwargs...)
biperm = trivialbiperm(ndims_codomain, Val(ndims(A)))
axes_codomain, axes_domain = blocks(axes(A)[biperm])
axes_U = tuplemortar((axes_codomain, (axes(U, 2),)))
axes_Vᴴ = tuplemortar(((axes(Vᴴ, 1),), axes_domain))
return unmatricize(style, U, axes_U), S, unmatricize(style, Vᴴ, axes_Vᴴ)
end
function svd!!(A::AbstractArray, ndims_codomain::Val; kwargs...)
return svd!!(FusionStyle(A), A, ndims_codomain; kwargs...)
end

function svd(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
return svd!!(style, copy(A), ndims_codomain; kwargs...)
end
function svd(A::AbstractArray, ndims_codomain::Val; kwargs...)
return svd(FusionStyle(A), A, ndims_codomain; kwargs...)
return svd!!(copy(A), ndims_codomain; kwargs...)
end

"""
Expand All @@ -309,12 +337,19 @@ See also `MatrixAlgebraKit.svd_vals!`.
"""
svdvals

function svdvals(style::FusionStyle, A::AbstractArray, ndims_codomain::Val)
function svdvals!!(style::FusionStyle, A::AbstractArray, ndims_codomain::Val)
A_mat = matricize(style, A, ndims_codomain)
return MatrixAlgebra.svdvals!(A_mat)
return MatrixAlgebra.svdvals!!(A_mat)
end
function svdvals!!(A::AbstractArray, ndims_codomain::Val)
return svdvals!!(FusionStyle(A), A, ndims_codomain)
end

function svdvals(style::FusionStyle, A::AbstractArray, ndims_codomain::Val)
return svdvals!!(style, copy(A), ndims_codomain)
end
function svdvals(A::AbstractArray, ndims_codomain::Val)
return svdvals(FusionStyle(A), A, ndims_codomain)
return svdvals!!(copy(A), ndims_codomain)
end

"""
Expand All @@ -338,16 +373,23 @@ The output satisfies `N' * A ≈ 0` and `N' * N ≈ I`.
"""
left_null

function left_null(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
function left_null!!(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
A_mat = matricize(style, A, ndims_codomain)
N = MatrixAlgebraKit.left_null!(A_mat; kwargs...)
biperm = trivialbiperm(ndims_codomain, Val(ndims(A)))
axes_codomain = first(blocks(axes(A)[biperm]))
axes_N = tuplemortar((axes_codomain, (axes(N, 2),)))
return unmatricize(style, N, axes_N)
end
function left_null!!(A::AbstractArray, ndims_codomain::Val; kwargs...)
return left_null!!(FusionStyle(A), A, ndims_codomain; kwargs...)
end

function left_null(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
return left_null!!(style, copy(A), ndims_codomain; kwargs...)
end
function left_null(A::AbstractArray, ndims_codomain::Val; kwargs...)
return left_null(FusionStyle(A), A, ndims_codomain; kwargs...)
return left_null!!(copy(A), ndims_codomain; kwargs...)
end

"""
Expand All @@ -371,14 +413,21 @@ The output satisfies `A * Nᴴ' ≈ 0` and `Nᴴ * Nᴴ' ≈ I`.
"""
right_null

function right_null(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
function right_null!!(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
A_mat = matricize(style, A, ndims_codomain)
Nᴴ = MatrixAlgebraKit.right_null!(A_mat; kwargs...)
biperm = trivialbiperm(ndims_codomain, Val(ndims(A)))
axes_domain = last(blocks((axes(A)[biperm])))
axes_Nᴴ = tuplemortar(((axes(Nᴴ, 1),), axes_domain))
return unmatricize(style, Nᴴ, axes_Nᴴ)
end
function right_null!!(A::AbstractArray, ndims_codomain::Val; kwargs...)
return right_null!!(FusionStyle(A), A, ndims_codomain; kwargs...)
end

function right_null(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...)
return right_null!!(style, copy(A), ndims_codomain; kwargs...)
end
function right_null(A::AbstractArray, ndims_codomain::Val; kwargs...)
return right_null(FusionStyle(A), A, ndims_codomain; kwargs...)
return right_null!!(copy(A), ndims_codomain; kwargs...)
end
Loading
Loading