diff --git a/Project.toml b/Project.toml index 7d754d817..137aa050c 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ Combinatorics = "1" FiniteDifferences = "0.12" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.5.0" +MatrixAlgebraKit = "0.6.0" OhMyThreads = "0.8.0" PackageExtensionCompat = "1" Printf = "1" diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 74d23c8ce..c6e003450 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -79,10 +79,12 @@ export left_orth, right_orth, left_null, right_null, qr_full!, qr_compact!, qr_null!, lq_full!, lq_compact!, lq_null!, svd_compact!, svd_full!, svd_trunc!, svd_compact, svd_full, svd_trunc, exp, exp!, - eigh_full!, eigh_full, eigh_trunc!, eigh_trunc, eig_full!, eig_full, eig_trunc!, - eig_trunc, - eigh_vals!, eigh_vals, eig_vals!, eig_vals, - isposdef, isposdef!, ishermitian, isisometry, isunitary, sylvester, rank, cond + eigh_full!, eigh_full, eigh_trunc!, eigh_trunc, eigh_vals!, eigh_vals, + eig_full!, eig_full, eig_trunc!, eig_trunc, eig_vals!, eig_vals, + ishermitian, project_hermitian, project_hermitian!, + isantihermitian, project_antihermitian, project_antihermitian!, + isisometric, isunitary, project_isometric, project_isometric!, + isposdef, isposdef!, sylvester, rank, cond export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition, repartition! @@ -135,7 +137,7 @@ using LinearAlgebra: norm, dot, normalize, normalize!, tr, adjoint, adjoint!, transpose, transpose!, lu, pinv, sylvester, eigen, eigen!, svd, svd!, - isposdef, isposdef!, ishermitian, rank, cond, + isposdef, isposdef!, rank, cond, Diagonal, Hermitian using MatrixAlgebraKit diff --git a/src/auxiliary/deprecate.jl b/src/auxiliary/deprecate.jl index a37c40531..81fbe9a50 100644 --- a/src/auxiliary/deprecate.jl +++ b/src/auxiliary/deprecate.jl @@ -186,7 +186,7 @@ function tsvd(t::AbstractTensorMap; kwargs...) Base.depwarn("p is a deprecated kwarg, and should be specified through the truncation strategy", :tsvd) kwargs = _drop_p(; kwargs...) end - return haskey(kwargs, :trunc) ? svd_trunc(t; kwargs...) : svd_compact(t; kwargs...) + return haskey(kwargs, :trunc) ? svd_trunc(t; kwargs...) : (svd_compact(t; kwargs...)..., abs(zero(scalartype(t)))) end function tsvd!(t::AbstractTensorMap; kwargs...) Base.depwarn("`tsvd!` is deprecated, use `svd_compact!`, `svd_full!` or `svd_trunc!` instead", :tsvd!) diff --git a/src/factorizations/adjoint.jl b/src/factorizations/adjoint.jl index 30a68c138..29f27203c 100644 --- a/src/factorizations/adjoint.jl +++ b/src/factorizations/adjoint.jl @@ -6,46 +6,54 @@ _adjoint(alg::MAK.LAPACK_HouseholderQR) = MAK.LAPACK_HouseholderLQ(; alg.kwargs. _adjoint(alg::MAK.LAPACK_HouseholderLQ) = MAK.LAPACK_HouseholderQR(; alg.kwargs...) _adjoint(alg::MAK.LAPACK_HouseholderQL) = MAK.LAPACK_HouseholderRQ(; alg.kwargs...) _adjoint(alg::MAK.LAPACK_HouseholderRQ) = MAK.LAPACK_HouseholderQL(; alg.kwargs...) -_adjoint(alg::MAK.PolarViaSVD) = MAK.PolarViaSVD(_adjoint(alg.svdalg)) +_adjoint(alg::MAK.PolarViaSVD) = MAK.PolarViaSVD(_adjoint(alg.svd_alg)) _adjoint(alg::AbstractAlgorithm) = alg -# 1-arg functions -function MAK.initialize_output(::typeof(left_null!), t::AdjointTensorMap, alg::AbstractAlgorithm) - return adjoint(MAK.initialize_output(right_null!, adjoint(t), _adjoint(alg))) -end -function MAK.initialize_output( - ::typeof(right_null!), t::AdjointTensorMap, - alg::AbstractAlgorithm - ) - return adjoint(MAK.initialize_output(left_null!, adjoint(t), _adjoint(alg))) +for f in + [ + :svd_compact, :svd_full, :svd_vals, + :qr_compact, :qr_full, :qr_null, + :lq_compact, :lq_full, :lq_null, + :eig_full, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, + :left_polar, :right_polar, + :project_hermitian, :project_antihermitian, :project_isometric, + ] + f! = Symbol(f, :!) + # just return the algorithm for the parent type since we are mapping this with + # `_adjoint` afterwards anyways. + # TODO: properly handle these cases + @eval MAK.default_algorithm(::typeof($f!), ::Type{T}; kwargs...) where {T <: AdjointTensorMap} = + MAK.default_algorithm($f!, TensorKit.parenttype(T); kwargs...) end -function MAK.left_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm) - right_null!(adjoint(t), adjoint(N), _adjoint(alg)) - return N -end -function MAK.right_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm) - left_null!(adjoint(t), adjoint(N), _adjoint(alg)) - return N -end +# 1-arg functions +MAK.initialize_output(::typeof(qr_null!), t::AdjointTensorMap, alg::AbstractAlgorithm) = + adjoint(MAK.initialize_output(lq_null!, adjoint(t), _adjoint(alg))) +MAK.initialize_output(::typeof(lq_null!), t::AdjointTensorMap, alg::AbstractAlgorithm) = + adjoint(MAK.initialize_output(qr_null!, adjoint(t), _adjoint(alg))) -function MAK.is_left_isometry(t::AdjointTensorMap; kwargs...) - return is_right_isometry(adjoint(t); kwargs...) -end -function MAK.is_right_isometry(t::AdjointTensorMap; kwargs...) - return is_left_isometry(adjoint(t); kwargs...) -end +MAK.qr_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm) = + lq_null!(adjoint(t), adjoint(N), _adjoint(alg)) +MAK.lq_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm) = + qr_null!(adjoint(t), adjoint(N), _adjoint(alg)) + +MAK.is_left_isometric(t::AdjointTensorMap; kwargs...) = + MAK.is_right_isometric(adjoint(t); kwargs...) +MAK.is_right_isometric(t::AdjointTensorMap; kwargs...) = + MAK.is_left_isometric(adjoint(t); kwargs...) # 2-arg functions -for (left_f!, right_f!) in zip( - (:qr_full!, :qr_compact!, :left_polar!, :left_orth!), - (:lq_full!, :lq_compact!, :right_polar!, :right_orth!) +for (left_f, right_f) in zip( + (:qr_full, :qr_compact, :left_polar), + (:lq_full, :lq_compact, :right_polar) ) - @eval function MAK.copy_input(::typeof($left_f!), t::AdjointTensorMap) - return adjoint(MAK.copy_input($right_f!, adjoint(t))) + left_f! = Symbol(left_f, :!) + right_f! = Symbol(right_f, :!) + @eval function MAK.copy_input(::typeof($left_f), t::AdjointTensorMap) + return adjoint(MAK.copy_input($right_f, adjoint(t))) end - @eval function MAK.copy_input(::typeof($right_f!), t::AdjointTensorMap) - return adjoint(MAK.copy_input($left_f!, adjoint(t))) + @eval function MAK.copy_input(::typeof($right_f), t::AdjointTensorMap) + return adjoint(MAK.copy_input($left_f, adjoint(t))) end @eval function MAK.initialize_output( @@ -60,19 +68,20 @@ for (left_f!, right_f!) in zip( end @eval function MAK.$left_f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) - $right_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) - return F + F′ = $right_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) + return reverse(adjoint.(F′)) end @eval function MAK.$right_f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) - $left_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) - return F + F′ = $left_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) + return reverse(adjoint.(F′)) end end # 3-arg functions -for f! in (:svd_full!, :svd_compact!, :svd_trunc!) - @eval function MAK.copy_input(::typeof($f!), t::AdjointTensorMap) - return adjoint(MAK.copy_input($f!, adjoint(t))) +for f in (:svd_full, :svd_compact) + f! = Symbol(f, :!) + @eval function MAK.copy_input(::typeof($f), t::AdjointTensorMap) + return adjoint(MAK.copy_input($f, adjoint(t))) end @eval function MAK.initialize_output( @@ -80,9 +89,10 @@ for f! in (:svd_full!, :svd_compact!, :svd_trunc!) ) return reverse(adjoint.(MAK.initialize_output($f!, adjoint(t), _adjoint(alg)))) end + @eval function MAK.$f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) - $f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) - return F + F′ = $f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) + return reverse(adjoint.(F′)) end # disambiguate by prohibition @@ -92,17 +102,9 @@ for f! in (:svd_full!, :svd_compact!, :svd_trunc!) throw(MethodError($f!, (t, alg))) end end + # avoid amgiguity -function MAK.initialize_output( - ::typeof(svd_trunc!), t::AdjointTensorMap, alg::TruncatedAlgorithm - ) - return MAK.initialize_output(svd_compact!, t, alg.alg) -end -# to fix ambiguity -function MAK.svd_trunc!(t::AdjointTensorMap, USVᴴ, alg::TruncatedAlgorithm) - USVᴴ′ = svd_compact!(t, USVᴴ, alg.alg) - return MAK.truncate(svd_trunc!, USVᴴ′, alg.trunc) -end -function MAK.svd_compact!(t::AdjointTensorMap, USVᴴ, alg::DiagonalAlgorithm) - return MAK.svd_compact!(t, USVᴴ, alg.alg) +function MAK.svd_compact!(t::AdjointTensorMap, F, alg::DiagonalAlgorithm) + F′ = svd_compact!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) + return reverse(adjoint.(F′)) end diff --git a/src/factorizations/diagonal.jl b/src/factorizations/diagonal.jl index 4de99f973..cea915995 100644 --- a/src/factorizations/diagonal.jl +++ b/src/factorizations/diagonal.jl @@ -81,44 +81,11 @@ for f! in :eigh_trunc!, :right_orth!, :left_orth!, ) @eval function MAK.$f!(d::DiagonalTensorMap, F, alg::DiagonalAlgorithm) - MAK.check_input($f!, d, F, alg) $f!(_repack_diagonal(d), _repack_diagonal.(F), alg) return F end end -for f! in (:qr_full!, :qr_compact!) - @eval function MAK.check_input( - ::typeof($f!), d::AbstractTensorMap, QR, ::DiagonalAlgorithm - ) - Q, R = QR - @assert d isa DiagonalTensorMap - @assert Q isa DiagonalTensorMap && R isa DiagonalTensorMap - @check_scalar Q d - @check_scalar R d - @check_space(Q, space(d)) - @check_space(R, space(d)) - - return nothing - end -end - -for f! in (:lq_full!, :lq_compact!) - @eval function MAK.check_input( - ::typeof($f!), d::AbstractTensorMap, LQ, ::DiagonalAlgorithm - ) - L, Q = LQ - @assert d isa DiagonalTensorMap - @assert Q isa DiagonalTensorMap && L isa DiagonalTensorMap - @check_scalar Q d - @check_scalar L d - @check_space(Q, space(d)) - @check_space(L, space(d)) - - return nothing - end -end - # disambiguate function MAK.svd_compact!(t::AbstractTensorMap, USVᴴ, alg::DiagonalAlgorithm) return svd_full!(t, USVᴴ, alg) @@ -126,10 +93,8 @@ end # f_vals # ------ - for f! in (:eig_vals!, :eigh_vals!, :svd_vals!) @eval function MAK.$f!(d::AbstractTensorMap, V, alg::DiagonalAlgorithm) - MAK.check_input($f!, d, V, alg) $f!(_repack_diagonal(d), diagview(_repack_diagonal(V)), alg) return V end @@ -140,64 +105,3 @@ for f! in (:eig_vals!, :eigh_vals!, :svd_vals!) return DiagonalTensorMap(data, d.domain) end end - -function MAK.check_input(::typeof(eig_full!), t::AbstractTensorMap, DV, ::DiagonalAlgorithm) - domain(t) == codomain(t) || - throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) - - D, V = DV - - @assert D isa DiagonalTensorMap - @assert V isa AbstractTensorMap - - # scalartype checks - @check_scalar D t - @check_scalar V t - - # space checks - @check_space D space(t) - @check_space V space(t) - - return nothing -end - -function MAK.check_input(::typeof(eigh_full!), t::AbstractTensorMap, DV, ::DiagonalAlgorithm) - domain(t) == codomain(t) || - throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) - - D, V = DV - - @assert D isa DiagonalTensorMap - @assert V isa AbstractTensorMap - - # scalartype checks - @check_scalar D t real - @check_scalar V t - - # space checks - @check_space D space(t) - @check_space V space(t) - - return nothing -end - -function MAK.check_input(::typeof(eig_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm) - @assert D isa DiagonalTensorMap - @check_scalar D t - @check_space D space(t) - return nothing -end - -function MAK.check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm) - @assert D isa DiagonalTensorMap - @check_scalar D t real - @check_space D space(t) - return nothing -end - -function MAK.check_input(::typeof(svd_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm) - @assert D isa DiagonalTensorMap - @check_scalar D t real - @check_space D space(t) - return nothing -end diff --git a/src/factorizations/factorizations.jl b/src/factorizations/factorizations.jl index 29d2c16a6..44d7e2315 100644 --- a/src/factorizations/factorizations.jl +++ b/src/factorizations/factorizations.jl @@ -18,8 +18,7 @@ import MatrixAlgebraKit as MAK using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, DiagonalAlgorithm using MatrixAlgebraKit: TruncationStrategy, NoTruncation, TruncationByValue, TruncationByError, TruncationIntersection, TruncationByFilter, TruncationByOrder -using MatrixAlgebraKit: left_orth_polar!, right_orth_polar!, left_orth_svd!, - right_orth_svd!, left_null_svd!, right_null_svd!, diagview +using MatrixAlgebraKit: diagview include("utility.jl") include("matrixalgebrakit.jl") @@ -30,11 +29,6 @@ include("pullbacks.jl") TensorKit.one!(A::AbstractMatrix) = MatrixAlgebraKit.one!(A) -function MatrixAlgebraKit.isisometry(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple) - t = permute(t, (p₁, p₂); copy = false) - return isisometry(t) -end - #------------------------------# # LinearAlgebra overloads #------------------------------# @@ -61,15 +55,24 @@ LinearAlgebra.svdvals!(t::AbstractTensorMap) = diagview(svd_vals!(t)) #--------------------------------------------------# # Checks for hermiticity and positive definiteness # #--------------------------------------------------# -function LinearAlgebra.ishermitian(t::AbstractTensorMap) - domain(t) == codomain(t) || return false - InnerProductStyle(t) === EuclideanInnerProduct() || return false # hermiticity only defined for euclidean - for (c, b) in blocks(t) - ishermitian(b) || return false +function _blockmap(f; kwargs...) + return function ((c, b)) + return f(b; kwargs...) end - return true end +function MAK.ishermitian(t::AbstractTensorMap; kwargs...) + return InnerProductStyle(t) === EuclideanInnerProduct() && + domain(t) == codomain(t) && + all(_blockmap(MAK.ishermitian; kwargs...), blocks(t)) +end +function MAK.isantihermitian(t::AbstractTensorMap; kwargs...) + return InnerProductStyle(t) === EuclideanInnerProduct() && + domain(t) == codomain(t) && + all(_blockmap(MAK.isantihermitian; kwargs...), blocks(t)) +end +LinearAlgebra.ishermitian(t::AbstractTensorMap) = MAK.ishermitian(t) + function LinearAlgebra.isposdef(t::AbstractTensorMap) return isposdef!(copy_oftype(t, factorisation_scalartype(isposdef, t))) end @@ -77,22 +80,17 @@ function LinearAlgebra.isposdef!(t::AbstractTensorMap) domain(t) == codomain(t) || throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same")) InnerProductStyle(spacetype(t)) === EuclideanInnerProduct() || return false - for (c, b) in blocks(t) - isposdef!(b) || return false - end - return true + return all(_blockmap(isposdef!), blocks(t)) end # TODO: tolerances are per-block, not global or weighted - does that matter? -function MatrixAlgebraKit.is_left_isometry(t::AbstractTensorMap; kwargs...) +function MAK.is_left_isometric(t::AbstractTensorMap; kwargs...) domain(t) ≾ codomain(t) || return false - f((c, b)) = MatrixAlgebraKit.is_left_isometry(b; kwargs...) - return all(f, blocks(t)) + return all(_blockmap(MAK.is_left_isometric; kwargs...), blocks(t)) end -function MatrixAlgebraKit.is_right_isometry(t::AbstractTensorMap; kwargs...) +function MAK.is_right_isometric(t::AbstractTensorMap; kwargs...) domain(t) ≿ codomain(t) || return false - f((c, b)) = MatrixAlgebraKit.is_right_isometry(b; kwargs...) - return all(f, blocks(t)) + return all(_blockmap(MAK.is_right_isometric; kwargs...), blocks(t)) end end diff --git a/src/factorizations/matrixalgebrakit.jl b/src/factorizations/matrixalgebrakit.jl index fc5f84502..412fd8761 100644 --- a/src/factorizations/matrixalgebrakit.jl +++ b/src/factorizations/matrixalgebrakit.jl @@ -2,9 +2,12 @@ # ------------------- for f in [ - :svd_compact, :svd_full, :svd_trunc, :svd_vals, :qr_compact, :qr_full, :qr_null, - :lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full, - :eigh_trunc, :eigh_vals, :left_polar, :right_polar, + :svd_compact, :svd_full, :svd_vals, + :qr_compact, :qr_full, :qr_null, + :lq_compact, :lq_full, :lq_null, + :eig_full, :eig_vals, :eigh_full, :eigh_vals, + :left_polar, :right_polar, + :project_hermitian, :project_antihermitian, :project_isometric, ] f! = Symbol(f, :!) @eval function MAK.default_algorithm(::typeof($f!), ::Type{T}; kwargs...) where {T <: AbstractTensorMap} @@ -25,38 +28,30 @@ end for f! in ( :qr_compact!, :qr_full!, :lq_compact!, :lq_full!, :eig_full!, :eigh_full!, :svd_compact!, :svd_full!, - :left_polar!, :left_orth_polar!, :right_polar!, :right_orth_polar!, - :left_orth!, :right_orth!, + :left_polar!, :right_polar!, ) @eval function MAK.$f!(t::AbstractTensorMap, F, alg::AbstractAlgorithm) - MAK.check_input($f!, t, F, alg) - - foreachblock(t, F...) do _, bs - factors = Base.tail(bs) - factors′ = $f!(first(bs), factors, alg) + foreachblock(t, F...) do _, (tblock, Fblocks...) + Fblocks′ = $f!(tblock, Fblocks, alg) # deal with the case where the output is not in-place - for (f′, f) in zip(factors′, factors) - f′ === f || copy!(f, f′) + for (b′, b) in zip(Fblocks′, Fblocks) + b === b′ || copy!(b, b′) end return nothing end - return F end end # Handle these separately because single output instead of tuple -for f! in (:qr_null!, :lq_null!) +for f! in (:qr_null!, :lq_null!, :project_hermitian!, :project_antihermitian!, :project_isometric!) @eval function MAK.$f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm) - MAK.check_input($f!, t, N, alg) - - foreachblock(t, N) do _, (b, n) - n′ = $f!(b, n, alg) + foreachblock(t, N) do _, (tblock, Nblock) + Nblock′ = $f!(tblock, Nblock, alg) # deal with the case where the output is not the same as the input - n === n′ || copy!(n, n′) + Nblock === Nblock′ || copy!(Nblock, Nblock′) return nothing end - return N end end @@ -64,74 +59,18 @@ end # Handle these separately because single output instead of tuple for f! in (:svd_vals!, :eig_vals!, :eigh_vals!) @eval function MAK.$f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm) - MAK.check_input($f!, t, N, alg) - - foreachblock(t, N) do _, (b, n) - n′ = $f!(b, diagview(n), alg) + foreachblock(t, N) do _, (tblock, Nblock) + Nblock′ = $f!(tblock, diagview(Nblock), alg) # deal with the case where the output is not the same as the input - diagview(n) === n′ || copy!(diagview(n), n′) + diagview(Nblock) === Nblock′ || copy!(diagview(Nblock), Nblock′) return nothing end - return N end end # Singular value decomposition # ---------------------------- -function MAK.check_input(::typeof(svd_full!), t::AbstractTensorMap, USVᴴ, ::AbstractAlgorithm) - U, S, Vᴴ = USVᴴ - - # type checks - @assert U isa AbstractTensorMap - @assert S isa AbstractTensorMap - @assert Vᴴ isa AbstractTensorMap - - # scalartype checks - @check_scalar U t - @check_scalar S t real - @check_scalar Vᴴ t - - # space checks - V_cod = fuse(codomain(t)) - V_dom = fuse(domain(t)) - @check_space(U, codomain(t) ← V_cod) - @check_space(S, V_cod ← V_dom) - @check_space(Vᴴ, V_dom ← domain(t)) - - return nothing -end - -function MAK.check_input(::typeof(svd_compact!), t::AbstractTensorMap, USVᴴ, ::AbstractAlgorithm) - U, S, Vᴴ = USVᴴ - - # type checks - @assert U isa AbstractTensorMap - @assert S isa DiagonalTensorMap - @assert Vᴴ isa AbstractTensorMap - - # scalartype checks - @check_scalar U t - @check_scalar S t real - @check_scalar Vᴴ t - - # space checks - V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) - @check_space(U, codomain(t) ← V_cod) - @check_space(S, V_cod ← V_dom) - @check_space(Vᴴ, V_dom ← domain(t)) - - return nothing -end - -function MAK.check_input(::typeof(svd_vals!), t::AbstractTensorMap, D, ::AbstractAlgorithm) - @check_scalar D t real - @assert D isa DiagonalTensorMap - V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) - @check_space(D, V_cod ← V_dom) - return nothing -end - function MAK.initialize_output(::typeof(svd_full!), t::AbstractTensorMap, ::AbstractAlgorithm) V_cod = fuse(codomain(t)) V_dom = fuse(domain(t)) @@ -149,11 +88,6 @@ function MAK.initialize_output(::typeof(svd_compact!), t::AbstractTensorMap, ::A return U, S, Vᴴ end -# TODO: remove this once `AbstractMatrix` specialization is removed in MatrixAlgebraKit -function MAK.initialize_output(::typeof(svd_trunc!), t::AbstractTensorMap, alg::TruncatedAlgorithm) - return MAK.initialize_output(svd_compact!, t, alg.alg) -end - function MAK.initialize_output(::typeof(svd_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm) V_cod = infimum(fuse(codomain(t)), fuse(domain(t))) return DiagonalTensorMap{real(scalartype(t))}(undef, V_cod) @@ -161,66 +95,6 @@ end # Eigenvalue decomposition # ------------------------ -function MAK.check_input(::typeof(eigh_full!), t::AbstractTensorMap, DV, ::AbstractAlgorithm) - domain(t) == codomain(t) || - throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) - - D, V = DV - - # type checks - @assert D isa DiagonalTensorMap - @assert V isa AbstractTensorMap - - # scalartype checks - @check_scalar D t real - @check_scalar V t - - # space checks - V_D = fuse(domain(t)) - @check_space(D, V_D ← V_D) - @check_space(V, codomain(t) ← V_D) - - return nothing -end - -function MAK.check_input(::typeof(eig_full!), t::AbstractTensorMap, DV, ::AbstractAlgorithm) - domain(t) == codomain(t) || - throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) - - D, V = DV - - # type checks - @assert D isa DiagonalTensorMap - @assert V isa AbstractTensorMap - - # scalartype checks - @check_scalar D t complex - @check_scalar V t complex - - # space checks - V_D = fuse(domain(t)) - @check_space(D, V_D ← V_D) - @check_space(V, codomain(t) ← V_D) - - return nothing -end - -function MAK.check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D, ::AbstractAlgorithm) - @check_scalar D t real - @assert D isa DiagonalTensorMap - V_D = fuse(domain(t)) - @check_space(D, V_D ← V_D) - return nothing -end - -function MAK.check_input(::typeof(eig_vals!), t::AbstractTensorMap, D, ::AbstractAlgorithm) - @check_scalar D t complex - @assert D isa DiagonalTensorMap - V_D = fuse(domain(t)) - @check_space(D, V_D ← V_D) - return nothing -end - function MAK.initialize_output(::typeof(eigh_full!), t::AbstractTensorMap, ::AbstractAlgorithm) V_D = fuse(domain(t)) T = real(scalartype(t)) @@ -251,56 +125,6 @@ end # QR decomposition # ---------------- -function MAK.check_input(::typeof(qr_full!), t::AbstractTensorMap, QR, ::AbstractAlgorithm) - Q, R = QR - - # type checks - @assert Q isa AbstractTensorMap - @assert R isa AbstractTensorMap - - # scalartype checks - @check_scalar Q t - @check_scalar R t - - # space checks - V_Q = fuse(codomain(t)) - @check_space(Q, codomain(t) ← V_Q) - @check_space(R, V_Q ← domain(t)) - - return nothing -end - -function MAK.check_input(::typeof(qr_compact!), t::AbstractTensorMap, QR, ::AbstractAlgorithm) - Q, R = QR - - # type checks - @assert Q isa AbstractTensorMap - @assert R isa AbstractTensorMap - - # scalartype checks - @check_scalar Q t - @check_scalar R t - - # space checks - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - @check_space(Q, codomain(t) ← V_Q) - @check_space(R, V_Q ← domain(t)) - - return nothing -end - -function MAK.check_input(::typeof(qr_null!), t::AbstractTensorMap, N, ::AbstractAlgorithm) - # scalartype checks - @check_scalar N t - - # space checks - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = ⊖(fuse(codomain(t)), V_Q) - @check_space(N, codomain(t) ← V_N) - - return nothing -end - function MAK.initialize_output(::typeof(qr_full!), t::AbstractTensorMap, ::AbstractAlgorithm) V_Q = fuse(codomain(t)) Q = similar(t, codomain(t) ← V_Q) @@ -324,56 +148,6 @@ end # LQ decomposition # ---------------- -function MAK.check_input(::typeof(lq_full!), t::AbstractTensorMap, LQ, ::AbstractAlgorithm) - L, Q = LQ - - # type checks - @assert L isa AbstractTensorMap - @assert Q isa AbstractTensorMap - - # scalartype checks - @check_scalar L t - @check_scalar Q t - - # space checks - V_Q = fuse(domain(t)) - @check_space(L, codomain(t) ← V_Q) - @check_space(Q, V_Q ← domain(t)) - - return nothing -end - -function MAK.check_input(::typeof(lq_compact!), t::AbstractTensorMap, LQ, ::AbstractAlgorithm) - L, Q = LQ - - # type checks - @assert L isa AbstractTensorMap - @assert Q isa AbstractTensorMap - - # scalartype checks - @check_scalar L t - @check_scalar Q t - - # space checks - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - @check_space(L, codomain(t) ← V_Q) - @check_space(Q, V_Q ← domain(t)) - - return nothing -end - -function MAK.check_input(::typeof(lq_null!), t::AbstractTensorMap, N, ::AbstractAlgorithm) - # scalartype checks - @check_scalar N t - - # space checks - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = ⊖(fuse(domain(t)), V_Q) - @check_space(N, V_N ← domain(t)) - - return nothing -end - function MAK.initialize_output(::typeof(lq_full!), t::AbstractTensorMap, ::AbstractAlgorithm) V_Q = fuse(domain(t)) L = similar(t, codomain(t) ← V_Q) @@ -397,261 +171,23 @@ end # Polar decomposition # ------------------- -function MAK.check_input(::typeof(left_polar!), t::AbstractTensorMap, WP, ::AbstractAlgorithm) - codomain(t) ≿ domain(t) || - throw(ArgumentError("Polar decomposition requires `codomain(t) ≿ domain(t)`")) - - W, P = WP - @assert W isa AbstractTensorMap - @assert P isa AbstractTensorMap - - # scalartype checks - @check_scalar W t - @check_scalar P t - - # space checks - @check_space(W, space(t)) - @check_space(P, domain(t) ← domain(t)) - - return nothing -end - -function MAK.check_input(::typeof(left_orth_polar!), t::AbstractTensorMap, WP, ::AbstractAlgorithm) - codomain(t) ≿ domain(t) || - throw(ArgumentError("Polar decomposition requires `codomain(t) ≿ domain(t)`")) - - W, P = WP - @assert W isa AbstractTensorMap - @assert P isa AbstractTensorMap - - # scalartype checks - @check_scalar W t - @check_scalar P t - - # space checks - VW = fuse(domain(t)) - @check_space(W, codomain(t) ← VW) - @check_space(P, VW ← domain(t)) - - return nothing -end - function MAK.initialize_output(::typeof(left_polar!), t::AbstractTensorMap, ::AbstractAlgorithm) W = similar(t, space(t)) P = similar(t, domain(t) ← domain(t)) return W, P end -function MAK.check_input(::typeof(right_polar!), t::AbstractTensorMap, PWᴴ, ::AbstractAlgorithm) - codomain(t) ≾ domain(t) || - throw(ArgumentError("Polar decomposition requires `domain(t) ≿ codomain(t)`")) - - P, Wᴴ = PWᴴ - @assert P isa AbstractTensorMap - @assert Wᴴ isa AbstractTensorMap - - # scalartype checks - @check_scalar P t - @check_scalar Wᴴ t - - # space checks - @check_space(P, codomain(t) ← codomain(t)) - @check_space(Wᴴ, space(t)) - - return nothing -end - -function MAK.check_input(::typeof(right_orth_polar!), t::AbstractTensorMap, PWᴴ, ::AbstractAlgorithm) - codomain(t) ≾ domain(t) || - throw(ArgumentError("Polar decomposition requires `domain(t) ≿ codomain(t)`")) - - P, Wᴴ = PWᴴ - @assert P isa AbstractTensorMap - @assert Wᴴ isa AbstractTensorMap - - # scalartype checks - @check_scalar P t - @check_scalar Wᴴ t - - # space checks - VW = fuse(codomain(t)) - @check_space(P, codomain(t) ← VW) - @check_space(Wᴴ, VW ← domain(t)) - - return nothing -end - function MAK.initialize_output(::typeof(right_polar!), t::AbstractTensorMap, ::AbstractAlgorithm) P = similar(t, codomain(t) ← codomain(t)) Wᴴ = similar(t, space(t)) return P, Wᴴ end -# Orthogonalization -# ----------------- -function MAK.check_input(::typeof(left_orth!), t::AbstractTensorMap, VC, ::AbstractAlgorithm) - V, C = VC - - # scalartype checks - @check_scalar V t - isnothing(C) || @check_scalar C t - - # space checks - V_C = infimum(fuse(codomain(t)), fuse(domain(t))) - @check_space(V, codomain(t) ← V_C) - isnothing(C) || @check_space(C, V_C ← domain(t)) - - return nothing -end - -function MAK.check_input(::typeof(right_orth!), t::AbstractTensorMap, CVᴴ, ::AbstractAlgorithm) - C, Vᴴ = CVᴴ - - # scalartype checks - isnothing(C) || @check_scalar C t - @check_scalar Vᴴ t - - # space checks - V_C = infimum(fuse(codomain(t)), fuse(domain(t))) - isnothing(C) || @check_space(C, codomain(t) ← V_C) - @check_space(Vᴴ, V_C ← domain(t)) - - return nothing -end - -function MAK.initialize_output(::typeof(left_orth!), t::AbstractTensorMap) - V_C = infimum(fuse(codomain(t)), fuse(domain(t))) - V = similar(t, codomain(t) ← V_C) - C = similar(t, V_C ← domain(t)) - return V, C -end - -function MAK.initialize_output(::typeof(right_orth!), t::AbstractTensorMap) - V_C = infimum(fuse(codomain(t)), fuse(domain(t))) - C = similar(t, codomain(t) ← V_C) - Vᴴ = similar(t, V_C ← domain(t)) - return C, Vᴴ -end - -# This is a rework of the dispatch logic in order to avoid having to deal with having to -# allocate the output before knowing the kind of decomposition. In particular, here I disable -# providing output arguments for left_ and right_orth. -# This is mainly because polar decompositions have different shapes, and SVD for Diagonal -# also does -function MAK.left_orth!( - t::AbstractTensorMap; - trunc::TruncationStrategy = notrunc(), - kind = trunc == notrunc() ? :qr : :svd, - alg_qr = (; positive = true), alg_polar = (;), alg_svd = (;) - ) - trunc == notrunc() || kind === :svd || - throw(ArgumentError("truncation not supported for left_orth with kind = $kind")) - - return if kind === :qr - alg_qr isa NamedTuple ? qr_compact!(t; alg_qr...) : qr_compact!(t; alg = alg_qr) - elseif kind === :polar - alg_polar isa NamedTuple ? left_orth_polar!(t; alg_polar...) : - left_orth_polar!(t; alg = alg_polar) - elseif kind === :svd - alg_svd isa NamedTuple ? left_orth_svd!(t; trunc, alg_svd...) : - left_orth_svd!(t; trunc, alg = alg_svd) - else - throw(ArgumentError(lazy"`left_orth!` received unknown value `kind = $kind`")) - end -end -function MAK.right_orth!( - t::AbstractTensorMap; - trunc::TruncationStrategy = notrunc(), - kind = trunc == notrunc() ? :lq : :svd, - alg_lq = (; positive = true), alg_polar = (;), alg_svd = (;) - ) - trunc == notrunc() || kind === :svd || - throw(ArgumentError("truncation not supported for right_orth with kind = $kind")) - - return if kind === :lq - alg_lq isa NamedTuple ? lq_compact!(t; alg_lq...) : lq_compact!(t; alg = alg_lq) - elseif kind === :polar - alg_polar isa NamedTuple ? right_orth_polar!(t; alg_polar...) : - right_orth_polar!(t; alg = alg_polar) - elseif kind === :svd - alg_svd isa NamedTuple ? right_orth_svd!(t; trunc, alg_svd...) : - right_orth_svd!(t; trunc, alg = alg_svd) - else - throw(ArgumentError(lazy"`right_orth!` received unknown value `kind = $kind`")) - end -end - -function MAK.left_orth_polar!(t::AbstractTensorMap; alg = nothing, kwargs...) - alg′ = MAK.select_algorithm(left_polar!, t, alg; kwargs...) - VC = MAK.initialize_output(left_orth!, t) - return left_orth_polar!(t, VC, alg′) -end -function MAK.left_orth_polar!(t::AbstractTensorMap, VC, alg) - alg′ = MAK.select_algorithm(left_polar!, t, alg) - return left_orth_polar!(t, VC, alg′) -end -function MAK.right_orth_polar!(t::AbstractTensorMap; alg = nothing, kwargs...) - alg′ = MAK.select_algorithm(right_polar!, t, alg; kwargs...) - CVᴴ = MAK.initialize_output(right_orth!, t) - return right_orth_polar!(t, CVᴴ, alg′) -end -function MAK.right_orth_polar!(t::AbstractTensorMap, CVᴴ, alg) - alg′ = MAK.select_algorithm(right_polar!, t, alg) - return right_orth_polar!(t, CVᴴ, alg′) -end - -function MAK.left_orth_svd!(t::AbstractTensorMap; trunc = notrunc(), kwargs...) - U, S, Vᴴ = trunc == notrunc() ? svd_compact!(t; kwargs...) : - svd_trunc!(t; trunc, kwargs...) - return U, lmul!(S, Vᴴ) -end -function MAK.right_orth_svd!(t::AbstractTensorMap; trunc = notrunc(), kwargs...) - U, S, Vᴴ = trunc == notrunc() ? svd_compact!(t; kwargs...) : - svd_trunc!(t; trunc, kwargs...) - return rmul!(U, S), Vᴴ -end - -# Nullspace -# --------- -function MAK.check_input(::typeof(left_null!), t::AbstractTensorMap, N, ::AbstractAlgorithm) - # scalartype checks - @check_scalar N t - - # space checks - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = ⊖(fuse(codomain(t)), V_Q) - @check_space(N, codomain(t) ← V_N) - - return nothing -end - -function MAK.check_input(::typeof(right_null!), t::AbstractTensorMap, N, ::AbstractAlgorithm) - @check_scalar N t - - # space checks - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = ⊖(fuse(domain(t)), V_Q) - @check_space(N, V_N ← domain(t)) - - return nothing -end - -function MAK.initialize_output(::typeof(left_null!), t::AbstractTensorMap) - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = ⊖(fuse(codomain(t)), V_Q) - N = similar(t, codomain(t) ← V_N) - return N -end - -function MAK.initialize_output(::typeof(right_null!), t::AbstractTensorMap) - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = ⊖(fuse(domain(t)), V_Q) - N = similar(t, V_N ← domain(t)) - return N -end - -for (f!, f_svd!) in zip((:left_null!, :right_null!), (:left_null_svd!, :right_null_svd!)) - @eval function MAK.$f_svd!(t::AbstractTensorMap, N, alg, ::Nothing = nothing) - return $f!(t, N; alg_svd = alg) - end -end +# Projections +# ----------- +MAK.initialize_output(::typeof(project_hermitian!), tsrc::AbstractTensorMap, ::AbstractAlgorithm) = + tsrc +MAK.initialize_output(::typeof(project_antihermitian!), tsrc::AbstractTensorMap, ::AbstractAlgorithm) = + tsrc +MAK.initialize_output(::typeof(project_isometric!), tsrc::AbstractTensorMap, ::AbstractAlgorithm) = + similar(tsrc) diff --git a/src/factorizations/pullbacks.jl b/src/factorizations/pullbacks.jl index eeeec597f..488089809 100644 --- a/src/factorizations/pullbacks.jl +++ b/src/factorizations/pullbacks.jl @@ -32,12 +32,13 @@ for pullback! in (:svd_pullback!, :eig_pullback!, :eigh_pullback!) Δt::AbstractTensorMap, t::AbstractTensorMap, F, ΔF, inds = _notrunc_ind(t); kwargs... ) - for (c, ind) in inds - Δb = block(Δt, c) - b = block(t, c) + foreachblock(Δt, t) do c, (Δb, b) + haskey(inds, c) || return nothing + ind = inds[c] Fc = block.(F, Ref(c)) ΔFc = block.(ΔF, Ref(c)) MAK.$pullback!(Δb, b, Fc, ΔFc, ind; kwargs...) + return nothing end return Δt end diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 19f39759c..84f2569e0 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -21,7 +21,7 @@ Truncation strategy to keep the first values for each sector when sorted accordi such that the resulting vector space is no greater than `V`. """ function truncspace(space::ElementarySpace; by = abs, rev::Bool = true) - isdual(space) && throw(ArgumentError("resulting vector space is never dual")) + isdual(space) && throw(ArgumentError("truncation space should not be dual")) return TruncationSpace(space, by, rev) end @@ -37,7 +37,7 @@ function truncate_domain!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, inds for (c, b) in blocks(tdst) I = get(inds, c, nothing) @assert !isnothing(I) - copy!(b, @view(block(tsrc, c)[:, I])) + copy!(b, view(block(tsrc, c), :, I)) end return tdst end @@ -45,7 +45,7 @@ function truncate_codomain!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, in for (c, b) in blocks(tdst) I = get(inds, c, nothing) @assert !isnothing(I) - copy!(b, @view(block(tsrc, c)[I, :])) + copy!(b, view(block(tsrc, c), I, :)) end return tdst end @@ -76,20 +76,52 @@ function MAK.truncate( end function MAK.truncate( - ::typeof(left_null!), - (U, S)::Tuple{AbstractTensorMap, AbstractTensorMap}, - strategy::MatrixAlgebraKit.TruncationStrategy + ::typeof(left_null!), (U, S)::NTuple{2, AbstractTensorMap}, strategy::TruncationStrategy + ) + extended_S = SectorDict( + c => vcat(diagview(b), zeros(eltype(b), max(0, size(b, 1) - size(b, 2)))) + for (c, b) in blocks(S) + ) + ind = MAK.findtruncated(extended_S, strategy) + V_truncated = truncate_space(space(S, 1), ind) + Ũ = similar(U, codomain(U) ← V_truncated) + truncate_domain!(Ũ, U, ind) + return Ũ, ind +end +function MAK.truncate( + ::typeof(right_null!), (S, Vᴴ)::NTuple{2, AbstractTensorMap}, strategy::TruncationStrategy ) extended_S = SectorDict( c => vcat(diagview(b), zeros(eltype(b), max(0, size(b, 2) - size(b, 1)))) for (c, b) in blocks(S) ) ind = MAK.findtruncated(extended_S, strategy) + V_truncated = truncate_space(dual(space(S, 2)), ind) + Ṽᴴ = similar(Vᴴ, V_truncated ← domain(Vᴴ)) + truncate_codomain!(Ṽᴴ, Vᴴ, ind) + return Ṽᴴ, ind +end + +# special case `NoTruncation` for null: should keep exact zeros due to rectangularity +# need to specialize to avoid ambiguity with special case in MatrixAlgebraKit +function MAK.truncate( + ::typeof(left_null!), (U, S)::NTuple{2, AbstractTensorMap}, strategy::NoTruncation + ) + ind = SectorDict(c => (size(b, 2) + 1):size(b, 1) for (c, b) in blocks(S)) V_truncated = truncate_space(space(S, 1), ind) Ũ = similar(U, codomain(U) ← V_truncated) truncate_domain!(Ũ, U, ind) return Ũ, ind end +function MAK.truncate( + ::typeof(right_null!), (S, Vᴴ)::NTuple{2, AbstractTensorMap}, strategy::NoTruncation + ) + ind = SectorDict(c => (size(b, 1) + 1):size(b, 2) for (c, b) in blocks(S)) + V_truncated = truncate_space(dual(space(S, 2)), ind) + Ṽᴴ = similar(Vᴴ, V_truncated ← domain(Vᴴ)) + truncate_codomain!(Ṽᴴ, Vᴴ, ind) + return Ṽᴴ, ind +end for f! in (:eig_trunc!, :eigh_trunc!) @eval function MAK.truncate( @@ -113,7 +145,8 @@ end # Find truncation # --------------- # auxiliary functions -rtol_to_atol(S, p, atol, rtol) = rtol > 0 ? max(atol, TensorKit._norm(S, p) * rtol) : atol +rtol_to_atol(S, p, atol, rtol) = + rtol == 0 ? atol : max(atol, TensorKit._norm(S, p, norm(zero(scalartype(valtype(S))))) * rtol) function _compute_truncerr(Σdata, truncdim, p = 2) I = keytype(Σdata) @@ -242,3 +275,16 @@ function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationIntersect ) for c in intersect(map(keys, inds)...) ) end + +# Truncation error +# ---------------- +MAK.truncation_error(values::SectorDict, ind) = + MAK.truncation_error!(SectorDict(c => copy(v) for (c, v) in values), ind) + +function MAK.truncation_error!(values::SectorDict, ind) + for (c, ind_c) in ind + v = values[c] + v[ind_c] .= zero(eltype(v)) + end + return TensorKit._norm(values, 2, zero(real(eltype(valtype(values))))) +end diff --git a/src/factorizations/utility.jl b/src/factorizations/utility.jl index 0d2c3575b..4d5e1bb00 100644 --- a/src/factorizations/utility.jl +++ b/src/factorizations/utility.jl @@ -1,11 +1,3 @@ -# convenience to set default -macro check_space(x, V) - return esc(:($MatrixAlgebraKit.@check_size($x, $V, $space))) -end -macro check_scalar(x, y, op = :identity, eltype = :scalartype) - return esc(:($MatrixAlgebraKit.@check_scalar($x, $y, $op, $eltype))) -end - function factorisation_scalartype(t::AbstractTensorMap) T = scalartype(t) return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T))))) diff --git a/src/tensors/adjoint.jl b/src/tensors/adjoint.jl index 2a229f2c5..ca484e77b 100644 --- a/src/tensors/adjoint.jl +++ b/src/tensors/adjoint.jl @@ -11,6 +11,8 @@ struct AdjointTensorMap{T, S, N₁, N₂, TT <: AbstractTensorMap{T, S, N₂, N parent::TT end Base.parent(t::AdjointTensorMap) = t.parent +parenttype(t::AdjointTensorMap) = parenttype(typeof(t)) +parenttype(::Type{AdjointTensorMap{T, S, N₁, N₂, TT}}) where {T, S, N₁, N₂, TT} = TT # Constructor: construct from taking adjoint of a tensor Base.adjoint(t::AdjointTensorMap) = parent(t) diff --git a/test/autodiff/ad.jl b/test/autodiff/ad.jl index 15c8c387e..53f007b21 100644 --- a/test/autodiff/ad.jl +++ b/test/autodiff/ad.jl @@ -77,6 +77,9 @@ function test_ad_rrule(f, args...; check_inferred = false, kwargs...) return nothing end +# project_hermitian is non-differentiable for now +_project_hermitian(x) = (x + x') / 2 + # Gauge fixing tangents # --------------------- function remove_qrgauge_dependence!(ΔQ, t, Q) @@ -90,7 +93,6 @@ function remove_qrgauge_dependence!(ΔQ, t, Q) end return ΔQ end - function remove_lqgauge_dependence!(ΔQ, t, Q) for (c, b) in blocks(ΔQ) m, n = size(block(t, c)) @@ -103,7 +105,7 @@ function remove_lqgauge_dependence!(ΔQ, t, Q) return ΔQ end function remove_eiggauge_dependence!( - ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D) + ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D) ) gaugepart = V' * ΔV for (c, b) in blocks(gaugepart) @@ -119,10 +121,9 @@ function remove_eiggauge_dependence!( return ΔV end function remove_eighgauge_dependence!( - ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D) + ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D) ) - gaugepart = V' * ΔV - gaugepart = (gaugepart - gaugepart') / 2 + gaugepart = project_antihermitian!(V' * ΔV) for (c, b) in blocks(gaugepart) Dc = diagview(block(D, c)) # for some reason this fails only on tests, and I cannot reproduce it in an @@ -136,10 +137,9 @@ function remove_eighgauge_dependence!( return ΔV end function remove_svdgauge_dependence!( - ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) + ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(S) ) - gaugepart = U' * ΔU + Vᴴ * ΔVᴴ' - gaugepart = (gaugepart - gaugepart') / 2 + gaugepart = project_antihermitian!(U' * ΔU + Vᴴ * ΔVᴴ') for (c, b) in blocks(gaugepart) Sd = diagview(block(S, c)) # for some reason this fails only on tests, and I cannot reproduce it in an @@ -153,8 +153,6 @@ function remove_svdgauge_dependence!( return ΔU, ΔVᴴ end -project_hermitian(A) = (A + A') / 2 - # Tests # ----- @@ -569,7 +567,7 @@ for V in spacelist remove_eighgauge_dependence!(Δv, d, v) # necessary for FiniteDifferences to not complain - eigh_full′ = eigh_full ∘ project_hermitian + eigh_full′ = eigh_full ∘ _project_hermitian test_ad_rrule(eigh_full′, t; output_tangent = (Δd, Δv), atol, rtol) test_ad_rrule(first ∘ eigh_full′, t; output_tangent = Δd, atol, rtol) @@ -595,49 +593,46 @@ for V in spacelist test_ad_rrule(svd_compact, t; output_tangent = (ΔU, ΔS, ΔVᴴ), atol, rtol) test_ad_rrule(svd_compact, t; output_tangent = (ΔU, ΔS2, ΔVᴴ), atol, rtol) - # TODO: I'm not sure how to properly test with spaces that might change - # with the finite-difference methods, as then the jacobian is ill-defined. - - trunc = truncrank(max(2, round(Int, min(dim(domain(t)), dim(codomain(t))) * (3 / 4)))) + # Testing truncation with finitedifferences is RNG-prone since the + # Jacobian changes size if the truncation space changes, causing errors. + # So, first test the fixed space case, then do more limited testing on + # some gradients and compare to the fixed space case + V_trunc = spacetype(t)(c => div(min(size(b)...), 2) for (c, b) in blocks(t)) + trunc = truncspace(V_trunc) USVᴴ_trunc = svd_trunc(t; trunc) - ΔUSVᴴ_trunc = rand_tangent.(USVᴴ_trunc) + ΔUSVᴴ_trunc = (rand_tangent.(Base.front(USVᴴ_trunc))..., zero(last(USVᴴ_trunc))) remove_svdgauge_dependence!( - ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], USVᴴ_trunc...; degeneracy_atol - ) - # test_ad_rrule(svd_trunc, t; - # fkwargs=(; trunc), output_tangent=ΔUSVᴴ_trunc, atol, rtol) - - trunc = truncspace(space(USVᴴ_trunc[2], 1)) - USVᴴ_trunc = svd_trunc(t; trunc) - ΔUSVᴴ_trunc = rand_tangent.(USVᴴ_trunc) - remove_svdgauge_dependence!( - ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], USVᴴ_trunc...; degeneracy_atol + ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], Base.front(USVᴴ_trunc)...; degeneracy_atol ) test_ad_rrule( svd_trunc, t; fkwargs = (; trunc), output_tangent = ΔUSVᴴ_trunc, atol, rtol ) - # ϵ = norm(*(USVᴴ_trunc...) - t) - # trunc = truncerror(; atol=ϵ) - # USVᴴ_trunc = svd_trunc(t; trunc) - # ΔUSVᴴ_trunc = rand_tangent.(USVᴴ_trunc) - # remove_svdgauge_dependence!(ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], USVᴴ_trunc...; - # degeneracy_atol) - # test_ad_rrule(svd_trunc, t; - # fkwargs=(; trunc), output_tangent=ΔUSVᴴ_trunc, atol, rtol) + # attempt to construct a loss function that doesn't depend on the gauges + function f(t; trunc) + Utr, Str, Vᴴtr, ϵ = svd_trunc(t; trunc) + return LinearAlgebra.tr(Str) + LinearAlgebra.norm(Utr * Vᴴtr) + end - tol = minimum(((c, b),) -> minimum(diagview(b)), blocks(USVᴴ_trunc[2])) + trunc = truncrank(ceil(Int, dim(V_trunc))) + USVᴴ_trunc′ = svd_trunc(t; trunc) + g1, = Zygote.gradient(x -> f(x; trunc), t) + g2, = Zygote.gradient(x -> f(x; trunc = truncspace(space(USVᴴ_trunc′[2], 1))), t) + @test g1 ≈ g2 + + trunc = truncerror(; atol = last(USVᴴ_trunc)) + USVᴴ_trunc′ = svd_trunc(t; trunc) + g1, = Zygote.gradient(x -> f(x; trunc), t) + g2, = Zygote.gradient(x -> f(x; trunc = truncspace(space(USVᴴ_trunc′[2], 1))), t) + @test g1 ≈ g2 + + tol = minimum(((c, b),) -> minimum(diagview(b)), blocks(USVᴴ_trunc[2]); init = zero(scalartype(USVᴴ_trunc[2]))) trunc = trunctol(; atol = 10 * tol) - USVᴴ_trunc = svd_trunc(t; trunc) - ΔUSVᴴ_trunc = rand_tangent.(USVᴴ_trunc) - remove_svdgauge_dependence!( - ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], USVᴴ_trunc...; degeneracy_atol - ) - test_ad_rrule( - svd_trunc, t; - fkwargs = (; trunc), output_tangent = ΔUSVᴴ_trunc, atol, rtol - ) + USVᴴ_trunc′ = svd_trunc(t; trunc) + g1, = Zygote.gradient(x -> f(x; trunc), t) + g2, = Zygote.gradient(x -> f(x; trunc = truncspace(space(USVᴴ_trunc′[2], 1))), t) + @test g1 ≈ g2 end end diff --git a/test/tensors/factorizations.jl b/test/tensors/factorizations.jl index c5b4e8748..2c1638a30 100644 --- a/test/tensors/factorizations.jl +++ b/test/tensors/factorizations.jl @@ -47,18 +47,18 @@ for V in spacelist Q, R = @constinferred qr_compact(t) @test Q * R ≈ t - @test isisometry(Q) + @test isisometric(Q) - Q, R = @constinferred left_orth(t; kind = :qr) + Q, R = @constinferred left_orth(t) @test Q * R ≈ t - @test isisometry(Q) + @test isisometric(Q) N = @constinferred qr_null(t) - @test isisometry(N) + @test isisometric(N) @test norm(N' * t) ≈ 0 atol = 100 * eps(norm(t)) - N = @constinferred left_null(t; kind = :qr) - @test isisometry(N) + N = @constinferred left_null(t) + @test isisometric(N) @test norm(N' * t) ≈ 0 atol = 100 * eps(norm(t)) end @@ -73,12 +73,12 @@ for V in spacelist Q, R = @constinferred qr_compact(t) @test Q * R ≈ t - @test isisometry(Q) + @test isisometric(Q) @test dim(Q) == dim(R) == dim(t) - Q, R = @constinferred left_orth(t; kind = :qr) + Q, R = @constinferred left_orth(t) @test Q * R ≈ t - @test isisometry(Q) + @test isisometric(Q) @test dim(Q) == dim(R) == dim(t) N = @constinferred qr_null(t) @@ -100,14 +100,14 @@ for V in spacelist L, Q = @constinferred lq_compact(t) @test L * Q ≈ t - @test isisometry(Q; side = :right) + @test isisometric(Q; side = :right) - L, Q = @constinferred right_orth(t; kind = :lq) + L, Q = @constinferred right_orth(t) @test L * Q ≈ t - @test isisometry(Q; side = :right) + @test isisometric(Q; side = :right) Nᴴ = @constinferred lq_null(t) - @test isisometry(Nᴴ; side = :right) + @test isisometric(Nᴴ; side = :right) @test norm(t * Nᴴ') ≈ 0 atol = 100 * eps(norm(t)) end @@ -122,12 +122,12 @@ for V in spacelist L, Q = @constinferred lq_compact(t) @test L * Q ≈ t - @test isisometry(Q; side = :right) + @test isisometric(Q; side = :right) @test dim(Q) == dim(L) == dim(t) - L, Q = @constinferred right_orth(t; kind = :lq) + L, Q = @constinferred right_orth(t) @test L * Q ≈ t - @test isisometry(Q; side = :right) + @test isisometric(Q; side = :right) @test dim(Q) == dim(L) == dim(t) Nᴴ = @constinferred lq_null(t) @@ -146,12 +146,12 @@ for V in spacelist @assert domain(t) ≾ codomain(t) w, p = @constinferred left_polar(t) @test w * p ≈ t - @test isisometry(w) + @test isisometric(w) @test isposdef(p) - w, p = @constinferred left_orth(t; kind = :polar) + w, p = @constinferred left_orth(t; alg = :polar) @test w * p ≈ t - @test isisometry(w) + @test isisometric(w) end for T in eltypes, @@ -160,12 +160,12 @@ for V in spacelist @assert codomain(t) ≾ domain(t) p, wᴴ = @constinferred right_polar(t) @test p * wᴴ ≈ t - @test isisometry(wᴴ; side = :right) + @test isisometric(wᴴ; side = :right) @test isposdef(p) - p, wᴴ = @constinferred right_orth(t; kind = :polar) + p, wᴴ = @constinferred right_orth(t; alg = :polar) @test p * wᴴ ≈ t - @test isisometry(wᴴ; side = :right) + @test isisometric(wᴴ; side = :right) end end @@ -185,25 +185,37 @@ for V in spacelist u, s, vᴴ = @constinferred svd_compact(t) @test u * s * vᴴ ≈ t - @test isisometry(u) + @test isisometric(u) @test isposdef(s) - @test isisometry(vᴴ; side = :right) + @test isisometric(vᴴ; side = :right) s′ = LinearAlgebra.diag(s) for (c, b) in LinearAlgebra.svdvals(t) @test b ≈ s′[c] end - v, c = @constinferred left_orth(t; kind = :svd) + v, c = @constinferred left_orth(t; alg = :svd) @test v * c ≈ t - @test isisometry(v) + @test isisometric(v) - N = @constinferred left_null(t; kind = :svd) - @test isisometry(N) + c, vᴴ = @constinferred right_orth(t; alg = :svd) + @test c * vᴴ ≈ t + @test isisometric(vᴴ; side = :right) + + N = @constinferred left_null(t; alg = :svd) + @test isisometric(N) + @test norm(N' * t) ≈ 0 atol = 100 * eps(norm(t)) + + N = @constinferred left_null(t; trunc = (; atol = 100 * eps(norm(t)))) + @test isisometric(N) @test norm(N' * t) ≈ 0 atol = 100 * eps(norm(t)) - Nᴴ = @constinferred right_null(t; kind = :svd) - @test isisometry(Nᴴ; side = :right) + Nᴴ = @constinferred right_null(t; alg = :svd) + @test isisometric(Nᴴ; side = :right) + @test norm(t * Nᴴ') ≈ 0 atol = 100 * eps(norm(t)) + + Nᴴ = @constinferred right_null(t; trunc = (; atol = 100 * eps(norm(t)))) + @test isisometric(Nᴴ; side = :right) @test norm(t * Nᴴ') ≈ 0 atol = 100 * eps(norm(t)) end @@ -231,48 +243,55 @@ for V in spacelist @constinferred normalize!(t) - U, S, Vᴴ = @constinferred svd_trunc(t; trunc = notrunc()) + U, S, Vᴴ, ϵ = @constinferred svd_trunc(t; trunc = notrunc()) @test U * S * Vᴴ ≈ t - @test isisometry(U) - @test isisometry(Vᴴ; side = :right) + @test ϵ ≈ 0 + @test isisometric(U) + @test isisometric(Vᴴ; side = :right) trunc = truncrank(dim(domain(S)) ÷ 2) - U1, S1, Vᴴ1 = @constinferred svd_trunc(t; trunc) + U1, S1, Vᴴ1, ϵ1 = @constinferred svd_trunc(t; trunc) @test t * Vᴴ1' ≈ U1 * S1 - @test isisometry(U1) - @test isisometry(Vᴴ1; side = :right) + @test isisometric(U1) + @test isisometric(Vᴴ1; side = :right) + @test norm(t - U1 * S1 * Vᴴ1) ≈ ϵ1 atol = eps(real(T))^(4 / 5) @test dim(domain(S1)) <= trunc.howmany λ = minimum(minimum, values(LinearAlgebra.diag(S1))) trunc = trunctol(; atol = λ - 10eps(λ)) - U2, S2, Vᴴ2 = @constinferred svd_trunc(t; trunc) + U2, S2, Vᴴ2, ϵ2 = @constinferred svd_trunc(t; trunc) @test t * Vᴴ2' ≈ U2 * S2 - @test isisometry(U2) - @test isisometry(Vᴴ2; side = :right) + @test isisometric(U2) + @test isisometric(Vᴴ2; side = :right) + @test norm(t - U2 * S2 * Vᴴ2) ≈ ϵ2 atol = eps(real(T))^(4 / 5) @test minimum(minimum, values(LinearAlgebra.diag(S1))) >= λ @test U2 ≈ U1 @test S2 ≈ S1 @test Vᴴ2 ≈ Vᴴ1 + @test ϵ1 ≈ ϵ2 trunc = truncspace(space(S2, 1)) - U3, S3, Vᴴ3 = @constinferred svd_trunc(t; trunc) + U3, S3, Vᴴ3, ϵ3 = @constinferred svd_trunc(t; trunc) @test t * Vᴴ3' ≈ U3 * S3 - @test isisometry(U3) - @test isisometry(Vᴴ3; side = :right) + @test isisometric(U3) + @test isisometric(Vᴴ3; side = :right) + @test norm(t - U3 * S3 * Vᴴ3) ≈ ϵ3 atol = eps(real(T))^(4 / 5) @test space(S3, 1) ≾ space(S2, 1) - trunc = truncerror(; atol = 0.5) - U4, S4, Vᴴ4 = @constinferred svd_trunc(t; trunc) + trunc = truncerror(; atol = ϵ2) + U4, S4, Vᴴ4, ϵ4 = @constinferred svd_trunc(t; trunc) @test t * Vᴴ4' ≈ U4 * S4 - @test isisometry(U4) - @test isisometry(Vᴴ4; side = :right) - @test norm(t - U4 * S4 * Vᴴ4) <= 0.5 + @test isisometric(U4) + @test isisometric(Vᴴ4; side = :right) + @test norm(t - U4 * S4 * Vᴴ4) ≈ ϵ4 atol = eps(real(T))^(4 / 5) + @test ϵ4 ≤ ϵ2 trunc = truncrank(dim(domain(S)) ÷ 2) & trunctol(; atol = λ - 10eps(λ)) - U5, S5, Vᴴ5 = @constinferred svd_trunc(t; trunc) + U5, S5, Vᴴ5, ϵ5 = @constinferred svd_trunc(t; trunc) @test t * Vᴴ5' ≈ U5 * S5 - @test isisometry(U5) - @test isisometry(Vᴴ5; side = :right) + @test isisometric(U5) + @test isisometric(Vᴴ5; side = :right) + @test norm(t - U5 * S5 * Vᴴ5) ≈ ϵ5 atol = eps(real(T))^(4 / 5) @test minimum(minimum, values(LinearAlgebra.diag(S5))) >= λ @test dim(domain(S5)) ≤ dim(domain(S)) ÷ 2 end @@ -304,7 +323,7 @@ for V in spacelist t2 = (t + t') D, V = eigen(t2) - @test isisometry(V) + @test isisometric(V) D̃, Ṽ = @constinferred eigh_full(t2) @test D ≈ D̃ @test V ≈ Ṽ @@ -370,5 +389,62 @@ for V in spacelist @test cond(t) ≈ λmax / λmin end end + + @testset "Hermitian projections" begin + for T in eltypes, + t in ( + rand(T, V1, V1), rand(T, W, W), rand(T, W, W)', + DiagonalTensorMap(rand(T, reduceddim(V1)), V1), + ) + normalize!(t) + noisefactor = eps(real(T))^(3 / 4) + + th = (t + t') / 2 + ta = (t - t') / 2 + tc = copy(t) + + th′ = @constinferred project_hermitian(t) + @test ishermitian(th′) + @test th′ ≈ th + @test t == tc + th_approx = th + noisefactor * ta + @test !ishermitian(th_approx) || (T <: Real && t isa DiagonalTensorMap) + @test ishermitian(th_approx; atol = 10 * noisefactor) + + ta′ = project_antihermitian(t) + @test isantihermitian(ta′) + @test ta′ ≈ ta + @test t == tc + ta_approx = ta + noisefactor * th + @test !isantihermitian(ta_approx) + @test isantihermitian(ta_approx; atol = 10 * noisefactor) || (T <: Real && t isa DiagonalTensorMap) + end + end + + @testset "Isometric projections" begin + for T in eltypes, + t in ( + randn(T, W, W), randn(T, W, W)', + randn(T, W, V1), randn(T, V1, W)', + ) + t2 = project_isometric(t) + @test isisometric(t2) + t3 = project_isometric(t2) + @test t3 ≈ t2 # stability of the projection + @test t2 * (t2' * t) ≈ t + + tc = similar(t) + t3 = @constinferred project_isometric!(copy!(tc, t), t2) + @test t3 === t2 + @test isisometric(t2) + + # test that t2 is closer to A then any other isometry + for k in 1:10 + δt = randn!(similar(t)) + t3 = project_isometric(t + δt / 100) + @test norm(t - t3) > norm(t - t2) + end + end + end end end diff --git a/test/tensors/tensors.jl b/test/tensors/tensors.jl index 51c326cf9..34934eb34 100644 --- a/test/tensors/tensors.jl +++ b/test/tensors/tensors.jl @@ -364,7 +364,7 @@ for V in spacelist for T in (Float64, ComplexF64) t1 = randisometry(T, W1, W2) t2 = randisometry(T, W2 ← W2) - @test isisometry(t1) + @test isisometric(t1) @test isunitary(t2) P = t1 * t1' @test P * P ≈ P