Skip to content

Commit cf31cb4

Browse files
committed
Pass compute_error and epsilon as arguments
This PR adds a `compute_error` field to `TruncatedAlgorithm` and changes `eig_trunc!`, `eigh_trunc!`, and `svd_trunc!` (and their non-mutating counterparts) to accept an array `epsilon` as part of their arguments. The purpose of this is twofold: to make handling the truncation error easier in AD, and to avoid forcing GPU synchronization.
1 parent cf57841 commit cf31cb4

File tree

14 files changed

+90
-80
lines changed

14 files changed

+90
-80
lines changed

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ for eig in (:eig, :eigh)
118118
Ac = copy_input($eig_f, A)
119119
DV = $(eig_f!)(Ac, DV, alg.alg)
120120
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
121-
ϵ = truncation_error(diagview(DV[1]), ind)
121+
ϵ = [truncation_error(diagview(DV[1]), ind)]
122122
return (DV′..., ϵ), $(_make_eig_t_pb)(A, DV, ind)
123123
end
124124
function $(_make_eig_t_pb)(A, DV, ind)
@@ -174,7 +174,7 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg
174174
Ac = copy_input(svd_compact, A)
175175
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
176176
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
177-
ϵ = truncation_error(diagview(USVᴴ[2]), ind)
177+
ϵ = [truncation_error(diagview(USVᴴ[2]), ind)]
178178
return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind)
179179
end
180180
function _make_svd_trunc_pullback(A, USVᴴ, ind)

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,9 @@ for (f, pb, adj) in (
179179
# pass). For many types this is done automatically when the forward step returns, but
180180
# not for nested structs with various fields (like Diagonal{Complex})
181181
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
182-
function $adj(dy::Tuple{NoRData, NoRData, T}) where {T <: Real}
182+
function $adj(::NoRData)
183183
Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual)
184184
dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual)
185-
abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error"
186185
D, dD = arrayify(Dtrunc, dDtrunc_)
187186
V, dV = arrayify(Vtrunc, dVtrunc_)
188187
$pb(dA, A, (D, V), (dD, dV))
@@ -316,10 +315,9 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
316315
# pass). For many types this is done automatically when the forward step returns, but
317316
# not for nested structs with various fields (like Diagonal{Complex})
318317
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
319-
function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real}
318+
function svd_trunc_adjoint(::NoRData)
320319
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
321320
dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual)
322-
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"
323321
U, dU = arrayify(Utrunc, dUtrunc_)
324322
S, dS = arrayify(Strunc, dStrunc_)
325323
Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_)

src/algorithms.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,9 @@ truncation through `trunc`.
228228
struct TruncatedAlgorithm{A, T} <: AbstractAlgorithm
229229
alg::A
230230
trunc::T
231+
compute_error::Bool
231232
end
233+
TruncatedAlgorithm(alg::A, trunc::T; compute_error::Bool = true) where {A <: AbstractAlgorithm, T} = TruncatedAlgorithm{A, T}(alg, trunc, compute_error)
232234

233235
does_truncate(::TruncatedAlgorithm) = true
234236

src/implementations/eig.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ function initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::AbstractAlg
6666
return D
6767
end
6868
function initialize_output(::typeof(eig_trunc!), A, alg::TruncatedAlgorithm)
69-
return initialize_output(eig_full!, A, alg.alg)
69+
DV = initialize_output(eig_full!, A, alg.alg)
70+
ϵ = similar(A, real(eltype(A)), alg.compute_error)
71+
return (DV..., ϵ)
7072
end
7173

7274
function initialize_output(::typeof(eig_full!), A::Diagonal, ::DiagonalAlgorithm)
@@ -115,10 +117,14 @@ function eig_vals!(A::AbstractMatrix, D, alg::LAPACK_EigAlgorithm)
115117
return D
116118
end
117119

118-
function eig_trunc!(A, DV, alg::TruncatedAlgorithm)
119-
D, V = eig_full!(A, DV, alg.alg)
120+
function eig_trunc!(A, DVϵ, alg::TruncatedAlgorithm)
121+
D, V, ϵ = DVϵ
122+
D, V = eig_full!(A, (D, V), alg.alg)
120123
DVtrunc, ind = truncate(eig_trunc!, (D, V), alg.trunc)
121-
return DVtrunc..., truncation_error!(diagview(D), ind)
124+
if !isempty(ϵ)
125+
ϵ .= truncation_error!(diagview(D), ind)
126+
end
127+
return DVtrunc..., ϵ
122128
end
123129

124130
# Diagonal logic

src/implementations/eigh.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ function initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::AbstractAl
7575
return D
7676
end
7777
function initialize_output(::typeof(eigh_trunc!), A, alg::TruncatedAlgorithm)
78-
return initialize_output(eigh_full!, A, alg.alg)
78+
DV = initialize_output(eigh_full!, A, alg.alg)
79+
ϵ = similar(A, real(eltype(A)), alg.compute_error)
80+
return (DV..., ϵ)
7981
end
8082

8183
function initialize_output(::typeof(eigh_full!), A::Diagonal, ::DiagonalAlgorithm)
@@ -129,10 +131,14 @@ function eigh_vals!(A::AbstractMatrix, D, alg::LAPACK_EighAlgorithm)
129131
return D
130132
end
131133

132-
function eigh_trunc!(A, DV, alg::TruncatedAlgorithm)
133-
D, V = eigh_full!(A, DV, alg.alg)
134+
function eigh_trunc!(A, DVϵ, alg::TruncatedAlgorithm)
135+
D, V, ϵ = DVϵ
136+
D, V = eigh_full!(A, (D, V), alg.alg)
134137
DVtrunc, ind = truncate(eigh_trunc!, (D, V), alg.trunc)
135-
return DVtrunc..., truncation_error!(diagview(D), ind)
138+
if !isempty(ϵ)
139+
ϵ .= truncation_error!(diagview(D), ind)
140+
end
141+
return DVtrunc..., ϵ
136142
end
137143

138144
# Diagonal logic

src/implementations/svd.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ function initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::AbstractAlg
9090
return similar(A, real(eltype(A)), (min(size(A)...),))
9191
end
9292
function initialize_output(::typeof(svd_trunc!), A, alg::TruncatedAlgorithm)
93-
return initialize_output(svd_compact!, A, alg.alg)
93+
USVᴴ = initialize_output(svd_compact!, A, alg.alg)
94+
ϵ = similar(A, real(eltype(A)), alg.compute_error)
95+
return (USVᴴ..., ϵ)
9496
end
9597

9698
function initialize_output(::typeof(svd_full!), A::Diagonal, ::DiagonalAlgorithm)
@@ -206,12 +208,6 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
206208
return S
207209
end
208210

209-
function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ}
210-
ϵ = similar(A, real(eltype(A)), compute_error)
211-
(U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg)
212-
return compute_error ? (U, S, Vᴴ, norm(ϵ)) : (U, S, Vᴴ, -one(eltype(ϵ)))
213-
end
214-
215211
function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm) where {TU, TS, TVᴴ, Tϵ}
216212
U, S, Vᴴ, ϵ = USVᴴϵ
217213
U, S, Vᴴ = svd_compact!(A, (U, S, Vᴴ), alg.alg)
@@ -272,18 +268,19 @@ end
272268
###
273269

274270
function check_input(
275-
::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized
271+
::typeof(svd_trunc!), A::AbstractMatrix, USVᴴϵ, alg::CUSOLVER_Randomized
276272
)
277273
m, n = size(A)
278274
minmn = min(m, n)
279-
U, S, Vᴴ = USVᴴ
275+
U, S, Vᴴ, ϵ = USVᴴϵ
280276
@assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix
281277
@check_size(U, (m, m))
282278
@check_scalar(U, A)
283279
@check_size(S, (minmn, minmn))
284280
@check_scalar(S, A, real)
285281
@check_size(Vᴴ, (n, n))
286282
@check_scalar(Vᴴ, A)
283+
@check_scalar(ϵ, A, real)
287284
return nothing
288285
end
289286

@@ -295,7 +292,8 @@ function initialize_output(
295292
U = similar(A, (m, m))
296293
S = Diagonal(similar(A, real(eltype(A)), (minmn,)))
297294
Vᴴ = similar(A, (n, n))
298-
return (U, S, Vᴴ)
295+
ϵ = similar(A, real(eltype(A)), alg.compute_error)
296+
return (U, S, Vᴴ, ϵ)
299297
end
300298

301299
function _gpu_gesvd!(

test/chainrules.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ end
247247
ΔVtrunc = ΔV[:, ind]
248248
test_rrule(
249249
copy_eig_trunc, A, truncalg NoTangent();
250-
output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))),
250+
output_tangent = (ΔDtrunc, ΔVtrunc, [zero(real(T))]),
251251
atol = atol, rtol = rtol
252252
)
253253
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
@@ -262,7 +262,7 @@ end
262262
ΔVtrunc = ΔV[:, ind]
263263
test_rrule(
264264
copy_eig_trunc, A, truncalg NoTangent();
265-
output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))),
265+
output_tangent = (ΔDtrunc, ΔVtrunc, [zero(real(T))]),
266266
atol = atol, rtol = rtol
267267
)
268268
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
@@ -328,7 +328,7 @@ end
328328
ΔVtrunc = ΔV[:, ind]
329329
test_rrule(
330330
copy_eigh_trunc, A, truncalg NoTangent();
331-
output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))),
331+
output_tangent = (ΔDtrunc, ΔVtrunc, [zero(real(T))]),
332332
atol = atol, rtol = rtol
333333
)
334334
dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
@@ -343,7 +343,7 @@ end
343343
ΔVtrunc = ΔV[:, ind]
344344
test_rrule(
345345
copy_eigh_trunc, A, truncalg NoTangent();
346-
output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))),
346+
output_tangent = (ΔDtrunc, ΔVtrunc, [zero(real(T))]),
347347
atol = atol, rtol = rtol
348348
)
349349
dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
@@ -380,7 +380,7 @@ end
380380
test_rrule(
381381
config, eigh_trunc2, A;
382382
fkwargs = (; trunc = trunc),
383-
output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))),
383+
output_tangent = (ΔD[ind, ind], ΔV[:, ind], [zero(real(T))]),
384384
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
385385
)
386386
end
@@ -389,7 +389,7 @@ end
389389
test_rrule(
390390
config, eigh_trunc2, A;
391391
fkwargs = (; trunc = trunc),
392-
output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))),
392+
output_tangent = (ΔD[ind, ind], ΔV[:, ind], [zero(real(T))]),
393393
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
394394
)
395395
end
@@ -431,7 +431,7 @@ end
431431
ΔVᴴtrunc = ΔVᴴ[ind, :]
432432
test_rrule(
433433
copy_svd_trunc, A, truncalg NoTangent();
434-
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))),
434+
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, [zero(real(T))]),
435435
atol = atol, rtol = rtol
436436
)
437437
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
@@ -448,7 +448,7 @@ end
448448
ΔVᴴtrunc = ΔVᴴ[ind, :]
449449
test_rrule(
450450
copy_svd_trunc, A, truncalg NoTangent();
451-
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))),
451+
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, [zero(real(T))]),
452452
atol = atol, rtol = rtol
453453
)
454454
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
@@ -477,7 +477,7 @@ end
477477
test_rrule(
478478
config, svd_trunc, A;
479479
fkwargs = (; trunc = trunc),
480-
output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))),
480+
output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], [zero(real(T))]),
481481
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
482482
)
483483
end
@@ -486,7 +486,7 @@ end
486486
test_rrule(
487487
config, svd_trunc, A;
488488
fkwargs = (; trunc = trunc),
489-
output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))),
489+
output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], [zero(real(T))]),
490490
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
491491
)
492492
end

test/eig.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,21 +48,21 @@ end
4848
D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r))
4949
@test length(diagview(D1)) == r
5050
@test A * V1 V1 * D1
51-
@test ϵ1 norm(view(D₀, (r + 1):m)) atol = atol
51+
@test norm(ϵ1) norm(view(D₀, (r + 1):m)) atol = atol
5252

5353
s = 1 + sqrt(eps(real(T)))
5454
trunc = trunctol(; atol = s * abs(D₀[r + 1]))
5555
D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc)
5656
@test length(diagview(D2)) == r
5757
@test A * V2 V2 * D2
58-
@test ϵ2 norm(view(D₀, (r + 1):m)) atol = atol
58+
@test norm(ϵ2) norm(view(D₀, (r + 1):m)) atol = atol
5959

6060
s = 1 - sqrt(eps(real(T)))
6161
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
6262
D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc)
6363
@test length(diagview(D3)) == r
6464
@test A * V3 V3 * D3
65-
@test ϵ3 norm(view(D₀, (r + 1):m)) atol = atol
65+
@test norm(ϵ3) norm(view(D₀, (r + 1):m)) atol = atol
6666

6767
# trunctol keeps order, truncrank might not
6868
# test for same subspace
@@ -83,13 +83,13 @@ end
8383
alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2))
8484
D2, V2, ϵ2 = @constinferred eig_trunc(A; alg)
8585
@test diagview(D2) diagview(D)[1:2]
86-
@test ϵ2 norm(diagview(D)[3:4]) atol = atol
86+
@test norm(ϵ2) norm(diagview(D)[3:4]) atol = atol
8787
@test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2))
8888

8989
alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1))
9090
D3, V3, ϵ3 = @constinferred eig_trunc(A; alg)
9191
@test diagview(D3) diagview(D)[1:2]
92-
@test ϵ3 norm(diagview(D)[3:4]) atol = atol
92+
@test norm(ϵ3) norm(diagview(D)[3:4]) atol = atol
9393
end
9494

9595
@testset "eig for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
@@ -112,5 +112,5 @@ end
112112
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
113113
D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg)
114114
@test diagview(D2) diagview(A2)[1:2]
115-
@test ϵ2 norm(diagview(A2)[3:4]) atol = atol
115+
@test norm(ϵ2) norm(diagview(A2)[3:4]) atol = atol
116116
end

test/eigh.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,21 +57,21 @@ end
5757
@test isisometric(V1)
5858
@test A * V1 V1 * D1
5959
@test LinearAlgebra.opnorm(A - V1 * D1 * V1') D₀[r + 1]
60-
@test ϵ1 norm(view(D₀, (r + 1):m)) atol = atol
60+
@test norm(ϵ1) norm(view(D₀, (r + 1):m)) atol = atol
6161

6262
trunc = trunctol(; atol = s * D₀[r + 1])
6363
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc)
6464
@test length(diagview(D2)) == r
6565
@test isisometric(V2)
6666
@test A * V2 V2 * D2
67-
@test ϵ2 norm(view(D₀, (r + 1):m)) atol = atol
67+
@test norm(ϵ2) norm(view(D₀, (r + 1):m)) atol = atol
6868

6969
s = 1 - sqrt(eps(real(T)))
7070
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
7171
D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc)
7272
@test length(diagview(D3)) == r
7373
@test A * V3 V3 * D3
74-
@test ϵ3 norm(view(D₀, (r + 1):m)) atol = atol
74+
@test norm(ϵ3) norm(view(D₀, (r + 1):m)) atol = atol
7575

7676
# test for same subspace
7777
@test V1 * (V1' * V2) V2
@@ -93,12 +93,12 @@ end
9393
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg)
9494
@test diagview(D2) diagview(D)[1:2]
9595
@test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2))
96-
@test ϵ2 norm(diagview(D)[3:4]) atol = atol
96+
@test norm(ϵ2) norm(diagview(D)[3:4]) atol = atol
9797

9898
alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncerror(; atol = 0.2))
9999
D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg)
100100
@test diagview(D3) diagview(D)[1:2]
101-
@test ϵ3 norm(diagview(D)[3:4]) atol = atol
101+
@test norm(ϵ3) norm(diagview(D)[3:4]) atol = atol
102102
end
103103

104104
@testset "eigh for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
@@ -122,5 +122,5 @@ end
122122
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
123123
D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg)
124124
@test diagview(D2) diagview(A2)[1:2]
125-
@test ϵ2 norm(diagview(A2)[3:4]) atol = atol
125+
@test norm(ϵ2) norm(diagview(A2)[3:4]) atol = atol
126126
end

test/genericlinearalgebra/eigh.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,21 @@ end
4949
@test isisometric(V1)
5050
@test A * V1 V1 * D1
5151
@test LinearAlgebra.opnorm(A - V1 * D1 * V1') D₀[r + 1]
52-
@test ϵ1 norm(view(D₀, (r + 1):m)) atol = atol
52+
@test norm(ϵ1) norm(view(D₀, (r + 1):m)) atol = atol
5353

5454
trunc = trunctol(; atol = s * D₀[r + 1])
5555
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc)
5656
@test length(diagview(D2)) == r
5757
@test isisometric(V2)
5858
@test A * V2 V2 * D2
59-
@test ϵ2 norm(view(D₀, (r + 1):m)) atol = atol
59+
@test norm(ϵ2) norm(view(D₀, (r + 1):m)) atol = atol
6060

6161
s = 1 - sqrt(eps(real(T)))
6262
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
6363
D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc)
6464
@test length(diagview(D3)) == r
6565
@test A * V3 V3 * D3
66-
@test ϵ3 norm(view(D₀, (r + 1):m)) atol = atol
66+
@test norm(ϵ3) norm(view(D₀, (r + 1):m)) atol = atol
6767

6868
# test for same subspace
6969
@test V1 * (V1' * V2) V2
@@ -84,10 +84,10 @@ end
8484
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg)
8585
@test diagview(D2) diagview(D)[1:2]
8686
@test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2))
87-
@test ϵ2 norm(diagview(D)[3:4]) atol = atol
87+
@test norm(ϵ2) norm(diagview(D)[3:4]) atol = atol
8888

8989
alg = TruncatedAlgorithm(GLA_QRIteration(), truncerror(; atol = 0.2))
9090
D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg)
9191
@test diagview(D3) diagview(D)[1:2]
92-
@test ϵ3 norm(diagview(D)[3:4]) atol = atol
92+
@test norm(ϵ3) norm(diagview(D)[3:4]) atol = atol
9393
end

0 commit comments

Comments
 (0)