33copy_input (:: typeof (svd_full), A:: AbstractMatrix ) = copy! (similar (A, float (eltype (A))), A)
44copy_input (:: typeof (svd_compact), A) = copy_input (svd_full, A)
55copy_input (:: typeof (svd_vals), A) = copy_input (svd_full, A)
6- copy_input (:: typeof (svd_trunc), A) = copy_input (svd_compact, A)
6+ copy_input (:: Union{ typeof(svd_trunc), typeof(svd_trunc_no_error)} , A) = copy_input (svd_compact, A)
77
88copy_input (:: typeof (svd_full), A:: Diagonal ) = copy (A)
99
8989function initialize_output (:: typeof (svd_vals!), A:: AbstractMatrix , :: AbstractAlgorithm )
9090 return similar (A, real (eltype (A)), (min (size (A)... ),))
9191end
92- function initialize_output (:: typeof (svd_trunc!), A, alg:: TruncatedAlgorithm )
92+ function initialize_output (:: Union{ typeof(svd_trunc!), typeof(svd_trunc_no_error!)} , A, alg:: TruncatedAlgorithm )
9393 return initialize_output (svd_compact!, A, alg. alg)
9494end
9595
@@ -159,17 +159,17 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
159159 if alg isa LAPACK_QRIteration
160160 isempty (alg_kwargs) ||
161161 throw (ArgumentError (" invalid keyword arguments for LAPACK_QRIteration" ))
162- YALAPACK. gesvd! (A, S . diag , U, Vᴴ)
162+ YALAPACK. gesvd! (A, diagview (S) , U, Vᴴ)
163163 elseif alg isa LAPACK_DivideAndConquer
164164 isempty (alg_kwargs) ||
165165 throw (ArgumentError (" invalid keyword arguments for LAPACK_DivideAndConquer" ))
166- YALAPACK. gesdd! (A, S . diag , U, Vᴴ)
166+ YALAPACK. gesdd! (A, diagview (S) , U, Vᴴ)
167167 elseif alg isa LAPACK_Bisection
168- YALAPACK. gesvdx! (A, S . diag , U, Vᴴ; alg_kwargs... )
168+ YALAPACK. gesvdx! (A, diagview (S) , U, Vᴴ; alg_kwargs... )
169169 elseif alg isa LAPACK_Jacobi
170170 isempty (alg_kwargs) ||
171171 throw (ArgumentError (" invalid keyword arguments for LAPACK_Jacobi" ))
172- YALAPACK. gesvj! (A, S . diag , U, Vᴴ)
172+ YALAPACK. gesvj! (A, diagview (S) , U, Vᴴ)
173173 else
174174 throw (ArgumentError (" Unsupported SVD algorithm" ))
175175 end
@@ -206,19 +206,16 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
206206 return S
207207end
208208
209- function svd_trunc ! (A, USVᴴ:: Tuple{TU, TS, TVᴴ} , alg:: TruncatedAlgorithm ; compute_error :: Bool = true ) where {TU, TS, TVᴴ}
210- ϵ = similar (A, real ( eltype (A)), compute_error )
211- (U, S, Vᴴ, ϵ) = svd_trunc! (A , (USVᴴ ... , ϵ ), alg)
212- return compute_error ? (U, S, Vᴴ, norm (ϵ)) : (U, S, Vᴴ, - one ( eltype (ϵ)))
209+ function svd_trunc_no_error ! (A, USVᴴ, alg:: TruncatedAlgorithm )
210+ U, S, Vᴴ = svd_compact! (A, USVᴴ, alg . alg )
211+ USVᴴtrunc, ind = truncate ( svd_trunc!, (U, S, Vᴴ ), alg. trunc )
212+ return USVᴴtrunc
213213end
214214
215- function svd_trunc! (A, USVᴴϵ:: Tuple{TU, TS, TVᴴ, Tϵ} , alg:: TruncatedAlgorithm ) where {TU, TS, TVᴴ, Tϵ}
216- U, S, Vᴴ, ϵ = USVᴴϵ
217- U, S, Vᴴ = svd_compact! (A, (U, S, Vᴴ), alg. alg)
215+ function svd_trunc! (A, USVᴴ, alg:: TruncatedAlgorithm )
216+ U, S, Vᴴ = svd_compact! (A, USVᴴ, alg. alg)
218217 USVᴴtrunc, ind = truncate (svd_trunc!, (U, S, Vᴴ), alg. trunc)
219- if ! isempty (ϵ)
220- ϵ .= truncation_error! (diagview (S), ind)
221- end
218+ ϵ = truncation_error! (diagview (S), ind)
222219 return USVᴴtrunc... , ϵ
223220end
224221
272269# ##
273270
274271function check_input (
275- :: typeof (svd_trunc!), A:: AbstractMatrix , USVᴴ, alg:: CUSOLVER_Randomized
272+ :: Union{ typeof(svd_trunc!), typeof(svd_trunc_no_error!)} , A:: AbstractMatrix , USVᴴ, alg:: CUSOLVER_Randomized
276273 )
277274 m, n = size (A)
278275 minmn = min (m, n)
@@ -288,7 +285,7 @@ function check_input(
288285end
289286
290287function initialize_output (
291- :: typeof (svd_trunc!), A:: AbstractMatrix , alg:: TruncatedAlgorithm{<:CUSOLVER_Randomized}
288+ :: Union{ typeof(svd_trunc!), typeof(svd_trunc_no_error!)} , A:: AbstractMatrix , alg:: TruncatedAlgorithm{<:CUSOLVER_Randomized}
292289 )
293290 m, n = size (A)
294291 minmn = min (m, n)
@@ -372,22 +369,34 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
372369 return USVᴴ
373370end
374371
375- function svd_trunc! (A:: AbstractMatrix , USVᴴϵ:: Tuple{TU, TS, TVᴴ, Tϵ} , alg:: TruncatedAlgorithm{<:GPU_Randomized} ) where {TU, TS, TVᴴ, Tϵ}
376- U, S, Vᴴ, ϵ = USVᴴϵ
372+ function svd_trunc_no_error! (A:: AbstractMatrix , USVᴴ, alg:: TruncatedAlgorithm{<:GPU_Randomized} )
373+ U, S, Vᴴ = USVᴴ
374+ check_input (svd_trunc_no_error!, A, (U, S, Vᴴ), alg. alg)
375+ _gpu_Xgesvdr! (A, diagview (S), U, Vᴴ; alg. alg. kwargs... )
376+
377+ # TODO : make sure that truncation is based on maxrank, otherwise this might be wrong
378+ (Utr, Str, Vᴴtr), _ = truncate (svd_trunc!, (U, S, Vᴴ), alg. trunc)
379+
380+ do_gauge_fix = get (alg. alg. kwargs, :fixgauge , default_fixgauge ()):: Bool
381+ do_gauge_fix && gaugefix! (svd_trunc!, Utr, Vᴴtr)
382+
383+ return Utr, Str, Vᴴtr
384+ end
385+
386+ function svd_trunc! (A:: AbstractMatrix , USVᴴ, alg:: TruncatedAlgorithm{<:GPU_Randomized} )
387+ U, S, Vᴴ = USVᴴ
377388 check_input (svd_trunc!, A, (U, S, Vᴴ), alg. alg)
378- _gpu_Xgesvdr! (A, S . diag , U, Vᴴ; alg. alg. kwargs... )
389+ _gpu_Xgesvdr! (A, diagview (S) , U, Vᴴ; alg. alg. kwargs... )
379390
380391 # TODO : make sure that truncation is based on maxrank, otherwise this might be wrong
381392 (Utr, Str, Vᴴtr), _ = truncate (svd_trunc!, (U, S, Vᴴ), alg. trunc)
382393
383- if ! isempty (ϵ)
384- # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
385- normS = norm (diagview (Str))
386- normA = norm (A)
387- # equivalent to sqrt(normA^2 - normS^2)
388- # but may be more accurate
389- ϵ = sqrt ((normA + normS) * (normA - normS))
390- end
394+ # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
395+ normS = norm (diagview (Str))
396+ normA = norm (A)
397+ # equivalent to sqrt(normA^2 - normS^2)
398+ # but may be more accurate
399+ ϵ = sqrt ((normA + normS) * (normA - normS))
391400
392401 do_gauge_fix = get (alg. alg. kwargs, :fixgauge , default_fixgauge ()):: Bool
393402 do_gauge_fix && gaugefix! (svd_trunc!, Utr, Vᴴtr)
@@ -404,11 +413,11 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
404413
405414 if alg isa GPU_QRIteration
406415 isempty (alg_kwargs) || @warn " invalid keyword arguments for GPU_QRIteration"
407- _gpu_gesvd_maybe_transpose! (A, S . diag , U, Vᴴ)
416+ _gpu_gesvd_maybe_transpose! (A, diagview (S) , U, Vᴴ)
408417 elseif alg isa GPU_SVDPolar
409- _gpu_Xgesvdp! (A, S . diag , U, Vᴴ; alg_kwargs... )
418+ _gpu_Xgesvdp! (A, diagview (S) , U, Vᴴ; alg_kwargs... )
410419 elseif alg isa GPU_Jacobi
411- _gpu_gesvdj! (A, S . diag , U, Vᴴ; alg_kwargs... )
420+ _gpu_gesvdj! (A, diagview (S) , U, Vᴴ; alg_kwargs... )
412421 else
413422 throw (ArgumentError (" Unsupported SVD algorithm" ))
414423 end
0 commit comments