Skip to content

Commit c591e2e

Browse files
kshyattKatharine Hyatt
andauthored
Split svd_trunc (#116)
* Split svd_trunc Now we have a `svd_trunc_with_err` that returns epsilon for those who wish, and `svd_trunc` returns only the truncated USVh * Fix coverage * S.diag to diagview(S) * Reduce duplication * Switch to svd_trunc and svd_trunc_no_error * update docstring and test --------- Co-authored-by: Katharine Hyatt <katharine.s.hyatt@gmail.com>
1 parent cf57841 commit c591e2e

File tree

10 files changed

+200
-50
lines changed

10 files changed

+200
-50
lines changed

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,25 @@ function _make_svd_trunc_pullback(A, USVᴴ, ind)
193193
return svd_trunc_pullback
194194
end
195195

196+
function ChainRulesCore.rrule(::typeof(svd_trunc_no_error!), A, USVᴴ, alg::TruncatedAlgorithm)
197+
Ac = copy_input(svd_compact, A)
198+
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
199+
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
200+
return USVᴴ′, _make_svd_trunc_no_error_pullback(A, USVᴴ, ind)
201+
end
202+
function _make_svd_trunc_no_error_pullback(A, USVᴴ, ind)
203+
function svd_trunc_pullback(ΔUSVᴴ)
204+
ΔA = zero(A)
205+
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
206+
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind)
207+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
208+
end
209+
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
210+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
211+
end
212+
return svd_trunc_pullback
213+
end
214+
196215
function ChainRulesCore.rrule(::typeof(svd_vals!), A, S, alg)
197216
USVᴴ = svd_compact(A, alg)
198217
function svd_vals_pullback(ΔS)

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,35 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
319319
function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real}
320320
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
321321
dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual)
322-
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"
322+
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc does not yet support non-zero tangent for the truncation error"
323+
U, dU = arrayify(Utrunc, dUtrunc_)
324+
S, dS = arrayify(Strunc, dStrunc_)
325+
Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_)
326+
svd_trunc_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
327+
MatrixAlgebraKit.zero!(dU)
328+
MatrixAlgebraKit.zero!(dS)
329+
MatrixAlgebraKit.zero!(dVᴴ)
330+
return NoRData(), NoRData(), NoRData()
331+
end
332+
return output_codual, svd_trunc_adjoint
333+
end
334+
335+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm}
336+
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual)
337+
# compute primal
338+
A_ = Mooncake.primal(A_dA)
339+
dA_ = Mooncake.tangent(A_dA)
340+
A, dA = arrayify(A_, dA_)
341+
alg = Mooncake.primal(alg_dalg)
342+
output = svd_trunc_no_error(A, alg)
343+
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
344+
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
345+
# pass). For many types this is done automatically when the forward step returns, but
346+
# not for nested structs with various fields (like Diagonal{Complex})
347+
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
348+
function svd_trunc_adjoint(::NoRData)
349+
Utrunc, Strunc, Vᴴtrunc = Mooncake.primal(output_codual)
350+
dUtrunc_, dStrunc_, dVᴴtrunc_ = Mooncake.tangent(output_codual)
323351
U, dU = arrayify(Utrunc, dUtrunc_)
324352
S, dS = arrayify(Strunc, dStrunc_)
325353
Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_)

src/MatrixAlgebraKit.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ export project_hermitian, project_antihermitian, project_isometric
1616
export project_hermitian!, project_antihermitian!, project_isometric!
1717
export qr_compact, qr_full, qr_null, lq_compact, lq_full, lq_null
1818
export qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!
19-
export svd_compact, svd_full, svd_vals, svd_trunc
20-
export svd_compact!, svd_full!, svd_vals!, svd_trunc!
19+
export svd_compact, svd_full, svd_vals, svd_trunc, svd_trunc_no_error
20+
export svd_compact!, svd_full!, svd_vals!, svd_trunc!, svd_trunc_no_error!
2121
export eigh_full, eigh_vals, eigh_trunc
2222
export eigh_full!, eigh_vals!, eigh_trunc!
2323
export eig_full, eig_vals, eig_trunc

src/implementations/svd.jl

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
copy_input(::typeof(svd_full), A::AbstractMatrix) = copy!(similar(A, float(eltype(A))), A)
44
copy_input(::typeof(svd_compact), A) = copy_input(svd_full, A)
55
copy_input(::typeof(svd_vals), A) = copy_input(svd_full, A)
6-
copy_input(::typeof(svd_trunc), A) = copy_input(svd_compact, A)
6+
copy_input(::Union{typeof(svd_trunc), typeof(svd_trunc_no_error)}, A) = copy_input(svd_compact, A)
77

88
copy_input(::typeof(svd_full), A::Diagonal) = copy(A)
99

@@ -89,7 +89,7 @@ end
8989
function initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::AbstractAlgorithm)
9090
return similar(A, real(eltype(A)), (min(size(A)...),))
9191
end
92-
function initialize_output(::typeof(svd_trunc!), A, alg::TruncatedAlgorithm)
92+
function initialize_output(::Union{typeof(svd_trunc!), typeof(svd_trunc_no_error!)}, A, alg::TruncatedAlgorithm)
9393
return initialize_output(svd_compact!, A, alg.alg)
9494
end
9595

@@ -159,17 +159,17 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
159159
if alg isa LAPACK_QRIteration
160160
isempty(alg_kwargs) ||
161161
throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration"))
162-
YALAPACK.gesvd!(A, S.diag, U, Vᴴ)
162+
YALAPACK.gesvd!(A, diagview(S), U, Vᴴ)
163163
elseif alg isa LAPACK_DivideAndConquer
164164
isempty(alg_kwargs) ||
165165
throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer"))
166-
YALAPACK.gesdd!(A, S.diag, U, Vᴴ)
166+
YALAPACK.gesdd!(A, diagview(S), U, Vᴴ)
167167
elseif alg isa LAPACK_Bisection
168-
YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; alg_kwargs...)
168+
YALAPACK.gesvdx!(A, diagview(S), U, Vᴴ; alg_kwargs...)
169169
elseif alg isa LAPACK_Jacobi
170170
isempty(alg_kwargs) ||
171171
throw(ArgumentError("invalid keyword arguments for LAPACK_Jacobi"))
172-
YALAPACK.gesvj!(A, S.diag, U, Vᴴ)
172+
YALAPACK.gesvj!(A, diagview(S), U, Vᴴ)
173173
else
174174
throw(ArgumentError("Unsupported SVD algorithm"))
175175
end
@@ -206,19 +206,16 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
206206
return S
207207
end
208208

209-
function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ}
210-
ϵ = similar(A, real(eltype(A)), compute_error)
211-
(U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg)
212-
return compute_error ? (U, S, Vᴴ, norm(ϵ)) : (U, S, Vᴴ, -one(eltype(ϵ)))
209+
function svd_trunc_no_error!(A, USVᴴ, alg::TruncatedAlgorithm)
210+
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
211+
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
212+
return USVᴴtrunc
213213
end
214214

215-
function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm) where {TU, TS, TVᴴ, Tϵ}
216-
U, S, Vᴴ, ϵ = USVᴴϵ
217-
U, S, Vᴴ = svd_compact!(A, (U, S, Vᴴ), alg.alg)
215+
function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
216+
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
218217
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
219-
if !isempty(ϵ)
220-
ϵ .= truncation_error!(diagview(S), ind)
221-
end
218+
ϵ = truncation_error!(diagview(S), ind)
222219
return USVᴴtrunc..., ϵ
223220
end
224221

@@ -272,7 +269,7 @@ end
272269
###
273270

274271
function check_input(
275-
::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized
272+
::Union{typeof(svd_trunc!), typeof(svd_trunc_no_error!)}, A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized
276273
)
277274
m, n = size(A)
278275
minmn = min(m, n)
@@ -288,7 +285,7 @@ function check_input(
288285
end
289286

290287
function initialize_output(
291-
::typeof(svd_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized}
288+
::Union{typeof(svd_trunc!), typeof(svd_trunc_no_error!)}, A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized}
292289
)
293290
m, n = size(A)
294291
minmn = min(m, n)
@@ -372,22 +369,34 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
372369
return USVᴴ
373370
end
374371

375-
function svd_trunc!(A::AbstractMatrix, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm{<:GPU_Randomized}) where {TU, TS, TVᴴ, Tϵ}
376-
U, S, Vᴴ, ϵ = USVᴴϵ
372+
function svd_trunc_no_error!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
373+
U, S, Vᴴ = USVᴴ
374+
check_input(svd_trunc_no_error!, A, (U, S, Vᴴ), alg.alg)
375+
_gpu_Xgesvdr!(A, diagview(S), U, Vᴴ; alg.alg.kwargs...)
376+
377+
# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
378+
(Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
379+
380+
do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool
381+
do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr)
382+
383+
return Utr, Str, Vᴴtr
384+
end
385+
386+
function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
387+
U, S, Vᴴ = USVᴴ
377388
check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg)
378-
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)
389+
_gpu_Xgesvdr!(A, diagview(S), U, Vᴴ; alg.alg.kwargs...)
379390

380391
# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
381392
(Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
382393

383-
if !isempty(ϵ)
384-
# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
385-
normS = norm(diagview(Str))
386-
normA = norm(A)
387-
# equivalent to sqrt(normA^2 - normS^2)
388-
# but may be more accurate
389-
ϵ = sqrt((normA + normS) * (normA - normS))
390-
end
394+
# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
395+
normS = norm(diagview(Str))
396+
normA = norm(A)
397+
# equivalent to sqrt(normA^2 - normS^2)
398+
# but may be more accurate
399+
ϵ = sqrt((normA + normS) * (normA - normS))
391400

392401
do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool
393402
do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr)
@@ -404,11 +413,11 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
404413

405414
if alg isa GPU_QRIteration
406415
isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration"
407-
_gpu_gesvd_maybe_transpose!(A, S.diag, U, Vᴴ)
416+
_gpu_gesvd_maybe_transpose!(A, diagview(S), U, Vᴴ)
408417
elseif alg isa GPU_SVDPolar
409-
_gpu_Xgesvdp!(A, S.diag, U, Vᴴ; alg_kwargs...)
418+
_gpu_Xgesvdp!(A, diagview(S), U, Vᴴ; alg_kwargs...)
410419
elseif alg isa GPU_Jacobi
411-
_gpu_gesvdj!(A, S.diag, U, Vᴴ; alg_kwargs...)
420+
_gpu_gesvdj!(A, diagview(S), U, Vᴴ; alg_kwargs...)
412421
else
413422
throw(ArgumentError("Unsupported SVD algorithm"))
414423
end

src/interface/svd.jl

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,61 @@ truncation strategy is already embedded in the algorithm.
8686
possibly destroys the input matrix `A`. Always use the return value of the function
8787
as it may not always be possible to use the provided `USVᴴ` as output.
8888
89-
See also [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact),
90-
[`svd_vals(!)`](@ref svd_vals), and [Truncations](@ref) for more information on
91-
truncation strategies.
89+
See also [`svd_trunc_no_error(!)`](@ref svd_trunc_no_error), [`svd_full(!)`](@ref svd_full),
90+
[`svd_compact(!)`](@ref svd_compact), [`svd_vals(!)`](@ref svd_vals),
91+
and [Truncations](@ref) for more information on truncation strategies.
9292
"""
9393
@functiondef svd_trunc
9494

95+
"""
96+
svd_trunc_no_error(A; [trunc], kwargs...) -> U, S, Vᴴ
97+
svd_trunc_no_error(A, alg::AbstractAlgorithm) -> U, S, Vᴴ
98+
svd_trunc_no_error!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ
99+
svd_trunc_no_error!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ
100+
101+
Compute a partial or truncated singular value decomposition (SVD) of `A`, such that
102+
`A * (Vᴴ)' ≈ U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size
103+
`(m, k)`, whereas `Vᴴ` is a matrix of size `(k, n)` with orthonormal rows and `S` is a
104+
square diagonal matrix of size `(k, k)`, with `k` is set by the truncation strategy.
105+
The truncation error is *not* returned.
106+
107+
## Truncation
108+
The truncation strategy can be controlled via the `trunc` keyword argument. This can be
109+
either a `NamedTuple` or a [`TruncationStrategy`](@ref). If `trunc` is not provided or
110+
nothing, all values will be kept.
111+
112+
### `trunc::NamedTuple`
113+
The supported truncation keyword arguments are:
114+
115+
$docs_truncation_kwargs
116+
117+
### `trunc::TruncationStrategy`
118+
For more control, a truncation strategy can be supplied directly.
119+
By default, MatrixAlgebraKit supplies the following:
120+
121+
$docs_truncation_strategies
122+
123+
## Keyword arguments
124+
Other keyword arguments are passed to the algorithm selection procedure. If no explicit
125+
`alg` is provided, these keywords are used to select and configure the algorithm through
126+
[`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm
127+
selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref)
128+
for the default algorithm selection behavior.
129+
130+
When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the
131+
truncation strategy is already embedded in the algorithm.
132+
133+
!!! note
134+
The bang method `svd_trunc_no_error!` optionally accepts the output structure and
135+
possibly destroys the input matrix `A`. Always use the return value of the function
136+
as it may not always be possible to use the provided `USVᴴ` as output.
137+
138+
See also [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact),
139+
[`svd_vals(!)`](@ref svd_vals), [`svd_trunc(!)`](@ref svd_trunc) and
140+
[Truncations](@ref) for more information on truncation strategies.
141+
"""
142+
@functiondef svd_trunc_no_error
143+
95144
"""
96145
svd_vals(A; kwargs...) -> S
97146
svd_vals(A, alg::AbstractAlgorithm) -> S
@@ -125,13 +174,15 @@ for f in (:svd_full!, :svd_compact!, :svd_vals!)
125174
end
126175
end
127176

128-
function select_algorithm(::typeof(svd_trunc!), A, alg; trunc = nothing, kwargs...)
129-
if alg isa TruncatedAlgorithm
130-
isnothing(trunc) ||
131-
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
132-
return alg
133-
else
134-
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
135-
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
177+
for f in (:svd_trunc!, :svd_trunc_no_error!)
178+
@eval function select_algorithm(::typeof($f), A, alg; trunc = nothing, kwargs...)
179+
if alg isa TruncatedAlgorithm
180+
isnothing(trunc) ||
181+
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
182+
return alg
183+
else
184+
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
185+
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
186+
end
136187
end
137188
end

test/chainrules.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ for f in
1212
(
1313
:qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null,
1414
:eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals,
15-
:svd_compact, :svd_trunc, :svd_vals,
15+
:svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals,
1616
:left_polar, :right_polar,
1717
)
1818
copy_f = Symbol(:copy_, f)
@@ -434,6 +434,11 @@ end
434434
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))),
435435
atol = atol, rtol = rtol
436436
)
437+
test_rrule(
438+
copy_svd_trunc_no_error, A, truncalg NoTangent();
439+
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc),
440+
atol = atol, rtol = rtol
441+
)
437442
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
438443
dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc))
439444
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
@@ -451,6 +456,11 @@ end
451456
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))),
452457
atol = atol, rtol = rtol
453458
)
459+
test_rrule(
460+
copy_svd_trunc_no_error, A, truncalg NoTangent();
461+
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc),
462+
atol = atol, rtol = rtol
463+
)
454464
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
455465
dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc))
456466
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
@@ -480,6 +490,12 @@ end
480490
output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))),
481491
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
482492
)
493+
test_rrule(
494+
config, svd_trunc_no_error, A;
495+
fkwargs = (; trunc = trunc),
496+
output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]),
497+
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
498+
)
483499
end
484500
trunc = trunctol(; atol = S[1, 1] / 2)
485501
ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc)
@@ -489,6 +505,12 @@ end
489505
output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))),
490506
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
491507
)
508+
test_rrule(
509+
config, svd_trunc_no_error, A;
510+
fkwargs = (; trunc = trunc),
511+
output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]),
512+
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
513+
)
492514
end
493515
end
494516

test/cuda/svd.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ end
144144
@test length(S1.diag) == r
145145
@test opnorm(A - U1 * S1 * V1ᴴ) S₀[r + 1]
146146
@test norm(A - U1 * S1 * V1ᴴ) ϵ1
147+
U1, S1, V1ᴴ = @constinferred svd_trunc_no_error(A; alg, trunc = truncrank(r))
148+
@test length(S1.diag) == r
149+
@test opnorm(A - U1 * S1 * V1ᴴ) S₀[r + 1]
147150

148151
if !(alg isa CUSOLVER_Randomized)
149152
s = 1 + sqrt(eps(real(T)))
@@ -154,6 +157,12 @@ end
154157
@test U1 U2
155158
@test parent(S1) parent(S2)
156159
@test V1ᴴ V2ᴴ
160+
161+
U2, S2, V2ᴴ = @constinferred svd_trunc_no_error(A; alg, trunc = trunctol(; atol = s * S₀[r + 1]))
162+
@test length(S2.diag) == r
163+
@test U1 U2
164+
@test parent(S1) parent(S2)
165+
@test V1ᴴ V2ᴴ
157166
end
158167
end
159168
end

0 commit comments

Comments
 (0)