Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions src/common/defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,34 @@ quantity needs to be computed.
defaulttol(x::Any) = eps(real(float(one(eltype(x)))))^(2 / 3)

"""
default_pullback_gaugetol(a)
default_pullback_gauge_atol(ΔA...)

Default tolerance for deciding to warn if incoming adjoints of a pullback rule
has components that are not gauge-invariant.
"""
function default_pullback_gaugetol(a)
n = norm(a, Inf)
return eps(eltype(n))^(3 / 4) * max(n, one(n))
default_pullback_gauge_atol(A) = iszerotangent(A) ? 0 : eps(norm(A, Inf))^(3 / 4)
function default_pullback_gauge_atol(A, As...)
As′ = filter(!iszerotangent, (A, As...))
return isempty(As′) ? 0 : eps(norm(As′, Inf))^(3 / 4)
end

"""
default_pullback_degeneracy_atol(A)

Default tolerance for deciding when values should be considered as degenerate.
"""
default_pullback_degeneracy_atol(A) = eps(norm(A, Inf))^(3 / 4)

"""
default_pullback_rank_atol(A)

Default tolerance for deciding what values should be considered equal to 0.
"""
default_pullback_rank_atol(A) = eps(norm(A, Inf))^(3 / 4)

"""
default_hermitian_tol(A)

Default tolerance for deciding to warn if the provided `A` is not hermitian.
"""
function default_hermitian_tol(A)
n = norm(A, Inf)
return eps(eltype(n))^(3 / 4) * max(n, one(n))
end
default_hermitian_tol(A) = eps(norm(A, Inf))^(3 / 4)
20 changes: 8 additions & 12 deletions src/pullbacks/eig.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""
eig_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV, [ind];
tol = default_pullback_gaugetol(DV[1]),
degeneracy_atol = tol,
gauge_atol = tol
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)

Adds the pullback from the full eigenvalue decomposition of `A` to `ΔA`, given the output
Expand All @@ -22,9 +21,8 @@ not small compared to `gauge_atol`.
"""
function eig_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV, ind = Colon();
tol::Real = default_pullback_gaugetol(DV[1]),
degeneracy_atol::Real = tol,
gauge_atol::Real = tol
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)

# Basic size checks and determination
Expand Down Expand Up @@ -84,9 +82,8 @@ end
"""
eig_trunc_pullback!(
ΔA::AbstractMatrix, ΔDV, A, DV;
tol = default_pullback_gaugetol(DV[1]),
degeneracy_atol = tol,
gauge_atol = tol
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)

Adds the pullback from the truncated eigenvalue decomposition of `A` to `ΔA`, given the
Expand All @@ -106,9 +103,8 @@ not small compared to `gauge_atol`.
"""
function eig_trunc_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV;
tol::Real = default_pullback_gaugetol(DV[1]),
degeneracy_atol::Real = tol,
gauge_atol::Real = tol
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)

# Basic size checks and determination
Expand Down
24 changes: 10 additions & 14 deletions src/pullbacks/eigh.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""
eigh_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV, [ind];
tol = default_pullback_gaugetol(DV[1]),
degeneracy_atol = tol,
gauge_atol = tol
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)

Adds the pullback from the Hermitian eigenvalue decomposition of `A` to `ΔA`, given the
Expand All @@ -22,9 +21,8 @@ anti-hermitian part of `V' * ΔV`, restricted to rows `i` and columns `j` for wh
"""
function eigh_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV, ind = Colon();
tol::Real = default_pullback_gaugetol(DV[1]),
degeneracy_atol::Real = tol,
gauge_atol::Real = tol
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)

# Basic size checks and determination
Expand All @@ -49,7 +47,7 @@ function eigh_pullback!(
Δgauge < gauge_atol ||
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"

aVᴴΔV .*= inv_safe.(D' .- D, tol)
aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol)

if !iszerotangent(ΔDmat)
ΔDvec = diagview(ΔDmat)
Expand All @@ -74,9 +72,8 @@ end
"""
eigh_trunc_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV;
tol=default_pullback_gaugetol(DV[1]),
degeneracy_atol=tol,
gauge_atol=tol
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)

Adds the pullback from the truncated Hermitian eigenvalue decomposition of `A` to `ΔA`,
Expand All @@ -96,9 +93,8 @@ not small compared to `gauge_atol`.
"""
function eigh_trunc_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV;
tol::Real = default_pullback_gaugetol(DV[1]),
degeneracy_atol::Real = tol,
gauge_atol::Real = tol
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)

# Basic size checks and determination
Expand All @@ -119,7 +115,7 @@ function eigh_trunc_pullback!(
Δgauge < gauge_atol ||
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"

aVᴴΔV .*= inv_safe.(D' .- D, tol)
aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol)

if !iszerotangent(ΔDmat)
ΔDvec = diagview(ΔDmat)
Expand Down
24 changes: 12 additions & 12 deletions src/pullbacks/lq.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""
lq_pullback!(
ΔA, A, LQ, ΔLQ;
tol::Real = default_pullback_gaugetol(LQ[1]),
rank_atol::Real = tol,
gauge_atol::Real = tol
rank_atol::Real = default_pullback_rank_atol(LQ[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔLQ[2])
)

Adds the pullback from the LQ decomposition of `A` to `ΔA` given the output `LQ` and
Expand All @@ -18,17 +17,16 @@ or rows exceed `gauge_atol`, a warning will be printed.
"""
function lq_pullback!(
ΔA::AbstractMatrix, A, LQ, ΔLQ;
tol::Real = default_pullback_gaugetol(LQ[1]),
rank_atol::Real = tol,
gauge_atol::Real = tol
rank_atol::Real = default_pullback_rank_atol(LQ[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔLQ[2])
)
# process
L, Q = LQ
m = size(L, 1)
n = size(Q, 2)
minmn = min(m, n)
Ld = diagview(L)
p = findlast(>=(rank_atol) ∘ abs, Ld)
p = @something findlast(>=(rank_atol) ∘ abs, Ld) 0

ΔL, ΔQ = ΔLQ

Expand Down Expand Up @@ -72,7 +70,7 @@ function lq_pullback!(
# Q2' * ΔQ2 as a gauge dependent quantity.
ΔQ2Q1ᴴ = ΔQ2 * Q1'
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
Δgauge < tol ||
Δgauge < gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1)
end
Expand Down Expand Up @@ -105,7 +103,10 @@ function lq_pullback!(
end

"""
lq_null_pullback(ΔA, A, Nᴴ, ΔNᴴ)
lq_null_pullback!(
ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ;
gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ)
)

Adds the pullback from the left nullspace of `A` to `ΔA`, given the nullspace basis
`Nᴴ` and its cotangent `ΔNᴴ` of `lq_null(A)`.
Expand All @@ -114,13 +115,12 @@ See also [`lq_pullback!`](@ref).
"""
function lq_null_pullback!(
ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ;
tol::Real = default_pullback_gaugetol(A),
gauge_atol::Real = tol
gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ)
)
if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')
Δgauge = norm(aNᴴΔN)
Δgauge < tol ||
Δgauge < gauge_atol ||
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here?
X = ldiv!(LowerTriangular(L)', Q * ΔNᴴ')
Expand Down
23 changes: 12 additions & 11 deletions src/pullbacks/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
qr_pullback!(
ΔA, A, QR, ΔQR;
tol::Real = default_pullback_gaugetol(QR[2]),
rank_atol::Real = tol,
gauge_atol::Real = tol
rank_atol::Real = default_pullback_rank_atol(QR[2]),
gauge_atol::Real = default_pullback_gauge_atol(ΔQR[1])
)

Adds the pullback from the QR decomposition of `A` to `ΔA` given the output `QR` and
Expand All @@ -18,17 +18,16 @@ and also the adjoint variables `ΔQ` and `ΔR` should have nonzero values only i
"""
function qr_pullback!(
ΔA::AbstractMatrix, A, QR, ΔQR;
tol::Real = default_pullback_gaugetol(QR[2]),
rank_atol::Real = tol,
gauge_atol::Real = tol
rank_atol::Real = default_pullback_rank_atol(QR[2]),
gauge_atol::Real = default_pullback_gauge_atol(ΔQR[1])
)
# process
Q, R = QR
m = size(Q, 1)
n = size(R, 2)
minmn = min(m, n)
Rd = diagview(R)
p = findlast(>=(rank_atol) ∘ abs, Rd)
p = @something findlast(>=(rank_atol) ∘ abs, Rd) 0

ΔQ, ΔR = ΔQR

Expand Down Expand Up @@ -71,7 +70,7 @@ function qr_pullback!(
# Q2' * ΔQ2 as a gauge dependent quantity.
Q1dΔQ2 = Q1' * ΔQ2
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
Δgauge < tol ||
Δgauge < gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1)
end
Expand Down Expand Up @@ -104,7 +103,10 @@ function qr_pullback!(
end

"""
qr_null_pullback(ΔA, A, N, ΔN)
qr_null_pullback!(
ΔA::AbstractMatrix, A, N, ΔN;
gauge_atol::Real = default_pullback_gauge_atol(ΔN)
)

Adds the pullback from the right nullspace of `A` to `ΔA`, given the nullspace basis
`N` and its cotangent `ΔN` of `qr_null(A)`.
Expand All @@ -113,13 +115,12 @@ See also [`qr_pullback!`](@ref).
"""
function qr_null_pullback!(
ΔA::AbstractMatrix, A, N, ΔN;
tol::Real = default_pullback_gaugetol(A),
gauge_atol::Real = tol
gauge_atol::Real = default_pullback_gauge_atol(ΔN)
)
if !iszerotangent(ΔN) && size(N, 2) > 0
aNᴴΔN = project_antihermitian!(N' * ΔN)
Δgauge = norm(aNᴴΔN)
Δgauge < tol ||
Δgauge < gauge_atol ||
@warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"

Q, R = qr_compact(A; positive = true)
Expand Down
28 changes: 12 additions & 16 deletions src/pullbacks/svd.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""
svd_pullback!(
ΔA, A, USVᴴ, ΔUSVᴴ, [ind];
tol::Real=default_pullback_gaugetol(USVᴴ[2]),
rank_atol::Real = tol,
degeneracy_atol::Real = tol,
gauge_atol::Real = tol
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
)

Adds the pullback from the SVD of `A` to `ΔA` given the output USVᴴ of `svd_compact` or
Expand All @@ -23,10 +22,9 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol
"""
function svd_pullback!(
ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ, ind = Colon();
tol::Real = default_pullback_gaugetol(USVᴴ[2]),
rank_atol::Real = tol,
degeneracy_atol::Real = tol,
gauge_atol::Real = tol
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
)

# Extract the SVD components
Expand Down Expand Up @@ -106,10 +104,9 @@ end
"""
svd_trunc_pullback!(
ΔA, A, USVᴴ, ΔUSVᴴ;
tol::Real=default_pullback_gaugetol(S),
rank_atol::Real = tol,
degeneracy_atol::Real = tol,
gauge_atol::Real = tol
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
)

Adds the pullback from the truncated SVD of `A` to `ΔA`, given the output `USVᴴ` and the
Expand All @@ -128,10 +125,9 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol
"""
function svd_trunc_pullback!(
ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ;
tol::Real = default_pullback_gaugetol(USVᴴ[2]),
rank_atol::Real = tol,
degeneracy_atol::Real = tol,
gauge_atol::Real = tol
rank_atol::Real = 0,
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
)

# Extract the SVD components
Expand Down