Skip to content

feat: Add STCH aggregator#529

Open
rkhosrowshahi wants to merge 3 commits intoSimplexLab:mainfrom
rkhosrowshahi:add-stch-aggregator
Open

feat: Add STCH aggregator#529
rkhosrowshahi wants to merge 3 commits intoSimplexLab:mainfrom
rkhosrowshahi:add-stch-aggregator

Conversation

@rkhosrowshahi
Copy link
Contributor

Summary

Key Features

  • mu parameter: Controls smoothness of aggregation (smaller = harder max focusing on worst task, larger = uniform averaging)
  • warmup_steps parameter (optional): Number of steps to accumulate gradient norms for nadir vector computation. If None, no warmup is performed.
  • reset() method: Clears internal state (step counter, nadir vector) between experiments

Implementation Notes

The original STCH algorithm operates on loss values with a warmup phase. This implementation adapts it for gradient-based aggregation:

  • Uses gradient norms (from Gramian diagonal) as proxies for task performance
  • Follows the stateful pattern similar to NashMTL for the warmup mechanism

Test Plan

  • Unit tests for expected structure and permutation invariance
  • Tests for parameter validation (invalid mu, warmup_steps)
  • Tests for warmup behavior (uniform weights during warmup, nadir computation after)
  • Tests for reset() functionality
  • Tests for edge cases (small/large mu values)

rkhosrowshahi and others added 2 commits January 25, 2026 16:27
Implement the Smooth Tchebycheff (STCH) scalarization algorithm from
"Smooth Tchebycheff Scalarization for Multi-Objective Optimization"
(https://arxiv.org/abs/2402.19078).

The aggregator uses log-sum-exp (smooth maximum) to compute weights that
focus more on poorly performing tasks. Key features:

- mu parameter controls smoothness (smaller = harder max, larger = uniform)
- Optional warmup_steps for computing nadir vector from gradient norms
- reset() method for clearing state between experiments
- Both STCH (Aggregator) and STCHWeighting (Weighting) classes provided
@codecov
Copy link

codecov bot commented Jan 25, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
src/torchjd/aggregation/__init__.py 100.00% <100.00%> (ø)
src/torchjd/aggregation/_stch.py 100.00% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Add an additional forward() call in test_warmup_step_counter to cover
line 207 in _stch.py (the else branch that increments step during
steady-state operation after warmup completes).
@rkhosrowshahi
Copy link
Contributor Author

Fixed the uncovered line (line 207 in _stch.py) by adding an additional forward() call in test_warmup_step_counter.

The uncovered line was the else branch that increments the step counter during steady-state operation (after warmup completes and nadir_vector is already set):

else:
    self.step += 1

The test now covers:

  • Calls 1-3: During warmup (step < warmup_steps)
  • Call 4: First step after warmup (nadir_vector gets computed)
  • Call 5 (new): Steady-state operation (nadir_vector already exists) ✓

@ValerianRey
Copy link
Contributor

Thanks for this PR @rkhosrowshahi! I'll review it tomorrow most likely. Also, I just changed the actions rules so that checks run automatically now (without maintainer approval every time), so you can work without interruption.

Copy link
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR.

I will add another review later on, but first I am curious about something. From what I read in the paper, STCH is a scalarization method (i.e. we compute some combination of the coordinate of the objective and differentiate that). The main advantages of those method is that they don't require computing the full Jacobian (or Gramian I guess). But here, this implementation uses the norms of the gradients which itself requires at least the Gramian. I couldn't find the description of the latter in the paper or in the repository, could you provide me with a direct link to their code?

Another note: I think we need to add a line in the table of the README.md (there might be another place where we have to add reference to this).

Comment on lines +87 to +94
if mu <= 0.0:
raise ValueError(f"Parameter `mu` should be a positive float. Found `mu = {mu}`.")

if warmup_steps is not None and warmup_steps < 1:
raise ValueError(
f"Parameter `warmup_steps` should be a positive integer or None. "
f"Found `warmup_steps = {warmup_steps}`."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would move those checks to the Weighting as they are supposed to be checked also there.

@ValerianRey
Copy link
Contributor

From what I read in the paper, STCH is a scalarization method (i.e. we compute some combination of the coordinate of the objective and differentiate that). The main advantages of those method is that they don't require computing the full Jacobian (or Gramian I guess). But here, this implementation uses the norms of the gradients which itself requires at least the Gramian.

I completely agree with this. That being said, should we add a package to TorchJD for scalarization methods? It seems that the MTL community would benefit from that, even though it's not the main scope of the library. I think it can be quite community-driven, and it would give a more modern alternative to LibMTL (which I think is a bit dying these days).

@PierreQuinton
Copy link
Contributor

Sorry this is mostly me thinking out loud, it doesn't contain much information.

That would be nice indeed, this means that most MTL method can be handled by us, and we take care of maintaining those. Still it feels almost silly to have something as simple as:

losses = ...
something = STCH(losses)
something.backward()

Where STCH basically just compute some scalarization of losses.

I guess the main advantage of having this in this library is that we can then handle all the parametrizations. However, we will not be able to have one interface for all possible MTL methods (or maybe we could, by taking the losses and the parameters?)

@ValerianRey
Copy link
Contributor

Scalarization methods are quite the opposite of what we're doing with JD, but it would be nice to have access to them to be able to compare I think.

Something like:

from torchjd.scalarization import STCH

...
losses = ...
loss = STCH(losses)
loss.backward()

doesn't seem silly, just simple, and that's what we want.

However, we will not be able to have one interface for all possible MTL methods

That's a solid argument. I think if we would be able to support most methods without many changes to the structure of torchjd (e.g. just adding a scalarization package), we could really add them, but if we can't, it's a bit of a lost hope IMO. For methods that need access to the losses (or really anything) we can give these values at the initialization of the aggregator. So at least those cases can be covered. Stateful methods could also be added if we just add a warning saying that these aren't strictly aggregators in terms of the mathematical definition of aggregator.

Before taking a decision, we could look at the methods from LibMTL that we didn't include in TorchJD, and see if it would be possible to include them:

@rkhosrowshahi
Copy link
Contributor Author

Thanks to both of you for thinking about this PR. I didn't think it was going to hit long, but I hope I didn't cross any line :)

Yes, scalarization methods differ from others in that they use loss values rather than the Jacobian matrix when computing gradients. When I was implementing from LibMTL's STCH, I noticed the large wall in front of STCH because TorchJD passes the Gramian matrix to the aggregator. So, I had to use the grad norms as the basis for computing the weights in the STCH weighting algorithm. I am now not sure if
grad_norm / nadir_vector
is equal to
loss / nadir_vector
or not.

I really liked the idea of a separate package for STCH, as it would be easier for the user to implement, as you have shown. It will avoid using the Jacobian matrix and be more efficient.

I know this is not the place to ask, but since it is discussed, I would like to know: Is TorchJD willing to cover all sorts of MOO algorithms, including stateless, stateful, single Pareto solution, finite/infinite set of Pareto solutions, in the future?

@PierreQuinton
Copy link
Contributor

PierreQuinton commented Jan 26, 2026

I am now not sure if grad_norm / nadir_vector is equal to loss / nadir_vector or not.

I'm pretty sure this cannot hold.

Is TorchJD willing to cover all sorts of MOO algorithms, including stateless, stateful, single Pareto solution, finite/infinite set of Pareto solutions, in the future?

It is a matter of the responsibility we are able to take. For now, we are responsible for implementation and optimization of methods based on the Jacobian and/or its Gramian (Jacobian descent). Sadly, that is already a bit too much for 2 people (half time), but if there is some traction from the community, as well as extra people willing to help, then most definitely we could try.

@ValerianRey
Copy link
Contributor

Is TorchJD willing to cover all sorts of MOO algorithms, including stateless, stateful, single Pareto solution, finite/infinite set of Pareto solutions, in the future?

It is a matter of the responsibility we are able to take. For now, we are responsible for implementation and optimization of methods based on the Jacobian and/or its Gramian (Jacobian descent). Sadly, that is already a bit too much for 2 people (half time), but if there is some traction from the community, as well as extra people willing to help, then most definitely we could try.

I mostly agree with that, but I also think that some methods would be quite easy to incorporate, and these methods could come from the community (I won't have time to implement any of them, but I could spend a bit of time working on a structure and reviewing some PRs for these methods). I'm thinking in particular of stateful methods (we spent a lot of time discussing this in the past, but I think all we need in the end is just a reset method + a disclaimer telling that these aren't true mathematical aggregators), and maybe scalarization methods. We're probably never going to include methods to find multiple (finitely or infinitely many) solutions, though, as this is seems like an entirely different approach compared to what we're doing.

So to conclude, I'm open to having TorchJD cover more and more MOO algorithms, but IMO it should never become an exhaustive collection of all methods from the literature. Let's just grab the low hanging fruits and keep improving what we already have.

I would love to know more about your vision of TorchJD @rkhosrowshahi. Do you see it as a potential tool to compare MOO methods? If so, what's the problem with using LibMTL?

I don't have a definitive answer about whether we should add a scalarization package to TorchJD. Maybe we could discuss that in a voice call on discord all three of us @PierreQuinton @rkhosrowshahi ?

@ValerianRey ValerianRey added cc: feat Conventional commit type for new features. package: aggregation labels Feb 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants