@@ -206,10 +206,35 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
206206 return S
207207end
208208
209- function svd_trunc! (A, USVᴴ, alg:: TruncatedAlgorithm )
209+ # nothing case here to handle GenericLinearAlgebra
210+ function svd_trunc! (A, USVᴴϵ:: Tuple{TU, TS, TVᴴ, Tϵ} , alg:: TruncatedAlgorithm ) where {TU, TS, TVᴴ, Tϵ}
211+ U, S, Vᴴ, ϵ = USVᴴϵ
212+ U, S, Vᴴ = svd_compact! (A, (U, S, Vᴴ), alg. alg)
213+ USVᴴtrunc, ind = truncate (svd_trunc!, (U, S, Vᴴ), alg. trunc)
214+ if ! isempty (ϵ)
215+ ϵ[1 ] = truncation_error! (diagview (S), ind)
216+ end
217+ return USVᴴtrunc... , ϵ
218+ end
219+ function svd_trunc! (A, USVᴴϵ:: Tuple{Nothing, Tϵ} , alg:: TruncatedAlgorithm ) where {Tϵ}
220+ USVᴴ, ϵ = USVᴴϵ
210221 U, S, Vᴴ = svd_compact! (A, USVᴴ, alg. alg)
211222 USVᴴtrunc, ind = truncate (svd_trunc!, (U, S, Vᴴ), alg. trunc)
212- return USVᴴtrunc... , truncation_error! (diagview (S), ind)
223+ if ! isempty (ϵ)
224+ ϵ[1 ] = truncation_error! (diagview (S), ind)
225+ end
226+ return USVᴴtrunc... , ϵ
227+ end
228+
229+ function svd_trunc! (A, USVᴴ:: Tuple{TU, TS, TVᴴ} , alg:: TruncatedAlgorithm ; compute_error:: Bool = true ) where {TU, TS, TVᴴ}
230+ ϵ = compute_error ? zeros (real (eltype (A)), 1 ) : zeros (real (eltype (A)), 0 )
231+ (U, S, Vᴴ, ϵ) = svd_trunc! (A, (USVᴴ... , ϵ), alg)
232+ return compute_error ? (U, S, Vᴴ, ϵ[1 ]) : (U, S, Vᴴ, NaN )
233+ end
234+ function svd_trunc! (A, USVᴴ:: Nothing , alg:: TruncatedAlgorithm ; compute_error:: Bool = true )
235+ ϵ = compute_error ? zeros (real (eltype (A)), 1 ) : zeros (real (eltype (A)), 0 )
236+ U, S, Vᴴ, ϵ = svd_trunc! (A, (USVᴴ, ϵ), alg)
237+ return compute_error ? (U, S, Vᴴ, ϵ[1 ]) : (U, S, Vᴴ, NaN )
213238end
214239
215240# Diagonal logic
@@ -362,16 +387,18 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
362387 return USVᴴ
363388end
364389
365- function svd_trunc! (A:: AbstractMatrix , USVᴴ, alg:: TruncatedAlgorithm{<:GPU_Randomized} )
366- check_input (svd_trunc!, A, USVᴴ, alg . alg)
367- U, S, Vᴴ = USVᴴ
390+ function svd_trunc! (A:: AbstractMatrix , USVᴴϵ :: Tuple{TU, TS, TVᴴ, Tϵ} , alg:: TruncatedAlgorithm{<:GPU_Randomized} ) where {TU, TS, TVᴴ, Tϵ}
391+ U, S, Vᴴ, ϵ = USVᴴϵ
392+ check_input (svd_trunc!, A, ( U, S, Vᴴ), alg . alg)
368393 _gpu_Xgesvdr! (A, S. diag, U, Vᴴ; alg. alg. kwargs... )
369394
370395 # TODO : make sure that truncation is based on maxrank, otherwise this might be wrong
371396 (Utr, Str, Vᴴtr), _ = truncate (svd_trunc!, (U, S, Vᴴ), alg. trunc)
372397
373- # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
374- ϵ = sqrt (norm (A)^ 2 - norm (diagview (Str))^ 2 ) # is there a more accurate way to do this?
398+ if ! isempty (ϵ)
399+ # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
400+ ϵ = sqrt (norm (A)^ 2 - norm (diagview (Str))^ 2 ) # is there a more accurate way to do this?
401+ end
375402
376403 do_gauge_fix = get (alg. alg. kwargs, :fixgauge , default_fixgauge ()):: Bool
377404 do_gauge_fix && gaugefix! (svd_trunc!, Utr, Vᴴtr)
0 commit comments