Skip to content

Commit 2231f49

Browse files
committed
Small fixes on top of #53 (#59)
* use `rank_atol` for determining SVD rank * avoid overwriting input tangents
1 parent 07fb93c commit 2231f49

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/pullbacks/svd.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function svd_pullback!(
3636
minmn = min(m, n)
3737
S = diagview(Smat)
3838
length(S) == minmn || throw(DimensionMismatch())
39-
r = searchsortedlast(S, tol; rev = true) # rank
39+
r = searchsortedlast(S, rank_atol; rev = true) # rank
4040
Ur = view(U, :, 1:r)
4141
Vᴴr = view(Vᴴ, 1:r, :)
4242
Sr = view(S, 1:r)
@@ -53,7 +53,8 @@ function svd_pullback!(
5353
length(indU) == pU || throw(DimensionMismatch())
5454
UΔUp = view(UΔU, :, indU)
5555
mul!(UΔUp, Ur', ΔU)
56-
mul!(ΔU, Ur, UΔUp, -1, 1)
56+
# ΔU -= Ur * UΔUp but one less allocation without overwriting ΔU
57+
ΔU = mul!(copy(ΔU), Ur, UΔUp, -1, 1)
5758
end
5859
if !iszerotangent(ΔVᴴ)
5960
n == size(ΔVᴴ, 2) || throw(DimensionMismatch())
@@ -63,7 +64,8 @@ function svd_pullback!(
6364
length(indV) == pV || throw(DimensionMismatch())
6465
VΔVp = view(VΔV, :, indV)
6566
mul!(VΔVp, Vᴴr, ΔVᴴ')
66-
mul!(ΔVᴴ, VΔVp', Vᴴr, -1, 1)
67+
# ΔVᴴ -= VΔVp' * Vᴴr but one less allocation without overwriting ΔVᴴ
68+
ΔVᴴ = mul!(copy(ΔVᴴ), VΔVp', Vᴴr, -1, 1)
6769
end
6870

6971
# Project onto antihermitian part; hermitian part outside of Grassmann tangent space
@@ -152,7 +154,8 @@ function svd_trunc_pullback!(
152154
if !iszerotangent(ΔVᴴ)
153155
(p, n) == size(ΔVᴴ) || throw(DimensionMismatch())
154156
mul!(VΔV, Vᴴ, ΔVᴴ')
155-
mul!(ΔVᴴ, VΔV', Vᴴ, -1, 1)
157+
# ΔVᴴ -= VΔVp' * Vᴴr but one less allocation without overwriting ΔVᴴ
158+
ΔVᴴ = mul!(copy(ΔVᴴ), VΔV', Vᴴ, -1, 1)
156159
end
157160

158161
# Project onto antihermitian part; hermitian part outside of Grassmann tangent space

0 commit comments

Comments
 (0)