|
| 1 | +function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ; rank_atol = default_pullback_rank_atol(A), kwargs...) |
| 2 | + U, Smat, Vᴴ = USVᴴ |
| 3 | + m, n = size(U, 1), size(Vᴴ, 2) |
| 4 | + (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)")) |
| 5 | + minmn = min(m, n) |
| 6 | + S = diagview(Smat) |
| 7 | + ΔU, ΔS, ΔVᴴ = ΔUSVᴴ |
| 8 | + r = searchsortedlast(S, rank_atol; rev = true) # rank |
| 9 | + |
| 10 | + vΔU = view(ΔU, :, 1:r) |
| 11 | + vΔS = view(ΔS, 1:r, 1:r) |
| 12 | + vΔVᴴ = view(ΔVᴴ, 1:r, :) |
| 13 | + |
| 14 | + vU = view(U, :, 1:r) |
| 15 | + vS = view(S, 1:r) |
| 16 | + vSmat = view(Smat, 1:r, 1:r) |
| 17 | + vVᴴ = view(Vᴴ, 1:r, :) |
| 18 | + |
| 19 | + # compact region |
| 20 | + vV = adjoint(vVᴴ) |
| 21 | + UΔAV = vU' * ΔA * vV |
| 22 | + copyto!(diagview(vΔS), diag(real.(UΔAV))) |
| 23 | + F = one(eltype(S)) ./ (transpose(vS) .- vS) |
| 24 | + G = one(eltype(S)) ./ (transpose(vS) .+ vS) |
| 25 | + diagview(F) .= zero(eltype(F)) |
| 26 | + hUΔAV = F .* (UΔAV + UΔAV') ./ 2 |
| 27 | + aUΔAV = G .* (UΔAV - UΔAV') ./ 2 |
| 28 | + K̇ = hUΔAV + aUΔAV |
| 29 | + Ṁ = hUΔAV - aUΔAV |
| 30 | + |
| 31 | + # check gauge condition |
| 32 | + @assert isantihermitian(K̇) |
| 33 | + @assert isantihermitian(Ṁ) |
| 34 | + K̇diag = diagview(K̇) |
| 35 | + for i in 1:length(K̇diag) |
| 36 | + @assert K̇diag[i] ≈ (im / 2) * imag(diagview(UΔAV)[i]) / S[i] |
| 37 | + end |
| 38 | + |
| 39 | + ∂U = vU * K̇ |
| 40 | + ∂V = vV * Ṁ |
| 41 | + # full component |
| 42 | + if size(U, 2) > minmn && size(Vᴴ, 1) > minmn |
| 43 | + Uperp = view(U, :, (minmn + 1):m) |
| 44 | + Vᴴperp = view(Vᴴ, (minmn + 1):n, :) |
| 45 | + |
| 46 | + aUAV = adjoint(Uperp) * A * adjoint(Vᴴperp) |
| 47 | + |
| 48 | + UÃÃV = similar(A, (size(aUAV, 1) + size(aUAV, 2), size(aUAV, 1) + size(aUAV, 2))) |
| 49 | + fill!(UÃÃV, 0) |
| 50 | + view(UÃÃV, (1:size(aUAV, 1)), size(aUAV, 1) .+ (1:size(aUAV, 2))) .= aUAV |
| 51 | + view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV' |
| 52 | + rhs = vcat(adjoint(Uperp, ΔA, V), Vᴴperp * ΔA' * U) |
| 53 | + superKM = -sylvester(UÃÃV, Smat, rhs) |
| 54 | + K̇perp = view(superKM, 1:size(aUAV, 2)) |
| 55 | + Ṁperp = view(superKM, (size(aUAV, 2) + 1):(size(aUAV, 1) + size(aUAV, 2))) |
| 56 | + ∂U .+= Uperp * K̇perp |
| 57 | + ∂V .+= Vperp * Ṁperp |
| 58 | + else |
| 59 | + ImUU = (LinearAlgebra.diagm(ones(eltype(U), m)) - vU * vU') |
| 60 | + ImVV = (LinearAlgebra.diagm(ones(eltype(Vᴴ), n)) - vV * vVᴴ) |
| 61 | + upper = ImUU * ΔA * vV |
| 62 | + lower = ImVV * ΔA' * vU |
| 63 | + rhs = vcat(upper, lower) |
| 64 | + |
| 65 | + Ã = ImUU * A * ImVV |
| 66 | + ÃÃ = similar(A, (m + n, m + n)) |
| 67 | + fill!(ÃÃ, 0) |
| 68 | + view(ÃÃ, (1:m), m .+ (1:n)) .= Ã |
| 69 | + view(ÃÃ, m .+ (1:n), 1:m) .= Ã' |
| 70 | + |
| 71 | + superLN = -sylvester(ÃÃ, vSmat, rhs) |
| 72 | + ∂U += view(superLN, 1:size(upper, 1), :) |
| 73 | + ∂V += view(superLN, (size(upper, 1) + 1):(size(upper, 1) + size(lower, 1)), :) |
| 74 | + end |
| 75 | + copyto!(vΔU, ∂U) |
| 76 | + adjoint!(vΔVᴴ, ∂V) |
| 77 | + return (ΔU, ΔS, ΔVᴴ) |
| 78 | +end |
| 79 | + |
| 80 | +function svd_trunc_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol = default_pullback_rank_atol(A), kwargs...) |
| 81 | + |
| 82 | +end |
0 commit comments