@@ -36,7 +36,7 @@ function svd_pullback!(
3636 minmn = min (m, n)
3737 S = diagview (Smat)
3838 length (S) == minmn || throw (DimensionMismatch ())
39- r = searchsortedlast (S, tol ; rev = true ) # rank
39+ r = searchsortedlast (S, rank_atol ; rev = true ) # rank
4040 Ur = view (U, :, 1 : r)
4141 Vᴴr = view (Vᴴ, 1 : r, :)
4242 Sr = view (S, 1 : r)
@@ -53,7 +53,8 @@ function svd_pullback!(
5353 length (indU) == pU || throw (DimensionMismatch ())
5454 UΔUp = view (UΔU, :, indU)
5555 mul! (UΔUp, Ur' , ΔU)
56- mul! (ΔU, Ur, UΔUp, - 1 , 1 )
56+ # ΔU -= Ur * UΔUp but one less allocation without overwriting ΔU
57+ ΔU = mul! (copy (ΔU), Ur, UΔUp, - 1 , 1 )
5758 end
5859 if ! iszerotangent (ΔVᴴ)
5960 n == size (ΔVᴴ, 2 ) || throw (DimensionMismatch ())
@@ -63,7 +64,8 @@ function svd_pullback!(
6364 length (indV) == pV || throw (DimensionMismatch ())
6465 VΔVp = view (VΔV, :, indV)
6566 mul! (VΔVp, Vᴴr, ΔVᴴ' )
66- mul! (ΔVᴴ, VΔVp' , Vᴴr, - 1 , 1 )
67+ # ΔVᴴ -= VΔVp' * Vᴴr but one less allocation without overwriting ΔVᴴ
68+ ΔVᴴ = mul! (copy (ΔVᴴ), VΔVp' , Vᴴr, - 1 , 1 )
6769 end
6870
6971 # Project onto antihermitian part; hermitian part outside of Grassmann tangent space
@@ -152,7 +154,8 @@ function svd_trunc_pullback!(
152154 if ! iszerotangent (ΔVᴴ)
153155 (p, n) == size (ΔVᴴ) || throw (DimensionMismatch ())
154156 mul! (VΔV, Vᴴ, ΔVᴴ' )
155- mul! (ΔVᴴ, VΔV' , Vᴴ, - 1 , 1 )
157+ # ΔVᴴ -= VΔVp' * Vᴴr but one less allocation without overwriting ΔVᴴ
158+ ΔVᴴ = mul! (copy (ΔVᴴ), VΔV' , Vᴴ, - 1 , 1 )
156159 end
157160
158161 # Project onto antihermitian part; hermitian part outside of Grassmann tangent space
0 commit comments