From 39fb00af3ccc9bebb1c8d1e623580fbb179a8fd6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 5 Nov 2025 14:48:22 -0500 Subject: [PATCH 1/6] update pullback tolerances --- src/common/defaults.jl | 21 ++++++++++++++++----- src/pullbacks/eig.jl | 20 ++++++++------------ src/pullbacks/eigh.jl | 24 ++++++++++-------------- src/pullbacks/lq.jl | 24 ++++++++++++------------ src/pullbacks/qr.jl | 23 ++++++++++++----------- src/pullbacks/svd.jl | 28 ++++++++++++---------------- 6 files changed, 70 insertions(+), 70 deletions(-) diff --git a/src/common/defaults.jl b/src/common/defaults.jl index c275ee87..9aca5391 100644 --- a/src/common/defaults.jl +++ b/src/common/defaults.jl @@ -10,15 +10,26 @@ quantity needs to be computed. defaulttol(x::Any) = eps(real(float(one(eltype(x)))))^(2 / 3) """ - default_pullback_gaugetol(a) + default_pullback_gauge_atol(A) Default tolerance for deciding to warn if incoming adjoints of a pullback rule has components that are not gauge-invariant. """ -function default_pullback_gaugetol(a) - n = norm(a, Inf) - return eps(eltype(n))^(3 / 4) * max(n, one(n)) -end +default_pullback_gauge_atol(A) = eps(real(eltype(A)))^(3 / 4) + +""" + default_pullback_degeneracy_atol(A) + +Default tolerance for deciding which singular values should be considered as degenerate. +""" +default_pullback_degeneracy_atol(A) = eps(norm(A, Inf))^(3 / 4) + +""" + default_pullback_rank_atol(A) + +Default tolerance for deciding what singular values should be considered equal to 0. +""" +default_pullback_rank_atol(A) = eps(norm(A, Inf))^(3 / 4) """ default_hermitian_tol(A) diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index c18a3158..f3b5991d 100644 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -1,9 +1,8 @@ """ eig_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV, [ind]; - tol = default_pullback_gaugetol(DV[1]), - degeneracy_atol = tol, - gauge_atol = tol + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(DV[1]) ) Adds the pullback from the full eigenvalue decomposition of `A` to `ΔA`, given the output @@ -22,9 +21,8 @@ not small compared to `gauge_atol`. """ function eig_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV, ind = Colon(); - tol::Real = default_pullback_gaugetol(DV[1]), - degeneracy_atol::Real = tol, - gauge_atol::Real = tol + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(DV[1]) ) # Basic size checks and determination @@ -84,9 +82,8 @@ end """ eig_trunc_pullback!( ΔA::AbstractMatrix, ΔDV, A, DV; - tol = default_pullback_gaugetol(DV[1]), - degeneracy_atol = tol, - gauge_atol = tol + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(DV[1]) ) Adds the pullback from the truncated eigenvalue decomposition of `A` to `ΔA`, given the @@ -106,9 +103,8 @@ not small compared to `gauge_atol`. """ function eig_trunc_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV; - tol::Real = default_pullback_gaugetol(DV[1]), - degeneracy_atol::Real = tol, - gauge_atol::Real = tol + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(DV[1]) ) # Basic size checks and determination diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index b15f912e..15d3cdb2 100644 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -1,9 +1,8 @@ """ eigh_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV, [ind]; - tol = default_pullback_gaugetol(DV[1]), - degeneracy_atol = tol, - gauge_atol = tol + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(DV[1]) ) Adds the pullback from the Hermitian eigenvalue decomposition of `A` to `ΔA`, given the @@ -22,9 +21,8 @@ anti-hermitian part of `V' * ΔV`, restricted to rows `i` and columns `j` for wh """ function eigh_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV, ind = Colon(); - tol::Real = default_pullback_gaugetol(DV[1]), - degeneracy_atol::Real = tol, - gauge_atol::Real = tol + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(DV[1]) ) # Basic size checks and determination @@ -49,7 +47,7 @@ function eigh_pullback!( Δgauge < gauge_atol || @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - aVᴴΔV .*= inv_safe.(D' .- D, tol) + aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol) if !iszerotangent(ΔDmat) ΔDvec = diagview(ΔDmat) @@ -74,9 +72,8 @@ end """ eigh_trunc_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV; - tol=default_pullback_gaugetol(DV[1]), - degeneracy_atol=tol, - gauge_atol=tol + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(DV[1]) ) Adds the pullback from the truncated Hermitian eigenvalue decomposition of `A` to `ΔA`, @@ -96,9 +93,8 @@ not small compared to `gauge_atol`. """ function eigh_trunc_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV; - tol::Real = default_pullback_gaugetol(DV[1]), - degeneracy_atol::Real = tol, - gauge_atol::Real = tol + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(DV[1]) ) # Basic size checks and determination @@ -119,7 +115,7 @@ function eigh_trunc_pullback!( Δgauge < gauge_atol || @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - aVᴴΔV .*= inv_safe.(D' .- D, tol) + aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol) if !iszerotangent(ΔDmat) ΔDvec = diagview(ΔDmat) diff --git a/src/pullbacks/lq.jl b/src/pullbacks/lq.jl index 05e23c38..cef82af1 100644 --- a/src/pullbacks/lq.jl +++ b/src/pullbacks/lq.jl @@ -1,9 +1,8 @@ """ lq_pullback!( ΔA, A, LQ, ΔLQ; - tol::Real = default_pullback_gaugetol(LQ[1]), - rank_atol::Real = tol, - gauge_atol::Real = tol + rank_atol::Real = default_pullback_rank_atol(LQ[1]), + gauge_atol::Real = default_pullback_gauge_atol(LQ[1]) ) Adds the pullback from the LQ decomposition of `A` to `ΔA` given the output `LQ` and @@ -18,9 +17,8 @@ or rows exceed `gauge_atol`, a warning will be printed. """ function lq_pullback!( ΔA::AbstractMatrix, A, LQ, ΔLQ; - tol::Real = default_pullback_gaugetol(LQ[1]), - rank_atol::Real = tol, - gauge_atol::Real = tol + rank_atol::Real = default_pullback_rank_atol(LQ[1]), + gauge_atol::Real = default_pullback_gauge_atol(LQ[1]) ) # process L, Q = LQ @@ -28,7 +26,7 @@ function lq_pullback!( n = size(Q, 2) minmn = min(m, n) Ld = diagview(L) - p = findlast(>=(rank_atol) ∘ abs, Ld) + p = @something findlast(>=(rank_atol) ∘ abs, Ld) 0 ΔL, ΔQ = ΔLQ @@ -72,7 +70,7 @@ function lq_pullback!( # Q2' * ΔQ2 as a gauge dependent quantity. ΔQ2Q1ᴴ = ΔQ2 * Q1' Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf) - Δgauge < tol || + Δgauge < gauge_atol || @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1) end @@ -105,7 +103,10 @@ function lq_pullback!( end """ - lq_null_pullback(ΔA, A, Nᴴ, ΔNᴴ) + lq_null_pullback!( + ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ; + gauge_atol::Real = default_pullback_gauge_atol(A) + ) Adds the pullback from the left nullspace of `A` to `ΔA`, given the nullspace basis `Nᴴ` and its cotangent `ΔNᴴ` of `lq_null(A)`. @@ -114,13 +115,12 @@ See also [`lq_pullback!`](@ref). """ function lq_null_pullback!( ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ; - tol::Real = default_pullback_gaugetol(A), - gauge_atol::Real = tol + gauge_atol::Real = default_pullback_gauge_atol(A) ) if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0 aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ') Δgauge = norm(aNᴴΔN) - Δgauge < tol || + Δgauge < gauge_atol || @warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)" L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here? X = ldiv!(LowerTriangular(L)', Q * ΔNᴴ') diff --git a/src/pullbacks/qr.jl b/src/pullbacks/qr.jl index bc49d6af..ae98530e 100644 --- a/src/pullbacks/qr.jl +++ b/src/pullbacks/qr.jl @@ -2,8 +2,8 @@ qr_pullback!( ΔA, A, QR, ΔQR; tol::Real = default_pullback_gaugetol(QR[2]), - rank_atol::Real = tol, - gauge_atol::Real = tol + rank_atol::Real = default_pullback_rank_atol(QR[2]), + gauge_atol::Real = default_pullback_gauge_atol(QR[2]) ) Adds the pullback from the QR decomposition of `A` to `ΔA` given the output `QR` and @@ -18,9 +18,8 @@ and also the adjoint variables `ΔQ` and `ΔR` should have nonzero values only i """ function qr_pullback!( ΔA::AbstractMatrix, A, QR, ΔQR; - tol::Real = default_pullback_gaugetol(QR[2]), - rank_atol::Real = tol, - gauge_atol::Real = tol + rank_atol::Real = default_pullback_rank_atol(QR[2]), + gauge_atol::Real = default_pullback_gauge_atol(QR[2]) ) # process Q, R = QR @@ -28,7 +27,7 @@ function qr_pullback!( n = size(R, 2) minmn = min(m, n) Rd = diagview(R) - p = findlast(>=(rank_atol) ∘ abs, Rd) + p = @something findlast(>=(rank_atol) ∘ abs, Rd) 0 ΔQ, ΔR = ΔQR @@ -71,7 +70,7 @@ function qr_pullback!( # Q2' * ΔQ2 as a gauge dependent quantity. Q1dΔQ2 = Q1' * ΔQ2 Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) - Δgauge < tol || + Δgauge < gauge_atol || @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1) end @@ -104,7 +103,10 @@ function qr_pullback!( end """ - qr_null_pullback(ΔA, A, N, ΔN) + qr_null_pullback!( + ΔA::AbstractMatrix, A, N, ΔN; + gauge_atol::Real = default_pullback_gauge_atol(A) + ) Adds the pullback from the right nullspace of `A` to `ΔA`, given the nullspace basis `N` and its cotangent `ΔN` of `qr_null(A)`. @@ -113,13 +115,12 @@ See also [`qr_pullback!`](@ref). """ function qr_null_pullback!( ΔA::AbstractMatrix, A, N, ΔN; - tol::Real = default_pullback_gaugetol(A), - gauge_atol::Real = tol + gauge_atol::Real = default_pullback_gauge_atol(A) ) if !iszerotangent(ΔN) && size(N, 2) > 0 aNᴴΔN = project_antihermitian!(N' * ΔN) Δgauge = norm(aNᴴΔN) - Δgauge < tol || + Δgauge < gauge_atol || @warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)" Q, R = qr_compact(A; positive = true) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index effceaa3..7b0979da 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -1,10 +1,9 @@ """ svd_pullback!( ΔA, A, USVᴴ, ΔUSVᴴ, [ind]; - tol::Real=default_pullback_gaugetol(USVᴴ[2]), - rank_atol::Real = tol, - degeneracy_atol::Real = tol, - gauge_atol::Real = tol + rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + gauge_atol::Real = default_pullback_gauge_atol(USVᴴ[2]) ) Adds the pullback from the SVD of `A` to `ΔA` given the output USVᴴ of `svd_compact` or @@ -23,10 +22,9 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol """ function svd_pullback!( ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ, ind = Colon(); - tol::Real = default_pullback_gaugetol(USVᴴ[2]), - rank_atol::Real = tol, - degeneracy_atol::Real = tol, - gauge_atol::Real = tol + rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + gauge_atol::Real = default_pullback_gauge_atol(USVᴴ[2]) ) # Extract the SVD components @@ -106,10 +104,9 @@ end """ svd_trunc_pullback!( ΔA, A, USVᴴ, ΔUSVᴴ; - tol::Real=default_pullback_gaugetol(S), - rank_atol::Real = tol, - degeneracy_atol::Real = tol, - gauge_atol::Real = tol + rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + gauge_atol::Real = default_pullback_gauge_atol(USVᴴ[2]) ) Adds the pullback from the truncated SVD of `A` to `ΔA`, given the output `USVᴴ` and the @@ -128,10 +125,9 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol """ function svd_trunc_pullback!( ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ; - tol::Real = default_pullback_gaugetol(USVᴴ[2]), - rank_atol::Real = tol, - degeneracy_atol::Real = tol, - gauge_atol::Real = tol + rank_atol::Real = 0, + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + gauge_atol::Real = default_pullback_gauge_atol(USVᴴ[2]) ) # Extract the SVD components From fc6836996ae7a60bfc28d3c236862d6be7f01a1b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 5 Nov 2025 16:09:02 -0500 Subject: [PATCH 2/6] update docstrings --- src/common/defaults.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/common/defaults.jl b/src/common/defaults.jl index 9aca5391..25935507 100644 --- a/src/common/defaults.jl +++ b/src/common/defaults.jl @@ -20,14 +20,14 @@ default_pullback_gauge_atol(A) = eps(real(eltype(A)))^(3 / 4) """ default_pullback_degeneracy_atol(A) -Default tolerance for deciding which singular values should be considered as degenerate. +Default tolerance for deciding when values should be considered as degenerate. """ default_pullback_degeneracy_atol(A) = eps(norm(A, Inf))^(3 / 4) """ default_pullback_rank_atol(A) -Default tolerance for deciding what singular values should be considered equal to 0. +Default tolerance for deciding what values should be considered equal to 0. """ default_pullback_rank_atol(A) = eps(norm(A, Inf))^(3 / 4) From 55c4e18147560dc9021714f72432f6891dbea0df Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 5 Nov 2025 17:37:16 -0500 Subject: [PATCH 3/6] make gauge tolerances relative --- src/common/defaults.jl | 5 +++-- src/pullbacks/eig.jl | 8 ++++---- src/pullbacks/eigh.jl | 8 ++++---- src/pullbacks/lq.jl | 8 ++++---- src/pullbacks/qr.jl | 8 ++++---- src/pullbacks/svd.jl | 8 ++++---- 6 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/common/defaults.jl b/src/common/defaults.jl index 25935507..740bf95a 100644 --- a/src/common/defaults.jl +++ b/src/common/defaults.jl @@ -10,12 +10,13 @@ quantity needs to be computed. defaulttol(x::Any) = eps(real(float(one(eltype(x)))))^(2 / 3) """ - default_pullback_gauge_atol(A) + default_pullback_gauge_atol(ΔA...) Default tolerance for deciding to warn if incoming adjoints of a pullback rule has components that are not gauge-invariant. """ -default_pullback_gauge_atol(A) = eps(real(eltype(A)))^(3 / 4) +default_pullback_gauge_atol(A) = eps(norm(A, Inf))^(3 / 4) +default_pullback_gauge_atol(A, As...) = maximum(default_pullback_gauge_atol, (A, As...)) """ default_pullback_degeneracy_atol(A) diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index f3b5991d..3115b3d5 100644 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -2,7 +2,7 @@ eig_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV, [ind]; degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), - gauge_atol::Real = default_pullback_gauge_atol(DV[1]) + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) Adds the pullback from the full eigenvalue decomposition of `A` to `ΔA`, given the output @@ -22,7 +22,7 @@ not small compared to `gauge_atol`. function eig_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV, ind = Colon(); degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), - gauge_atol::Real = default_pullback_gauge_atol(DV[1]) + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) # Basic size checks and determination @@ -83,7 +83,7 @@ end eig_trunc_pullback!( ΔA::AbstractMatrix, ΔDV, A, DV; degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), - gauge_atol::Real = default_pullback_gauge_atol(DV[1]) + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) Adds the pullback from the truncated eigenvalue decomposition of `A` to `ΔA`, given the @@ -104,7 +104,7 @@ not small compared to `gauge_atol`. function eig_trunc_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV; degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), - gauge_atol::Real = default_pullback_gauge_atol(DV[1]) + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) # Basic size checks and determination diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index 15d3cdb2..e7c93c75 100644 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -2,7 +2,7 @@ eigh_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV, [ind]; degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), - gauge_atol::Real = default_pullback_gauge_atol(DV[1]) + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) Adds the pullback from the Hermitian eigenvalue decomposition of `A` to `ΔA`, given the @@ -22,7 +22,7 @@ anti-hermitian part of `V' * ΔV`, restricted to rows `i` and columns `j` for wh function eigh_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV, ind = Colon(); degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), - gauge_atol::Real = default_pullback_gauge_atol(DV[1]) + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) # Basic size checks and determination @@ -73,7 +73,7 @@ end eigh_trunc_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV; degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), - gauge_atol::Real = default_pullback_gauge_atol(DV[1]) + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) Adds the pullback from the truncated Hermitian eigenvalue decomposition of `A` to `ΔA`, @@ -94,7 +94,7 @@ not small compared to `gauge_atol`. function eigh_trunc_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV; degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), - gauge_atol::Real = default_pullback_gauge_atol(DV[1]) + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) # Basic size checks and determination diff --git a/src/pullbacks/lq.jl b/src/pullbacks/lq.jl index cef82af1..d2cfa290 100644 --- a/src/pullbacks/lq.jl +++ b/src/pullbacks/lq.jl @@ -2,7 +2,7 @@ lq_pullback!( ΔA, A, LQ, ΔLQ; rank_atol::Real = default_pullback_rank_atol(LQ[1]), - gauge_atol::Real = default_pullback_gauge_atol(LQ[1]) + gauge_atol::Real = default_pullback_gauge_atol(ΔLQ[2]) ) Adds the pullback from the LQ decomposition of `A` to `ΔA` given the output `LQ` and @@ -18,7 +18,7 @@ or rows exceed `gauge_atol`, a warning will be printed. function lq_pullback!( ΔA::AbstractMatrix, A, LQ, ΔLQ; rank_atol::Real = default_pullback_rank_atol(LQ[1]), - gauge_atol::Real = default_pullback_gauge_atol(LQ[1]) + gauge_atol::Real = default_pullback_gauge_atol(ΔLQ[2]) ) # process L, Q = LQ @@ -105,7 +105,7 @@ end """ lq_null_pullback!( ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ; - gauge_atol::Real = default_pullback_gauge_atol(A) + gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ) ) Adds the pullback from the left nullspace of `A` to `ΔA`, given the nullspace basis @@ -115,7 +115,7 @@ See also [`lq_pullback!`](@ref). """ function lq_null_pullback!( ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ; - gauge_atol::Real = default_pullback_gauge_atol(A) + gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ) ) if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0 aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ') diff --git a/src/pullbacks/qr.jl b/src/pullbacks/qr.jl index ae98530e..10882bdf 100644 --- a/src/pullbacks/qr.jl +++ b/src/pullbacks/qr.jl @@ -3,7 +3,7 @@ ΔA, A, QR, ΔQR; tol::Real = default_pullback_gaugetol(QR[2]), rank_atol::Real = default_pullback_rank_atol(QR[2]), - gauge_atol::Real = default_pullback_gauge_atol(QR[2]) + gauge_atol::Real = default_pullback_gauge_atol(ΔQR[1]) ) Adds the pullback from the QR decomposition of `A` to `ΔA` given the output `QR` and @@ -19,7 +19,7 @@ and also the adjoint variables `ΔQ` and `ΔR` should have nonzero values only i function qr_pullback!( ΔA::AbstractMatrix, A, QR, ΔQR; rank_atol::Real = default_pullback_rank_atol(QR[2]), - gauge_atol::Real = default_pullback_gauge_atol(QR[2]) + gauge_atol::Real = default_pullback_gauge_atol(ΔQR[1]) ) # process Q, R = QR @@ -105,7 +105,7 @@ end """ qr_null_pullback!( ΔA::AbstractMatrix, A, N, ΔN; - gauge_atol::Real = default_pullback_gauge_atol(A) + gauge_atol::Real = default_pullback_gauge_atol(ΔN) ) Adds the pullback from the right nullspace of `A` to `ΔA`, given the nullspace basis @@ -115,7 +115,7 @@ See also [`qr_pullback!`](@ref). """ function qr_null_pullback!( ΔA::AbstractMatrix, A, N, ΔN; - gauge_atol::Real = default_pullback_gauge_atol(A) + gauge_atol::Real = default_pullback_gauge_atol(ΔN) ) if !iszerotangent(ΔN) && size(N, 2) > 0 aNᴴΔN = project_antihermitian!(N' * ΔN) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 7b0979da..c0353a3a 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -3,7 +3,7 @@ ΔA, A, USVᴴ, ΔUSVᴴ, [ind]; rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), - gauge_atol::Real = default_pullback_gauge_atol(USVᴴ[2]) + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) ) Adds the pullback from the SVD of `A` to `ΔA` given the output USVᴴ of `svd_compact` or @@ -24,7 +24,7 @@ function svd_pullback!( ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), - gauge_atol::Real = default_pullback_gauge_atol(USVᴴ[2]) + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) ) # Extract the SVD components @@ -106,7 +106,7 @@ end ΔA, A, USVᴴ, ΔUSVᴴ; rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), - gauge_atol::Real = default_pullback_gauge_atol(USVᴴ[2]) + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) ) Adds the pullback from the truncated SVD of `A` to `ΔA`, given the output `USVᴴ` and the @@ -127,7 +127,7 @@ function svd_trunc_pullback!( ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ; rank_atol::Real = 0, degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), - gauge_atol::Real = default_pullback_gauge_atol(USVᴴ[2]) + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) ) # Extract the SVD components From 5bff5bce8193d98750fdf80be031b3f635bb34b6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 5 Nov 2025 17:37:23 -0500 Subject: [PATCH 4/6] make hermitian tolerances relative --- src/common/defaults.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/common/defaults.jl b/src/common/defaults.jl index 740bf95a..ecbef3cf 100644 --- a/src/common/defaults.jl +++ b/src/common/defaults.jl @@ -37,7 +37,4 @@ default_pullback_rank_atol(A) = eps(norm(A, Inf))^(3 / 4) Default tolerance for deciding to warn if the provided `A` is not hermitian. """ -function default_hermitian_tol(A) - n = norm(A, Inf) - return eps(eltype(n))^(3 / 4) * max(n, one(n)) -end +default_hermitian_tol(A) = eps(norm(A, Inf))^(3 / 4) From b263ac4a56fd26016639c76f7c5ff66989a5bcf1 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 6 Nov 2025 10:49:01 -0500 Subject: [PATCH 5/6] deal with zerotangents --- src/common/defaults.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/defaults.jl b/src/common/defaults.jl index ecbef3cf..aaeb7d7f 100644 --- a/src/common/defaults.jl +++ b/src/common/defaults.jl @@ -15,7 +15,7 @@ defaulttol(x::Any) = eps(real(float(one(eltype(x)))))^(2 / 3) Default tolerance for deciding to warn if incoming adjoints of a pullback rule has components that are not gauge-invariant. """ -default_pullback_gauge_atol(A) = eps(norm(A, Inf))^(3 / 4) +default_pullback_gauge_atol(A) = iszerotangent(A) ? 0 : eps(norm(A, Inf))^(3 / 4) default_pullback_gauge_atol(A, As...) = maximum(default_pullback_gauge_atol, (A, As...)) """ From 6b16b48e9624aa49b808b274dec954ba2abc83e5 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 6 Nov 2025 14:23:20 -0500 Subject: [PATCH 6/6] attempt to contain type instability --- src/common/defaults.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/common/defaults.jl b/src/common/defaults.jl index aaeb7d7f..332807da 100644 --- a/src/common/defaults.jl +++ b/src/common/defaults.jl @@ -16,7 +16,10 @@ Default tolerance for deciding to warn if incoming adjoints of a pullback rule has components that are not gauge-invariant. """ default_pullback_gauge_atol(A) = iszerotangent(A) ? 0 : eps(norm(A, Inf))^(3 / 4) -default_pullback_gauge_atol(A, As...) = maximum(default_pullback_gauge_atol, (A, As...)) +function default_pullback_gauge_atol(A, As...) + As′ = filter(!iszerotangent, (A, As...)) + return isempty(As′) ? 0 : eps(norm(As′, Inf))^(3 / 4) +end """ default_pullback_degeneracy_atol(A)