-
Notifications
You must be signed in to change notification settings - Fork 5
Description
When running variational optimization using a cost function which uses MatrixAlgebraKit.svd_trunc in an intermediate step, and whose value is independent of the normalization of the input, the current choice of tolerances for the svd_trunc pullback can give rise to very inaccurate gradients depending on the initial normalization. This came up for example in QuantumKitHub/PEPSKit.jl#276, but the problem persists with the new implementations in MatrixAlgebraKit.jl.
An example would be
using Random
using LinearAlgebra
using MatrixAlgebraKit
using OptimKit
using Zygote
Random.seed!(123)
# some random matrix
O = randn(ComplexF64, 10, 10)
# some cost function which uses svd_trunc
function fg(t)
e, g = Zygote.withgradient(t) do x
_, S, _ = svd_trunc(x; trunc = truncrank(10))
return real(tr(O * (S / norm(S))))
end
return e, only(g)
end
# prepare a random input
t0 = randn(ComplexF64, 20, 20)
# but artificially make singular values decay exponentially
U, S, V = svd_full(t0)
S´ = S .* diagm(logrange(1.0e0, 1.0e-12, size(S, 1)))
t = U * S´ * VClearly, the value of the cost function is independent of the initial norm of t. Checking the gradient without any prior normalization on t gives a good result,
@show fg(t´)[1]
_, _, dfs1, dfs2 = OptimKit.optimtest(
fg, t´;
alpha = LinRange(-1.0e-5, 1.0e-5, 2),
)
@show dfs1
@show dfs2(fg(t´))[1] = -0.621052432092075
dfs1 = [0.06446034588325489]
dfs2 = [0.06446034587777202]
Running the same for an input with unit norm already gives a significantly less accurate result,
rescaling_factor = norm(t´)
@show fg(t´ / rescaling_factor)[1]
_, _, dfs1, dfs2 = OptimKit.optimtest(
fg, t´ / rescaling_factor;
alpha = LinRange(-1.0e-5, 1.0e-5, 2),
)
@show dfs1
@show dfs2(fg(t´ / rescaling_factor))[1] = -0.6210524320920751
dfs1 = [4.3408483775597695]
dfs2 = [4.909645149555295]
And finally, running the same for an input with a very small norm seems to break things completely,
rescaling_factor = 1.0e3 * norm(t´)
@show fg(t´ / rescaling_factor)[1]
_, _, dfs1, dfs2 = OptimKit.optimtest(
fg, t´ / rescaling_factor;
alpha = LinRange(-1.0e-5, 1.0e-5, 2),
)
@show dfs1
@show dfs2(fg(t´ / rescaling_factor))[1] = -0.6210524320920752
dfs1 = [-2681.1597759100105]
dfs2 = [4.909645149555297e6]
I think the bottom line is that with the default settings, the safe inverse in the SVD pullback uses a tolerance that does not take into account the norm, and therefore simply discards essential information if the norm of the initial tensor is very small. I know that I can fix this by manually setting the tolerance of the pullback to something that scales with the norm of the input, but this is quite annoying to do in practice.
I would really just expect the example here to work out of the box. Is there any reason not to scale (some of) the tolerances in the pullback by the norm of the input of the primal computation?