Skip to content

Commit 5d44ca4

Browse files
committed
Mooncake forward rules
1 parent 5fc5ce5 commit 5d44ca4

File tree

12 files changed

+483
-86
lines changed

12 files changed

+483
-86
lines changed

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 224 additions & 51 deletions
Large diffs are not rendered by default.

src/MatrixAlgebraKit.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,11 @@ include("pullbacks/eigh.jl")
115115
include("pullbacks/svd.jl")
116116
include("pullbacks/polar.jl")
117117

118+
include("pushforwards/qr.jl")
119+
include("pushforwards/lq.jl")
120+
include("pushforwards/eig.jl")
121+
include("pushforwards/eigh.jl")
122+
include("pushforwards/polar.jl")
123+
include("pushforwards/svd.jl")
124+
118125
end

src/implementations/eigh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function check_hermitian(A; atol::Real = default_hermitian_tol(A), rtol::Real =
1919
end
2020

2121
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::AbstractAlgorithm)
22-
check_hermitian(A, alg)
22+
#check_hermitian(A, alg)
2323
D, V = DV
2424
m = size(A, 1)
2525
@assert D isa Diagonal && V isa AbstractMatrix

src/pullbacks/eig.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ function eig_pullback!(
4646
Δgauge gauge_atol ||
4747
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
4848

49-
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
49+
VᴴΔV ./= conj.(transpose(D) .- D)
50+
diagview(VᴴΔV) .= zero(eltype(VᴴΔV))
5051

5152
if !iszerotangent(ΔDmat)
5253
ΔDvec = diagview(ΔDmat)

src/pushforwards/eig.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
function eig_pushforward!(ΔA, A, DV, ΔDV; kwargs...)
2+
D, V = DV
3+
ΔD, ΔV = ΔDV
4+
iVΔAV = inv(V) * ΔA * V
5+
diagview(ΔD) .= diagview(iVΔAV)
6+
if !iszerotangent(ΔV)
7+
F = 1 ./ (transpose(diagview(D)) .- diagview(D))
8+
fill!(diagview(F), zero(eltype(F)))
9+
= F .* iVΔAV
10+
mul!(ΔV, V, K̇, 1, 0)
11+
end
12+
return ΔDV
13+
end
14+
15+
function eig_trunc_pushforward!(ΔA, A, DV, ΔDV; kwargs...) end

src/pushforwards/eigh.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
function eigh_pushforward!(dA, A, DV, dDV; kwargs...)
2+
D, V = DV
3+
dD, dV = dDV
4+
tmpV = V \ dA
5+
∂K = tmpV * V
6+
∂Kdiag = diag(∂K)
7+
diagview(dD) .= real.(∂Kdiag)
8+
if !iszerotangent(dV)
9+
dDD = transpose(diagview(D)) .- diagview(D)
10+
F = one(eltype(dDD)) ./ dDD
11+
diagview(F) .= zero(eltype(F))
12+
∂K .*= F
13+
∂V = mul!(tmpV, V, ∂K)
14+
copyto!(dV, ∂V)
15+
end
16+
return (dD, dV)
17+
end
18+
19+
function eigh_trunc_pushforward!(dA, A, DV, dDV; kwargs...) end

src/pushforwards/lq.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
function lq_pushforward!(dA, A, LQ, dLQ; tol::Real = default_pullback_gauge_atol(LQ[1]), rank_atol::Real = tol, gauge_atol::Real = tol)
2+
return qr_pushforward!(adjoint(dA), adjoint(A), adjoint.(reverse(LQ)), adjoint.(reverse(dLQ)); tol, rank_atol, gauge_atol)
3+
end
4+
5+
function lq_null_pushforward!(dA, A, Nᴴ, dNᴴ; tol::Real = default_pullback_gauge_atol(Nᴴ), rank_atol::Real = tol, gauge_atol::Real = tol)
6+
return iszero(min(size(Nᴴ)...)) && return # nothing to do
7+
end

src/pushforwards/polar.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...)
2+
W, P = WP
3+
ΔW, ΔP = ΔWP
4+
aWdA = adjoint(W) * ΔA
5+
= sylvester(P, P, -(aWdA - adjoint(aWdA)))
6+
= (Diagonal(ones(eltype(W), size(W, 1))) - W * adjoint(W)) * ΔA * inv(P)
7+
ΔW .= W *+
8+
ΔP .= aWdA -* P
9+
return (ΔW, ΔP)
10+
end
11+
12+
function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...)
13+
P, Wᴴ = PWᴴ
14+
ΔP, ΔWᴴ = ΔPWᴴ
15+
dAW = ΔA * adjoint(Wᴴ)
16+
= sylvester(P, P, -(dAW - adjoint(dAW)))
17+
= inv(P) * ΔA * (Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ)
18+
ΔWᴴ .=* Wᴴ +
19+
ΔP .= dAW - P *
20+
return (ΔWᴴ, ΔP)
21+
end

src/pushforwards/qr.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
function qr_pushforward!(dA, A, QR, dQR; tol::Real = default_pullback_gauge_atol(QR[2]), rank_atol::Real = tol, gauge_atol::Real = tol)
2+
Q, R = QR
3+
m = size(A, 1)
4+
n = size(A, 2)
5+
minmn = min(m, n)
6+
Rd = diagview(R)
7+
p = findlast(>=(rank_atol) abs, Rd)
8+
9+
m1 = p
10+
m2 = minmn - p
11+
m3 = m - minmn
12+
n1 = p
13+
n2 = n - p
14+
15+
Q1 = view(Q, 1:m, 1:m1) # full rank portion
16+
Q2 = view(Q, 1:m, (m1 + 1):(m2 + m1))
17+
R11 = view(R, 1:m1, 1:n1)
18+
R12 = view(R, 1:m1, (n1 + 1):n)
19+
20+
dA1 = view(dA, 1:m, 1:n1)
21+
dA2 = view(dA, 1:m, (n1 + 1):n)
22+
23+
dQ, dR = dQR
24+
dQ1 = view(dQ, 1:m, 1:m1)
25+
dQ2 = view(dQ, 1:m, (m1 + 1):(m2 + m1))
26+
dQ3 = minmn + 1 < size(dQ, 2) ? view(dQ, :, (minmn + 1):size(dQ, 2)) : similar(dQ, eltype(dQ), (0, 0))
27+
dR11 = view(dR, 1:m1, 1:n1)
28+
dR12 = view(dR, 1:m1, (n1 + 1):n)
29+
dR22 = view(dR, (m1 + 1):(m1 + m2), (n1 + 1):n)
30+
31+
# fwd rule for Q1 and R11 -- for a non-rank redeficient QR, this is all we need
32+
invR11 = inv(R11)
33+
tmp = Q1' * dA1 * invR11
34+
Rtmp = tmp + tmp'
35+
diagview(Rtmp) ./= 2
36+
ltRtmp = view(Rtmp, lowertriangularind(Rtmp))
37+
ltRtmp .= zero(eltype(Rtmp))
38+
dR11 .= Rtmp * R11
39+
dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11
40+
dR12 .= adjoint(Q1) * (dA2 - dQ1 * R12)
41+
if size(Q2, 2) > 0
42+
dQ2 .= -Q1 * (Q1' * Q2)
43+
dQ2 .+= Q2 * (Q2' * dQ2)
44+
end
45+
if m3 > 0 && size(Q, 2) > minmn
46+
# only present for qr_full or rank-deficient qr_compact
47+
Q′ = view(Q, :, 1:minmn)
48+
Q3 = view(Q, :, (minmn + 1):m)
49+
#dQ3 .= Q′ * (Q′' * Q3)
50+
dQ3 .= Q3
51+
end
52+
if !isempty(dR22)
53+
_, r22 = qr_compact(dA2 - dQ1 * R12 - Q1 * dR12; positive = true)
54+
dR22 .= view(r22, 1:size(dR22, 1), 1:size(dR22, 2))
55+
end
56+
return (dQ, dR)
57+
end
58+
59+
function qr_null_pushforward!(dA, A, N, dN; tol::Real = default_pullback_gauge_atol(N), rank_atol::Real = tol, gauge_atol::Real = tol)
60+
return iszero(min(size(N)...)) && return # nothing to do
61+
end

src/pushforwards/svd.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ; rank_atol = default_pullback_rank_atol(A), kwargs...)
2+
U, Smat, Vᴴ = USVᴴ
3+
m, n = size(U, 1), size(Vᴴ, 2)
4+
(m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)"))
5+
minmn = min(m, n)
6+
S = diagview(Smat)
7+
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
8+
r = searchsortedlast(S, rank_atol; rev = true) # rank
9+
10+
vΔU = view(ΔU, :, 1:r)
11+
vΔS = view(ΔS, 1:r, 1:r)
12+
vΔVᴴ = view(ΔVᴴ, 1:r, :)
13+
14+
vU = view(U, :, 1:r)
15+
vS = view(S, 1:r)
16+
vSmat = view(Smat, 1:r, 1:r)
17+
vVᴴ = view(Vᴴ, 1:r, :)
18+
19+
# compact region
20+
vV = adjoint(vVᴴ)
21+
UΔAV = vU' * ΔA * vV
22+
copyto!(diagview(vΔS), diag(real.(UΔAV)))
23+
F = one(eltype(S)) ./ (transpose(vS) .- vS)
24+
G = one(eltype(S)) ./ (transpose(vS) .+ vS)
25+
diagview(F) .= zero(eltype(F))
26+
hUΔAV = F .* (UΔAV + UΔAV') ./ 2
27+
aUΔAV = G .* (UΔAV - UΔAV') ./ 2
28+
= hUΔAV + aUΔAV
29+
= hUΔAV - aUΔAV
30+
31+
# check gauge condition
32+
@assert isantihermitian(K̇)
33+
@assert isantihermitian(Ṁ)
34+
K̇diag = diagview(K̇)
35+
for i in 1:length(K̇diag)
36+
@assert K̇diag[i] (im / 2) * imag(diagview(UΔAV)[i]) / S[i]
37+
end
38+
39+
∂U = vU *
40+
∂V = vV *
41+
# full component
42+
if size(U, 2) > minmn && size(Vᴴ, 1) > minmn
43+
Uperp = view(U, :, (minmn + 1):m)
44+
Vᴴperp = view(Vᴴ, (minmn + 1):n, :)
45+
46+
aUAV = adjoint(Uperp) * A * adjoint(Vᴴperp)
47+
48+
UÃÃV = similar(A, (size(aUAV, 1) + size(aUAV, 2), size(aUAV, 1) + size(aUAV, 2)))
49+
fill!(UÃÃV, 0)
50+
view(UÃÃV, (1:size(aUAV, 1)), size(aUAV, 1) .+ (1:size(aUAV, 2))) .= aUAV
51+
view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV'
52+
rhs = vcat(adjoint(Uperp, ΔA, V), Vᴴperp * ΔA' * U)
53+
superKM = -sylvester(UÃÃV, Smat, rhs)
54+
K̇perp = view(superKM, 1:size(aUAV, 2))
55+
Ṁperp = view(superKM, (size(aUAV, 2) + 1):(size(aUAV, 1) + size(aUAV, 2)))
56+
∂U .+= Uperp * K̇perp
57+
∂V .+= Vperp * Ṁperp
58+
else
59+
ImUU = (LinearAlgebra.diagm(ones(eltype(U), m)) - vU * vU')
60+
ImVV = (LinearAlgebra.diagm(ones(eltype(Vᴴ), n)) - vV * vVᴴ)
61+
upper = ImUU * ΔA * vV
62+
lower = ImVV * ΔA' * vU
63+
rhs = vcat(upper, lower)
64+
65+
= ImUU * A * ImVV
66+
ÃÃ = similar(A, (m + n, m + n))
67+
fill!(ÃÃ, 0)
68+
view(ÃÃ, (1:m), m .+ (1:n)) .=
69+
view(ÃÃ, m .+ (1:n), 1:m) .='
70+
71+
superLN = -sylvester(ÃÃ, vSmat, rhs)
72+
∂U += view(superLN, 1:size(upper, 1), :)
73+
∂V += view(superLN, (size(upper, 1) + 1):(size(upper, 1) + size(lower, 1)), :)
74+
end
75+
copyto!(vΔU, ∂U)
76+
adjoint!(vΔVᴴ, ∂V)
77+
return (ΔU, ΔS, ΔVᴴ)
78+
end
79+
80+
function svd_trunc_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol = default_pullback_rank_atol(A), kwargs...)
81+
82+
end

0 commit comments

Comments
 (0)