Skip to content

Conversation

@lkdvos
Copy link
Member

@lkdvos lkdvos commented Dec 2, 2025

Here I write chainrules for the different f_vals functions.
To make these implementations sufficiently generic, I added f_vals_pullback functions that may be overloaded, but usually don't have to be since they simply fall back to f_pullback with the correct zerotangents filled in.

A small hiccup here is that in order to do this, I effectively require the inverse operation of diagview, which I introduced here as diagonal. I didn't want to simply use Diagonal, since that has the restriction that it has to return a Diagonal type, and MatrixAlgebraKit feels like the correct place to introduce such a function.

Finally, I also updated the mooncake rules to make use of the updated pullback functions. As I'm still getting the hang of that, it might be nice to double-check that I did that correctly.

This PR fixes #88.

@lkdvos lkdvos requested review from Jutho and kshyatt December 2, 2025 15:26
@codecov
Copy link

codecov bot commented Dec 2, 2025

Codecov Report

❌ Patch coverage is 91.89189% with 3 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
ext/MatrixAlgebraKitChainRulesCoreExt.jl 83.33% 3 Missing ⚠️
Files with missing lines Coverage Δ
...gebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl 98.97% <100.00%> (+<0.01%) ⬆️
src/MatrixAlgebraKit.jl 100.00% <ø> (ø)
src/common/view.jl 100.00% <100.00%> (ø)
src/pullbacks/eig.jl 96.10% <100.00%> (+0.15%) ⬆️
src/pullbacks/eigh.jl 91.42% <100.00%> (+0.38%) ⬆️
src/pullbacks/svd.jl 96.36% <100.00%> (+0.10%) ⬆️
ext/MatrixAlgebraKitChainRulesCoreExt.jl 81.42% <83.33%> (+0.28%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kshyatt
Copy link
Member

kshyatt commented Dec 2, 2025

LGTM although why not add these special pullbacks (which are not needed for Enzyme or Mooncake) to the chain rules extension?

@lkdvos
Copy link
Member Author

lkdvos commented Dec 2, 2025

I think we might have to specialize the TensorMap implementations, and then it is easier and slightly more consistent to have these pullbacks as functions that are available in the main package.
Additionally, I think it looks slightly cleaner to migrate the logic of filling in the zerotangents and mapping from diagonal vector to matrix out of the Mooncake implementation (obviously these are very small functions so manually inlining them would also work), so maybe just a code organization thing?

@lkdvos lkdvos enabled auto-merge (squash) December 2, 2025 18:09
@lkdvos lkdvos merged commit cea4178 into main Dec 2, 2025
10 checks passed
@lkdvos lkdvos deleted the ld-valsrrules branch December 2, 2025 19:27
@Jutho
Copy link
Member

Jutho commented Dec 2, 2025

Ok this went fast. I still have to study the PR in a bit more detail. Was it worth adding the f_vals_pullback in the generic pullbacks, rather than just doing the lowering ΔD -> ΔDV = (ΔD, nothing) (and similar for svd) in the ChainRules/Mooncake bindings ?

With f_trunc, I followed the strategy that either you go via f_pullback in the bindings, which requiresthe output of f_full, or there is a dedicated f_trunc_pullback, which only requires the output of the primal calculation of f_trunc. So by analogy, I would expect that f_vals_pullback should be a function that only depends on A , the output of f_vals (which is just vals) and Δvals. However, I don't actually think such a function can exist, i.e. you need the output of f_full in order to compute the pullback.

@Jutho
Copy link
Member

Jutho commented Dec 2, 2025

Oh I also only now see that you did actually discuss this above 😄

@lkdvos
Copy link
Member Author

lkdvos commented Dec 3, 2025

Apologies for the fast merge, I thought this was a pretty small change somehow.
Given that nothing is released, we can definitely still make changes here if you like, although I do think that specifically for the comments above the main point is to leave a hook to overload which does not require overloading the rrules of either ChainRules or Mooncake.

@Jutho
Copy link
Member

Jutho commented Dec 3, 2025

No problem. Yes I agree such a hook is useful, and since the pullback cannot be computed without knowing V (or U and Vᴴ in the SVD case), I think the current interface is ok.

@lkdvos lkdvos mentioned this pull request Dec 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Autodiff] svd_vals is not differentiable

4 participants