Skip to content

Commit ebffb20

Browse files
committed
Split svd_trunc
Now we have a `svd_trunc_with_err` that returns epsilon for those who wish, and `svd_trunc` returns only the truncated USVh
1 parent cf57841 commit ebffb20

File tree

12 files changed

+251
-71
lines changed

12 files changed

+251
-71
lines changed

docs/src/user_interface/truncations.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,16 @@ combined_trunc = truncrank(10) & trunctol(; atol = 1e-6);
113113

114114
## Truncation Error
115115

116-
When using truncated decompositions such as [`svd_trunc`](@ref), [`eig_trunc`](@ref), or [`eigh_trunc`](@ref), an additional truncation error value is returned.
116+
When using truncated decompositions such as [`svd_trunc_with_err`](@ref), [`eig_trunc`](@ref), or [`eigh_trunc`](@ref), an additional truncation error value is returned.
117117
This error is defined as the 2-norm of the discarded singular values or eigenvalues, providing a measure of the approximation quality.
118-
For `svd_trunc` and `eigh_trunc`, this corresponds to the 2-norm difference between the original and the truncated matrix.
118+
For `svd_trunc_with_err` and `eigh_trunc`, this corresponds to the 2-norm difference between the original and the truncated matrix.
119119
For the case of `eig_trunc`, this interpretation does not hold because the norm of the non-unitary matrix of eigenvectors and its inverse also influence the approximation quality.
120120

121121

122122
For example:
123123
```jldoctest truncations; output=false
124124
using LinearAlgebra: norm
125-
U, S, Vᴴ, ϵ = svd_trunc(A; trunc=truncrank(2))
125+
U, S, Vᴴ, ϵ = svd_trunc_with_err(A; trunc=truncrank(2))
126126
norm(A - U * S * Vᴴ) ≈ ϵ # ϵ is the 2-norm of the discarded singular values
127127
128128
# output

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,15 +170,15 @@ for svd_f in (:svd_compact, :svd_full)
170170
end
171171
end
172172

173-
function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlgorithm)
173+
function ChainRulesCore.rrule(::typeof(svd_trunc_with_err!), A, USVᴴ, alg::TruncatedAlgorithm)
174174
Ac = copy_input(svd_compact, A)
175175
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
176176
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
177177
ϵ = truncation_error(diagview(USVᴴ[2]), ind)
178-
return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind)
178+
return (USVᴴ′..., ϵ), _make_svd_trunc_with_err_pullback(A, USVᴴ, ind)
179179
end
180-
function _make_svd_trunc_pullback(A, USVᴴ, ind)
181-
function svd_trunc_pullback(ΔUSVᴴϵ)
180+
function _make_svd_trunc_with_err_pullback(A, USVᴴ, ind)
181+
function svd_trunc_with_err_pullback(ΔUSVᴴϵ)
182182
ΔA = zero(A)
183183
ΔU, ΔS, ΔVᴴ, Δϵ = ΔUSVᴴϵ
184184
if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ))
@@ -187,6 +187,25 @@ function _make_svd_trunc_pullback(A, USVᴴ, ind)
187187
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind)
188188
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
189189
end
190+
function svd_trunc_with_err_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
191+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
192+
end
193+
return svd_trunc_with_err_pullback
194+
end
195+
196+
function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlgorithm)
197+
Ac = copy_input(svd_compact, A)
198+
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
199+
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
200+
return USVᴴ′, _make_svd_trunc_pullback(A, USVᴴ, ind)
201+
end
202+
function _make_svd_trunc_pullback(A, USVᴴ, ind)
203+
function svd_trunc_pullback(ΔUSVᴴ)
204+
ΔA = zero(A)
205+
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
206+
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind)
207+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
208+
end
190209
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
191210
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
192211
end

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,14 +303,14 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
303303
return S_codual, svd_vals_adjoint
304304
end
305305

306-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm}
307-
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
306+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_with_err), Any, MatrixAlgebraKit.AbstractAlgorithm}
307+
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_with_err)}, A_dA::CoDual, alg_dalg::CoDual)
308308
# compute primal
309309
A_ = Mooncake.primal(A_dA)
310310
dA_ = Mooncake.tangent(A_dA)
311311
A, dA = arrayify(A_, dA_)
312312
alg = Mooncake.primal(alg_dalg)
313-
output = svd_trunc(A, alg)
313+
output = svd_trunc_with_err(A, alg)
314314
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
315315
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
316316
# pass). For many types this is done automatically when the forward step returns, but
@@ -319,7 +319,35 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
319319
function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real}
320320
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
321321
dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual)
322-
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"
322+
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc_with_err does not yet support non-zero tangent for the truncation error"
323+
U, dU = arrayify(Utrunc, dUtrunc_)
324+
S, dS = arrayify(Strunc, dStrunc_)
325+
Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_)
326+
svd_trunc_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
327+
MatrixAlgebraKit.zero!(dU)
328+
MatrixAlgebraKit.zero!(dS)
329+
MatrixAlgebraKit.zero!(dVᴴ)
330+
return NoRData(), NoRData(), NoRData()
331+
end
332+
return output_codual, svd_trunc_adjoint
333+
end
334+
335+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm}
336+
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
337+
# compute primal
338+
A_ = Mooncake.primal(A_dA)
339+
dA_ = Mooncake.tangent(A_dA)
340+
A, dA = arrayify(A_, dA_)
341+
alg = Mooncake.primal(alg_dalg)
342+
output = svd_trunc(A, alg)
343+
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
344+
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
345+
# pass). For many types this is done automatically when the forward step returns, but
346+
# not for nested structs with various fields (like Diagonal{Complex})
347+
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
348+
function svd_trunc_adjoint(::NoRData)
349+
Utrunc, Strunc, Vᴴtrunc = Mooncake.primal(output_codual)
350+
dUtrunc_, dStrunc_, dVᴴtrunc_ = Mooncake.tangent(output_codual)
323351
U, dU = arrayify(Utrunc, dUtrunc_)
324352
S, dS = arrayify(Strunc, dStrunc_)
325353
Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_)

src/MatrixAlgebraKit.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ export project_hermitian, project_antihermitian, project_isometric
1616
export project_hermitian!, project_antihermitian!, project_isometric!
1717
export qr_compact, qr_full, qr_null, lq_compact, lq_full, lq_null
1818
export qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!
19-
export svd_compact, svd_full, svd_vals, svd_trunc
20-
export svd_compact!, svd_full!, svd_vals!, svd_trunc!
19+
export svd_compact, svd_full, svd_vals, svd_trunc, svd_trunc_with_err
20+
export svd_compact!, svd_full!, svd_vals!, svd_trunc!, svd_trunc_with_err!
2121
export eigh_full, eigh_vals, eigh_trunc
2222
export eigh_full!, eigh_vals!, eigh_trunc!
2323
export eig_full, eig_vals, eig_trunc

src/implementations/svd.jl

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ copy_input(::typeof(svd_full), A::AbstractMatrix) = copy!(similar(A, float(eltyp
44
copy_input(::typeof(svd_compact), A) = copy_input(svd_full, A)
55
copy_input(::typeof(svd_vals), A) = copy_input(svd_full, A)
66
copy_input(::typeof(svd_trunc), A) = copy_input(svd_compact, A)
7+
copy_input(::typeof(svd_trunc_with_err), A) = copy_input(svd_compact, A)
78

89
copy_input(::typeof(svd_full), A::Diagonal) = copy(A)
910

@@ -92,6 +93,9 @@ end
9293
function initialize_output(::typeof(svd_trunc!), A, alg::TruncatedAlgorithm)
9394
return initialize_output(svd_compact!, A, alg.alg)
9495
end
96+
function initialize_output(::typeof(svd_trunc_with_err!), A, alg::TruncatedAlgorithm)
97+
return initialize_output(svd_compact!, A, alg.alg)
98+
end
9599

96100
function initialize_output(::typeof(svd_full!), A::Diagonal, ::DiagonalAlgorithm)
97101
TA = eltype(A)
@@ -206,19 +210,16 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
206210
return S
207211
end
208212

209-
function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ}
210-
ϵ = similar(A, real(eltype(A)), compute_error)
211-
(U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg)
212-
return compute_error ? (U, S, Vᴴ, norm(ϵ)) : (U, S, Vᴴ, -one(eltype(ϵ)))
213+
function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
214+
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
215+
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
216+
return USVᴴtrunc
213217
end
214218

215-
function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm) where {TU, TS, TVᴴ, Tϵ}
216-
U, S, Vᴴ, ϵ = USVᴴϵ
217-
U, S, Vᴴ = svd_compact!(A, (U, S, Vᴴ), alg.alg)
219+
function svd_trunc_with_err!(A, USVᴴ, alg::TruncatedAlgorithm)
220+
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
218221
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
219-
if !isempty(ϵ)
220-
ϵ .= truncation_error!(diagview(S), ind)
221-
end
222+
ϵ = truncation_error!(diagview(S), ind)
222223
return USVᴴtrunc..., ϵ
223224
end
224225

@@ -287,6 +288,22 @@ function check_input(
287288
return nothing
288289
end
289290

291+
function check_input(
292+
::typeof(svd_trunc_with_err!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized
293+
)
294+
m, n = size(A)
295+
minmn = min(m, n)
296+
U, S, Vᴴ = USVᴴ
297+
@assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix
298+
@check_size(U, (m, m))
299+
@check_scalar(U, A)
300+
@check_size(S, (minmn, minmn))
301+
@check_scalar(S, A, real)
302+
@check_size(Vᴴ, (n, n))
303+
@check_scalar(Vᴴ, A)
304+
return nothing
305+
end
306+
290307
function initialize_output(
291308
::typeof(svd_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized}
292309
)
@@ -298,6 +315,17 @@ function initialize_output(
298315
return (U, S, Vᴴ)
299316
end
300317

318+
function initialize_output(
319+
::typeof(svd_trunc_with_err!), A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized}
320+
)
321+
m, n = size(A)
322+
minmn = min(m, n)
323+
U = similar(A, (m, m))
324+
S = Diagonal(similar(A, real(eltype(A)), (minmn,)))
325+
Vᴴ = similar(A, (n, n))
326+
return (U, S, Vᴴ)
327+
end
328+
301329
function _gpu_gesvd!(
302330
A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix
303331
)
@@ -372,22 +400,34 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
372400
return USVᴴ
373401
end
374402

375-
function svd_trunc!(A::AbstractMatrix, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm{<:GPU_Randomized}) where {TU, TS, TVᴴ, Tϵ}
376-
U, S, Vᴴ, ϵ = USVᴴϵ
403+
function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
404+
U, S, Vᴴ = USVᴴ
377405
check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg)
378406
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)
379407

380408
# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
381409
(Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
382410

383-
if !isempty(ϵ)
384-
# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
385-
normS = norm(diagview(Str))
386-
normA = norm(A)
387-
# equivalent to sqrt(normA^2 - normS^2)
388-
# but may be more accurate
389-
ϵ = sqrt((normA + normS) * (normA - normS))
390-
end
411+
do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool
412+
do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr)
413+
414+
return Utr, Str, Vᴴtr
415+
end
416+
417+
function svd_trunc_with_err!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
418+
U, S, Vᴴ = USVᴴ
419+
check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg)
420+
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)
421+
422+
# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
423+
(Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
424+
425+
# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
426+
normS = norm(diagview(Str))
427+
normA = norm(A)
428+
# equivalent to sqrt(normA^2 - normS^2)
429+
# but may be more accurate
430+
ϵ = sqrt((normA + normS) * (normA - normS))
391431

392432
do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool
393433
do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr)

src/interface/svd.jl

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_vals(!)`](@ref svd_vals) and
4242
@functiondef svd_compact
4343

4444
"""
45-
svd_trunc(A; [trunc], kwargs...) -> U, S, Vᴴ, ϵ
46-
svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ
47-
svd_trunc!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ, ϵ
48-
svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ
45+
svd_trunc_with_err(A; [trunc], kwargs...) -> U, S, Vᴴ, ϵ
46+
svd_trunc_with_err(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ
47+
svd_trunc_with_err!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ, ϵ
48+
svd_trunc_with_err!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ
4949
5050
Compute a partial or truncated singular value decomposition (SVD) of `A`, such that
5151
`A * (Vᴴ)' ≈ U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size
@@ -81,6 +81,54 @@ for the default algorithm selection behavior.
8181
When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the
8282
truncation strategy is already embedded in the algorithm.
8383
84+
!!! note
85+
The bang method `svd_trunc!` optionally accepts the output structure and
86+
possibly destroys the input matrix `A`. Always use the return value of the function
87+
as it may not always be possible to use the provided `USVᴴ` as output.
88+
89+
See also [`svd_trunc(!)`](@ref svd_trunc), [`svd_full(!)`](@ref svd_full),
90+
[`svd_compact(!)`](@ref svd_compact), [`svd_vals(!)`](@ref svd_vals),
91+
and [Truncations](@ref) for more information on truncation strategies.
92+
"""
93+
@functiondef svd_trunc_with_err
94+
95+
"""
96+
svd_trunc(A; [trunc], kwargs...) -> U, S, Vᴴ
97+
svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ
98+
svd_trunc!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ
99+
svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ
100+
101+
Compute a partial or truncated singular value decomposition (SVD) of `A`, such that
102+
`A * (Vᴴ)' ≈ U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size
103+
`(m, k)`, whereas `Vᴴ` is a matrix of size `(k, n)` with orthonormal rows and `S` is a
104+
square diagonal matrix of size `(k, k)`, with `k` is set by the truncation strategy.
105+
106+
## Truncation
107+
The truncation strategy can be controlled via the `trunc` keyword argument. This can be
108+
either a `NamedTuple` or a [`TruncationStrategy`](@ref). If `trunc` is not provided or
109+
nothing, all values will be kept.
110+
111+
### `trunc::NamedTuple`
112+
The supported truncation keyword arguments are:
113+
114+
$docs_truncation_kwargs
115+
116+
### `trunc::TruncationStrategy`
117+
For more control, a truncation strategy can be supplied directly.
118+
By default, MatrixAlgebraKit supplies the following:
119+
120+
$docs_truncation_strategies
121+
122+
## Keyword arguments
123+
Other keyword arguments are passed to the algorithm selection procedure. If no explicit
124+
`alg` is provided, these keywords are used to select and configure the algorithm through
125+
[`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm
126+
selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref)
127+
for the default algorithm selection behavior.
128+
129+
When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the
130+
truncation strategy is already embedded in the algorithm.
131+
84132
!!! note
85133
The bang method `svd_trunc!` optionally accepts the output structure and
86134
possibly destroys the input matrix `A`. Always use the return value of the function
@@ -125,13 +173,15 @@ for f in (:svd_full!, :svd_compact!, :svd_vals!)
125173
end
126174
end
127175

128-
function select_algorithm(::typeof(svd_trunc!), A, alg; trunc = nothing, kwargs...)
129-
if alg isa TruncatedAlgorithm
130-
isnothing(trunc) ||
131-
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
132-
return alg
133-
else
134-
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
135-
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
176+
for f in (:svd_trunc!, :svd_trunc_with_err!)
177+
@eval function select_algorithm(::typeof($f), A, alg; trunc = nothing, kwargs...)
178+
if alg isa TruncatedAlgorithm
179+
isnothing(trunc) ||
180+
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
181+
return alg
182+
else
183+
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
184+
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
185+
end
136186
end
137187
end

test/amd/svd.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,14 @@ end
140140
# minmn = min(m, n)
141141
# r = minmn - 2
142142
#
143-
# U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc=truncrank(r))
143+
# U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc_with_err(A; alg, trunc=truncrank(r))
144144
# @test length(S1.diag) == r
145145
# @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1]
146146
#
147147
# s = 1 + sqrt(eps(real(T)))
148148
# trunc2 = trunctol(; atol=s * S₀[r + 1])
149149
#
150-
# U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1]))
150+
# U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc_with_err(A; alg, trunc=trunctol(; atol=s * S₀[r + 1]))
151151
# @test length(S2.diag) == r
152152
# @test U1 ≈ U2
153153
# @test S1 ≈ S2

0 commit comments

Comments
 (0)