Skip to content

Commit 2dbc5e6

Browse files
author
Katharine Hyatt
committed
Add optional epsilon vector and keyword arg for error
1 parent ba24875 commit 2dbc5e6

File tree

1 file changed

+34
-7
lines changed

1 file changed

+34
-7
lines changed

src/implementations/svd.jl

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,35 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
206206
return S
207207
end
208208

209-
function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
209+
# nothing case here to handle GenericLinearAlgebra
210+
function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm) where {TU, TS, TVᴴ, Tϵ}
211+
U, S, Vᴴ, ϵ = USVᴴϵ
212+
U, S, Vᴴ = svd_compact!(A, (U, S, Vᴴ), alg.alg)
213+
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
214+
if !isempty(ϵ)
215+
ϵ[1] = truncation_error!(diagview(S), ind)
216+
end
217+
return USVᴴtrunc..., ϵ
218+
end
219+
function svd_trunc!(A, USVᴴϵ::Tuple{Nothing, Tϵ}, alg::TruncatedAlgorithm) where {Tϵ}
220+
USVᴴ, ϵ = USVᴴϵ
210221
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
211222
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
212-
return USVᴴtrunc..., truncation_error!(diagview(S), ind)
223+
if !isempty(ϵ)
224+
ϵ[1] = truncation_error!(diagview(S), ind)
225+
end
226+
return USVᴴtrunc..., ϵ
227+
end
228+
229+
function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool=true) where {TU, TS, TVᴴ}
230+
ϵ = compute_error ? zeros(real(eltype(A)), 1) : zeros(real(eltype(A)), 0)
231+
(U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg)
232+
return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, NaN)
233+
end
234+
function svd_trunc!(A, USVᴴ::Nothing, alg::TruncatedAlgorithm; compute_error::Bool=true)
235+
ϵ = compute_error ? zeros(real(eltype(A)), 1) : zeros(real(eltype(A)), 0)
236+
U, S, Vᴴ, ϵ = svd_trunc!(A, (USVᴴ, ϵ), alg)
237+
return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, NaN)
213238
end
214239

215240
# Diagonal logic
@@ -362,16 +387,18 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
362387
return USVᴴ
363388
end
364389

365-
function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
366-
check_input(svd_trunc!, A, USVᴴ, alg.alg)
367-
U, S, Vᴴ = USVᴴ
390+
function svd_trunc!(A::AbstractMatrix, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm{<:GPU_Randomized}) where {TU, TS, TVᴴ, Tϵ}
391+
U, S, Vᴴ, ϵ = USVᴴϵ
392+
check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg)
368393
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)
369394

370395
# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
371396
(Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
372397

373-
# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
374-
ϵ = sqrt(norm(A)^2 - norm(diagview(Str))^2) # is there a more accurate way to do this?
398+
if !isempty(ϵ)
399+
# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
400+
ϵ = sqrt(norm(A)^2 - norm(diagview(Str))^2) # is there a more accurate way to do this?
401+
end
375402

376403
do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool
377404
do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr)

0 commit comments

Comments
 (0)