Skip to content

Commit 40d5d1a

Browse files
committed
Working truncs
1 parent 0db2345 commit 40d5d1a

File tree

4 files changed

+87
-180
lines changed

4 files changed

+87
-180
lines changed

Project.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,3 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6161

6262
[targets]
6363
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake", "Enzyme", "EnzymeTestUtils"]
64-
65-
[sources]
66-
Enzyme = {path="/Users/khyatt/.julia/dev/Enzyme"}

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 55 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ module MatrixAlgebraKitEnzymeExt
22

33
using MatrixAlgebraKit
44
using MatrixAlgebraKit: copy_input
5-
using MatrixAlgebraKit: diagview, inv_safe, eig_trunc!, eigh_trunc!
5+
using MatrixAlgebraKit: diagview, inv_safe, eig_trunc!, eigh_trunc!, truncate
66
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
77
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
8-
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_trunc_pullback!, eigh_trunc_pullback!
8+
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!
99
using MatrixAlgebraKit: svd_pullback!
1010
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
1111
using Enzyme
@@ -187,7 +187,7 @@ end
187187

188188
function EnzymeRules.augmented_primal(
189189
config::EnzymeRules.RevConfigWidth{1},
190-
func::Const{typeof(svd_trunc!)},
190+
func::Const{typeof(svd_trunc_no_error!)},
191191
::Type{RT},
192192
A::Annotation,
193193
USVᴴ::Annotation,
@@ -198,16 +198,16 @@ function EnzymeRules.augmented_primal(
198198
svd_compact!(A.val, USVᴴ.val, alg.val.alg)
199199
cache_USVᴴ = copy.(USVᴴ.val)
200200
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ.val, alg.val.trunc)
201-
primal = EnzymeRules.needs_primal(config) ? USVᴴ′ : nothing
201+
primal = EnzymeRules.needs_primal(config) ? USVᴴ′ : nothing
202202
shadow_USVᴴ = if !isa(A, Const) && !isa(USVᴴ, Const)
203203
dU, dS, dVᴴ = USVᴴ.dval
204204
# This creates new output shadow matrices, we do this slicing
205205
# to ensure they have the correct eltype and dimensions.
206206
# These new shadow matrices are "filled in" with the accumulated
207207
# results from earlier in reverse-mode AD after this function exits
208208
# and before `reverse` is called.
209-
dStrunc = Diagonal(diagview(dS)[ind])
210-
dUtrunc = dU[:, ind]
209+
dStrunc = Diagonal(diagview(dS)[ind])
210+
dUtrunc = dU[:, ind]
211211
dVᴴtrunc = dVᴴ[ind, :]
212212
(dUtrunc, dStrunc, dVᴴtrunc)
213213
else
@@ -218,154 +218,72 @@ function EnzymeRules.augmented_primal(
218218
end
219219
function EnzymeRules.reverse(
220220
config::EnzymeRules.RevConfigWidth{1},
221-
func::Const{typeof(svd_trunc!)},
221+
func::Const{typeof(svd_trunc_no_error!)},
222222
dret::Type{RT},
223223
cache,
224224
A::Annotation,
225225
USVᴴ::Annotation,
226226
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
227227
) where {RT}
228228
cache_A, cache_USVᴴ, shadow_USVᴴ, ind = cache
229-
U, S, Vᴴ = cache_USVᴴ
229+
U, S, Vᴴ = cache_USVᴴ
230230
dU, dS, dVᴴ = shadow_USVᴴ
231-
Aval = isnothing(cache_A) ? A.val : cache_A
231+
Aval = isnothing(cache_A) ? A.val : cache_A
232232
if !isa(A, Const) && !isa(USVᴴ, Const)
233233
svd_pullback!(A.dval, Aval, (U, S, Vᴴ), shadow_USVᴴ, ind)
234234
end
235235
!isa(USVᴴ, Const) && make_zero!(USVᴴ.dval)
236236
return (nothing, nothing, nothing)
237237
end
238-
#=
239-
function EnzymeRules.augmented_primal(
240-
config::EnzymeRules.RevConfigWidth{1},
241-
func::Const{typeof(svd_trunc)},
242-
::Type{MixedDuplicated},
243-
A::Annotation,
244-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
245-
)
246-
# form cache if needed
247-
cache_A = copy(A.val)
248-
U, S, Vᴴ, ϵ = svd_trunc(A.val, USVᴴ.val, alg.val.alg)
249-
primal = EnzymeRules.needs_primal(config) ? (U, S, Vᴴ, ϵ) : nothing
250-
dU = zero(U)
251-
dS = zero(S)
252-
dVᴴ = zero(Vᴴ)
253-
dϵ = zero(ϵ)
254-
shadow = EnzymeRules.needs_shadow(config) ? (dU, dS, dVᴴ, dϵ) : nothing
255-
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, (U, S, Vᴴ), (dU, dS, dVᴴ)))
256-
end
257-
function EnzymeRules.reverse(
258-
config::EnzymeRules.RevConfigWidth{1},
259-
func::Const{typeof(svd_trunc)},
260-
dret::Type{MixedDuplicated},
261-
cache,
262-
A::Annotation,
263-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
264-
)
265-
cache_A, cache_USVᴴ, shadow_USVᴴ = cache
266-
U, S, Vᴴ = cache_USVᴴ
267-
dU, dS, dVᴴ = shadow_USVᴴ
268-
Aval = isnothing(cache_A) ? A.val : cache_A
269-
if !isa(A, Const) && !isa(USVᴴ, Const)
270-
svd_trunc_pullback!(A.dval, Aval, (U, S, Vᴴ), shadow_USVᴴ, ind)
271-
end
272-
return (nothing, nothing, nothing)
273-
end
274-
=#
275-
function EnzymeRules.augmented_primal(
276-
config::EnzymeRules.RevConfigWidth{1},
277-
func::Const{typeof(eigh_trunc!)},
278-
::Type{RT},
279-
A::Annotation,
280-
DV::Annotation{Tuple{TD, TV}},
281-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
282-
) where {RT, TD, TV}
283-
# form cache if needed
284-
cache_A = copy(A.val)
285-
MatrixAlgebraKit.eigh_full!(A.val, DV.val, alg.val.alg)
286-
cache_DV = copy.(DV.val)
287-
DV′, ind = MatrixAlgebraKit.truncate(eigh_trunc!, DV.val, alg.val.trunc)
288-
ϵ.val = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
289-
primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing
290-
shadow_DV = if !isa(A, Const) && !isa(DV, Const)
291-
dD, dV = DV.dval
292-
dDtrunc = Diagonal(diagview(dD)[ind])
293-
dVtrunc = dV[:, ind]
294-
(dDtrunc, dVtrunc)
295-
else
296-
(nothing, nothing)
297-
end
298-
!isa(ϵ, Const) && make_zero.dval)
299-
shadow_ϵ = !isa(ϵ, Const) ? ϵ.dval : zero(T)
300-
shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., shadow_ϵ) : nothing
301-
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind))
302-
end
303-
function EnzymeRules.reverse(
304-
config::EnzymeRules.RevConfigWidth{1},
305-
func::Const{typeof(eigh_trunc!)},
306-
::Type{RT},
307-
cache,
308-
A::Annotation,
309-
DV::Annotation{Tuple{TD, TV}},
310-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
311-
) where {RT, TD, TV}
312-
cache_A, cache_DV, cache_dDVtrunc, ind = cache
313-
Aval = cache_A
314-
D, V = cache_DV
315-
dD, dV = cache_dDVtrunc
316-
if !isa(A, Const) && !isa(DV, Const)
317-
MatrixAlgebraKit.eigh_pullback!(A.dval, Aval, (D, V), (dD, dV), ind)
318-
end
319-
!isa(DV, Const) && make_zero!(DV.dval)
320-
!isa(ϵ, Const) && make_zero!.dval)
321-
return (nothing, nothing, nothing, nothing)
322-
end
323238

324-
function EnzymeRules.augmented_primal(
325-
config::EnzymeRules.RevConfigWidth{1},
326-
func::Const{typeof(eig_trunc!)},
327-
::Type{RT},
328-
A::Annotation,
329-
DV::Annotation{Tuple{TD, TV}},
330-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
331-
) where {RT, TD, TV}
332-
# form cache if needed
333-
cache_A = copy(A.val)
334-
eig_full!(A.val, DV.val, alg.val.alg)
335-
cache_DV = copy.(DV.val)
336-
DV′, ind = MatrixAlgebraKit.truncate(eig_trunc!, DV.val, alg.val.trunc)
337-
ϵ.val = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
338-
primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing
339-
shadow_DV = if !isa(A, Const) && !isa(DV, Const)
340-
dD, dV = DV.dval
341-
dDtrunc = Diagonal(diagview(dD)[ind])
342-
dVtrunc = dV[:, ind]
343-
(dDtrunc, dVtrunc)
344-
else
345-
(nothing, nothing)
239+
for (f, trunc_f, full_f, pb) in (
240+
(:eigh_trunc_no_error!, :eigh_trunc!, :eigh_full!, :eigh_pullback!),
241+
(:eig_trunc_no_error!, :eig_trunc!, :eig_full!, :eig_pullback!),
242+
)
243+
@eval function EnzymeRules.augmented_primal(
244+
config::EnzymeRules.RevConfigWidth{1},
245+
func::Const{typeof($f)},
246+
::Type{RT},
247+
A::Annotation,
248+
DV::Annotation{Tuple{TD, TV}},
249+
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
250+
) where {RT, TD, TV}
251+
# form cache if needed
252+
cache_A = copy(A.val)
253+
$full_f(A.val, DV.val, alg.val.alg)
254+
cache_DV = copy.(DV.val)
255+
DV′, ind = truncate($trunc_f, DV.val, alg.val.trunc)
256+
primal = EnzymeRules.needs_primal(config) ? DV′ : nothing
257+
shadow_DV = if !isa(A, Const) && !isa(DV, Const)
258+
dD, dV = DV.dval
259+
dDtrunc = Diagonal(diagview(dD)[ind])
260+
dVtrunc = dV[:, ind]
261+
(dDtrunc, dVtrunc)
262+
else
263+
(nothing, nothing)
264+
end
265+
shadow = EnzymeRules.needs_shadow(config) ? shadow_DV : nothing
266+
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind))
346267
end
347-
shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., zero(T)) : nothing
348-
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV))
349-
end
350-
function EnzymeRules.reverse(
351-
config::EnzymeRules.RevConfigWidth{1},
352-
func::Const{typeof(eig_trunc!)},
353-
::Type{RT},
354-
cache,
355-
A::Annotation,
356-
DV::Annotation{Tuple{TD, TV}},
357-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
358-
) where {RT, TD, TV}
359-
cache_A, cache_DV, cache_dDVtrunc = cache
360-
D, V = cache_DV
361-
Aval = cache_A
362-
dD, dV = cache_dDVtrunc
363-
if !isa(A, Const) && !isa(DV, Const)
364-
eig_trunc_pullback!(A.dval, Aval, (D, V), (dD, dV))
268+
@eval function EnzymeRules.reverse(
269+
config::EnzymeRules.RevConfigWidth{1},
270+
func::Const{typeof($f)},
271+
::Type{RT},
272+
cache,
273+
A::Annotation,
274+
DV::Annotation{Tuple{TD, TV}},
275+
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
276+
) where {RT, TD, TV}
277+
cache_A, cache_DV, cache_dDVtrunc, ind = cache
278+
Aval = cache_A
279+
D, V = cache_DV
280+
dD, dV = cache_dDVtrunc
281+
if !isa(A, Const) && !isa(DV, Const)
282+
$pb(A.dval, Aval, (D, V), (dD, dV), ind)
283+
end
284+
!isa(DV, Const) && make_zero!(DV.dval)
285+
return (nothing, nothing, nothing)
365286
end
366-
!isa(DV, Const) && make_zero!(DV.dval)
367-
!isa(ϵ, Const) && make_zero!.dval)
368-
return (nothing, nothing, nothing, nothing)
369287
end
370288

371289
for (f!, f_full!, pb!) in (

test/enzyme.jl

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@ using Enzyme, EnzymeTestUtils
77
using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD
88
using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul!
99

10-
is_ci = get(ENV, "CI", "false") == "true"
11-
12-
ETs = is_ci ? (Float64, Float32) : (Float64, Float32, ComplexF32, ComplexF64) # Enzyme/#2631
10+
ETs = (Float32, ComplexF64)
1311
include("ad_utils.jl")
1412
function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; ȳ = copy.(Δargs), return_act = Duplicated)
1513
ΔA = randn(rng, eltype(A), size(A)...)
@@ -188,10 +186,8 @@ end
188186
Vtrunc = V[:, ind]
189187
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
190188
ΔVtrunc = ΔV[:, ind]
191-
# broken due to Enzyme
192-
#test_reverse(eig_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
193-
# broken due to Enzyme
194-
#test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
189+
test_reverse(eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
190+
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
195191
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
196192
dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
197193
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
@@ -202,10 +198,8 @@ end
202198
Vtrunc = V[:, ind]
203199
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
204200
ΔVtrunc = ΔV[:, ind]
205-
# broken due to Enzyme
206-
#test_reverse(eig_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
207-
# broken due to Enzyme
208-
#test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
201+
test_reverse(eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
202+
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg; ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
209203
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
210204
dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
211205
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
@@ -253,24 +247,24 @@ function copy_eigh_vals!(A, D, alg; kwargs...)
253247
return eigh_vals!(A, D, alg; kwargs...)
254248
end
255249

256-
function copy_eigh_trunc(A; kwargs...)
250+
function copy_eigh_trunc_no_error(A; kwargs...)
257251
A = (A + A') / 2
258-
return eigh_trunc(A; kwargs...)
252+
return eigh_trunc_no_error(A; kwargs...)
259253
end
260254

261-
function copy_eigh_trunc!(A, DV; kwargs...)
255+
function copy_eigh_trunc_no_error!(A, DV; kwargs...)
262256
A = (A + A') / 2
263-
return eigh_trunc!(A, DV; kwargs...)
257+
return eigh_trunc_no_error!(A, DV; kwargs...)
264258
end
265259

266-
function copy_eigh_trunc(A, alg; kwargs...)
260+
function copy_eigh_trunc_no_error(A, alg; kwargs...)
267261
A = (A + A') / 2
268-
return eigh_trunc(A; kwargs...)
262+
return eigh_trunc_no_error(A, alg; kwargs...)
269263
end
270264

271-
function copy_eigh_trunc!(A, DV, alg; kwargs...)
265+
function copy_eigh_trunc_no_error!(A, DV, alg; kwargs...)
272266
A = (A + A') / 2
273-
return eigh_trunc!(A, DV; kwargs...)
267+
return eigh_trunc_no_error!(A, DV, alg; kwargs...)
274268
end
275269

276270
@timedtestset "EIGH AD Rules with eltype $T" for T in ETs
@@ -307,9 +301,8 @@ end
307301
Vtrunc = V[:, ind]
308302
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
309303
ΔVtrunc = ΔV[:, ind]
310-
# broken due to Enzyme
311-
#test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
312-
#test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
304+
test_reverse(copy_eigh_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
305+
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
313306
end
314307
Ddiag = diagview(D)
315308
truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2))
@@ -318,9 +311,8 @@ end
318311
Vtrunc = V[:, ind]
319312
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
320313
ΔVtrunc = ΔV[:, ind]
321-
# broken due to Enzyme
322-
#test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
323-
#test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
314+
test_reverse(copy_eigh_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
315+
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
324316
end
325317
end
326318
end
@@ -373,21 +365,21 @@ end
373365
@testset "svd_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
374366
for r in 1:4:minmn
375367
U, S, Vᴴ = svd_compact(A)
376-
ΔU = randn(rng, T, m, minmn)
377-
ΔS = randn(rng, real(T), minmn, minmn)
368+
ΔU = randn(rng, T, m, minmn)
369+
ΔS = randn(rng, real(T), minmn, minmn)
378370
ΔS2 = Diagonal(randn(rng, real(T), minmn))
379371
ΔVᴴ = randn(rng, T, minmn, n)
380372
ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol)
381373
truncalg = TruncatedAlgorithm(alg, truncrank(r))
382374
ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc)
383-
Strunc = Diagonal(diagview(S)[ind])
384-
Utrunc = U[:, ind]
385-
Vᴴtrunc = Vᴴ[ind, :]
386-
ΔStrunc = Diagonal(diagview(ΔS2)[ind])
387-
ΔUtrunc = ΔU[:, ind]
375+
Strunc = Diagonal(diagview(S)[ind])
376+
Utrunc = U[:, ind]
377+
Vᴴtrunc = Vᴴ[ind, :]
378+
ΔStrunc = Diagonal(diagview(ΔS2)[ind])
379+
ΔUtrunc = ΔU[:, ind]
388380
ΔVᴴtrunc = ΔVᴴ[ind, :]
389-
test_reverse(svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm)
390-
test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act=RT)
381+
test_reverse(svd_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm)
382+
test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act = RT)
391383
end
392384
U, S, Vᴴ = svd_compact(A)
393385
ΔU = randn(rng, T, m, minmn)
@@ -403,8 +395,8 @@ end
403395
ΔStrunc = Diagonal(diagview(ΔS2)[ind])
404396
ΔUtrunc = ΔU[:, ind]
405397
ΔVᴴtrunc = ΔVᴴ[ind, :]
406-
test_reverse(svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm)
407-
test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act=RT)
398+
test_reverse(svd_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm)
399+
test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act = RT)
408400
end
409401
end
410402
end

0 commit comments

Comments
 (0)