diff --git a/src/common/defaults.jl b/src/common/defaults.jl index c275ee87..332807da 100644 --- a/src/common/defaults.jl +++ b/src/common/defaults.jl @@ -10,22 +10,34 @@ quantity needs to be computed. defaulttol(x::Any) = eps(real(float(one(eltype(x)))))^(2 / 3) """ - default_pullback_gaugetol(a) + default_pullback_gauge_atol(ΔA...) Default tolerance for deciding to warn if incoming adjoints of a pullback rule has components that are not gauge-invariant. """ -function default_pullback_gaugetol(a) - n = norm(a, Inf) - return eps(eltype(n))^(3 / 4) * max(n, one(n)) +default_pullback_gauge_atol(A) = iszerotangent(A) ? 0 : eps(norm(A, Inf))^(3 / 4) +function default_pullback_gauge_atol(A, As...) + As′ = filter(!iszerotangent, (A, As...)) + return isempty(As′) ? 0 : eps(norm(As′, Inf))^(3 / 4) end +""" + default_pullback_degeneracy_atol(A) + +Default tolerance for deciding when values should be considered as degenerate. +""" +default_pullback_degeneracy_atol(A) = eps(norm(A, Inf))^(3 / 4) + +""" + default_pullback_rank_atol(A) + +Default tolerance for deciding what values should be considered equal to 0. +""" +default_pullback_rank_atol(A) = eps(norm(A, Inf))^(3 / 4) + """ default_hermitian_tol(A) Default tolerance for deciding to warn if the provided `A` is not hermitian. """ -function default_hermitian_tol(A) - n = norm(A, Inf) - return eps(eltype(n))^(3 / 4) * max(n, one(n)) -end +default_hermitian_tol(A) = eps(norm(A, Inf))^(3 / 4) diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index c18a3158..3115b3d5 100644 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -1,9 +1,8 @@ """ eig_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV, [ind]; - tol = default_pullback_gaugetol(DV[1]), - degeneracy_atol = tol, - gauge_atol = tol + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) Adds the pullback from the full eigenvalue decomposition of `A` to `ΔA`, given the output @@ -22,9 +21,8 @@ not small compared to `gauge_atol`. """ function eig_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV, ind = Colon(); - tol::Real = default_pullback_gaugetol(DV[1]), - degeneracy_atol::Real = tol, - gauge_atol::Real = tol + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) # Basic size checks and determination @@ -84,9 +82,8 @@ end """ eig_trunc_pullback!( ΔA::AbstractMatrix, ΔDV, A, DV; - tol = default_pullback_gaugetol(DV[1]), - degeneracy_atol = tol, - gauge_atol = tol + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) Adds the pullback from the truncated eigenvalue decomposition of `A` to `ΔA`, given the @@ -106,9 +103,8 @@ not small compared to `gauge_atol`. """ function eig_trunc_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV; - tol::Real = default_pullback_gaugetol(DV[1]), - degeneracy_atol::Real = tol, - gauge_atol::Real = tol + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) # Basic size checks and determination diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index b15f912e..e7c93c75 100644 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -1,9 +1,8 @@ """ eigh_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV, [ind]; - tol = default_pullback_gaugetol(DV[1]), - degeneracy_atol = tol, - gauge_atol = tol + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) Adds the pullback from the Hermitian eigenvalue decomposition of `A` to `ΔA`, given the @@ -22,9 +21,8 @@ anti-hermitian part of `V' * ΔV`, restricted to rows `i` and columns `j` for wh """ function eigh_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV, ind = Colon(); - tol::Real = default_pullback_gaugetol(DV[1]), - degeneracy_atol::Real = tol, - gauge_atol::Real = tol + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) # Basic size checks and determination @@ -49,7 +47,7 @@ function eigh_pullback!( Δgauge < gauge_atol || @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - aVᴴΔV .*= inv_safe.(D' .- D, tol) + aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol) if !iszerotangent(ΔDmat) ΔDvec = diagview(ΔDmat) @@ -74,9 +72,8 @@ end """ eigh_trunc_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV; - tol=default_pullback_gaugetol(DV[1]), - degeneracy_atol=tol, - gauge_atol=tol + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) Adds the pullback from the truncated Hermitian eigenvalue decomposition of `A` to `ΔA`, @@ -96,9 +93,8 @@ not small compared to `gauge_atol`. """ function eigh_trunc_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV; - tol::Real = default_pullback_gaugetol(DV[1]), - degeneracy_atol::Real = tol, - gauge_atol::Real = tol + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) # Basic size checks and determination @@ -119,7 +115,7 @@ function eigh_trunc_pullback!( Δgauge < gauge_atol || @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - aVᴴΔV .*= inv_safe.(D' .- D, tol) + aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol) if !iszerotangent(ΔDmat) ΔDvec = diagview(ΔDmat) diff --git a/src/pullbacks/lq.jl b/src/pullbacks/lq.jl index 05e23c38..d2cfa290 100644 --- a/src/pullbacks/lq.jl +++ b/src/pullbacks/lq.jl @@ -1,9 +1,8 @@ """ lq_pullback!( ΔA, A, LQ, ΔLQ; - tol::Real = default_pullback_gaugetol(LQ[1]), - rank_atol::Real = tol, - gauge_atol::Real = tol + rank_atol::Real = default_pullback_rank_atol(LQ[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔLQ[2]) ) Adds the pullback from the LQ decomposition of `A` to `ΔA` given the output `LQ` and @@ -18,9 +17,8 @@ or rows exceed `gauge_atol`, a warning will be printed. """ function lq_pullback!( ΔA::AbstractMatrix, A, LQ, ΔLQ; - tol::Real = default_pullback_gaugetol(LQ[1]), - rank_atol::Real = tol, - gauge_atol::Real = tol + rank_atol::Real = default_pullback_rank_atol(LQ[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔLQ[2]) ) # process L, Q = LQ @@ -28,7 +26,7 @@ function lq_pullback!( n = size(Q, 2) minmn = min(m, n) Ld = diagview(L) - p = findlast(>=(rank_atol) ∘ abs, Ld) + p = @something findlast(>=(rank_atol) ∘ abs, Ld) 0 ΔL, ΔQ = ΔLQ @@ -72,7 +70,7 @@ function lq_pullback!( # Q2' * ΔQ2 as a gauge dependent quantity. ΔQ2Q1ᴴ = ΔQ2 * Q1' Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf) - Δgauge < tol || + Δgauge < gauge_atol || @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1) end @@ -105,7 +103,10 @@ function lq_pullback!( end """ - lq_null_pullback(ΔA, A, Nᴴ, ΔNᴴ) + lq_null_pullback!( + ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ; + gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ) + ) Adds the pullback from the left nullspace of `A` to `ΔA`, given the nullspace basis `Nᴴ` and its cotangent `ΔNᴴ` of `lq_null(A)`. @@ -114,13 +115,12 @@ See also [`lq_pullback!`](@ref). """ function lq_null_pullback!( ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ; - tol::Real = default_pullback_gaugetol(A), - gauge_atol::Real = tol + gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ) ) if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0 aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ') Δgauge = norm(aNᴴΔN) - Δgauge < tol || + Δgauge < gauge_atol || @warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)" L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here? X = ldiv!(LowerTriangular(L)', Q * ΔNᴴ') diff --git a/src/pullbacks/qr.jl b/src/pullbacks/qr.jl index bc49d6af..10882bdf 100644 --- a/src/pullbacks/qr.jl +++ b/src/pullbacks/qr.jl @@ -2,8 +2,8 @@ qr_pullback!( ΔA, A, QR, ΔQR; tol::Real = default_pullback_gaugetol(QR[2]), - rank_atol::Real = tol, - gauge_atol::Real = tol + rank_atol::Real = default_pullback_rank_atol(QR[2]), + gauge_atol::Real = default_pullback_gauge_atol(ΔQR[1]) ) Adds the pullback from the QR decomposition of `A` to `ΔA` given the output `QR` and @@ -18,9 +18,8 @@ and also the adjoint variables `ΔQ` and `ΔR` should have nonzero values only i """ function qr_pullback!( ΔA::AbstractMatrix, A, QR, ΔQR; - tol::Real = default_pullback_gaugetol(QR[2]), - rank_atol::Real = tol, - gauge_atol::Real = tol + rank_atol::Real = default_pullback_rank_atol(QR[2]), + gauge_atol::Real = default_pullback_gauge_atol(ΔQR[1]) ) # process Q, R = QR @@ -28,7 +27,7 @@ function qr_pullback!( n = size(R, 2) minmn = min(m, n) Rd = diagview(R) - p = findlast(>=(rank_atol) ∘ abs, Rd) + p = @something findlast(>=(rank_atol) ∘ abs, Rd) 0 ΔQ, ΔR = ΔQR @@ -71,7 +70,7 @@ function qr_pullback!( # Q2' * ΔQ2 as a gauge dependent quantity. Q1dΔQ2 = Q1' * ΔQ2 Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) - Δgauge < tol || + Δgauge < gauge_atol || @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1) end @@ -104,7 +103,10 @@ function qr_pullback!( end """ - qr_null_pullback(ΔA, A, N, ΔN) + qr_null_pullback!( + ΔA::AbstractMatrix, A, N, ΔN; + gauge_atol::Real = default_pullback_gauge_atol(ΔN) + ) Adds the pullback from the right nullspace of `A` to `ΔA`, given the nullspace basis `N` and its cotangent `ΔN` of `qr_null(A)`. @@ -113,13 +115,12 @@ See also [`qr_pullback!`](@ref). """ function qr_null_pullback!( ΔA::AbstractMatrix, A, N, ΔN; - tol::Real = default_pullback_gaugetol(A), - gauge_atol::Real = tol + gauge_atol::Real = default_pullback_gauge_atol(ΔN) ) if !iszerotangent(ΔN) && size(N, 2) > 0 aNᴴΔN = project_antihermitian!(N' * ΔN) Δgauge = norm(aNᴴΔN) - Δgauge < tol || + Δgauge < gauge_atol || @warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)" Q, R = qr_compact(A; positive = true) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index effceaa3..c0353a3a 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -1,10 +1,9 @@ """ svd_pullback!( ΔA, A, USVᴴ, ΔUSVᴴ, [ind]; - tol::Real=default_pullback_gaugetol(USVᴴ[2]), - rank_atol::Real = tol, - degeneracy_atol::Real = tol, - gauge_atol::Real = tol + rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) ) Adds the pullback from the SVD of `A` to `ΔA` given the output USVᴴ of `svd_compact` or @@ -23,10 +22,9 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol """ function svd_pullback!( ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ, ind = Colon(); - tol::Real = default_pullback_gaugetol(USVᴴ[2]), - rank_atol::Real = tol, - degeneracy_atol::Real = tol, - gauge_atol::Real = tol + rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) ) # Extract the SVD components @@ -106,10 +104,9 @@ end """ svd_trunc_pullback!( ΔA, A, USVᴴ, ΔUSVᴴ; - tol::Real=default_pullback_gaugetol(S), - rank_atol::Real = tol, - degeneracy_atol::Real = tol, - gauge_atol::Real = tol + rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) ) Adds the pullback from the truncated SVD of `A` to `ΔA`, given the output `USVᴴ` and the @@ -128,10 +125,9 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol """ function svd_trunc_pullback!( ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ; - tol::Real = default_pullback_gaugetol(USVᴴ[2]), - rank_atol::Real = tol, - degeneracy_atol::Real = tol, - gauge_atol::Real = tol + rank_atol::Real = 0, + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) ) # Extract the SVD components