Skip to content

Add scalarization package #666

@ValerianRey

Description

@ValerianRey

To be able to compare aggregators to simple baselines, it would be great to have a torchjd.scalarization package.
This package would provide all sorts of Scalarizers, to combine multiple losses into a single scalar loss.

Proposed usage example:

from torchjd.scalarization import Mean

...
scalarizer = Mean()
losses = criterion(output, target)
loss = scalarizer(losses)
loss.backward()
...

To implement this, we would need:

  • A new public scalarization package in torchjd
  • An abstract base class Scalarizer
  • To start, a few trivial scalarizers, e.g. Mean (aka Equal Weights, or EW), Sum, Constant (aka Linear Scalarization or LS), Random (aka RLW).
  • In future pull requests, we could add some interesting scalarization methods 1 by 1. I'm thinking of STCH, maybe FAMO, but I'll make a dedicated issue to track those.

A few questions:

  • Is it fine to have some name in common between scalarizers and aggregators (e.g. torchjd.scalarization.Mean and torchjd.aggregation.Mean)? Maybe we wanna name them MeanScalarizer, SumScalarizer, ConstantScalarizer, STCHScalarizer, etc. But at the same time the package should be responsible for indicating this. And people can always import with from torchjd.scalarization import Mean as MeanScalarizer. So I'm not sure.
  • Should Scalarizer inherit from nn.Module (like Aggregator and Weighting do)? It makes it much easier to add hooks, but I'm not sure if hooks will even be needed for scalarizers, and nn.Module may lead to some typing issues. Another advantage of nn.Module is that it can have trainable params, which I think would be necessary for GradNorm.
  • There is a slightly more general concept than scalarization, which is to group some losses. For example, you could start with 128 losses, and group them into 32 group of 4 losses, and average each group, ending up with 32 losses. We could thus have a Combiner that could take a loss tensor and return another loss tensor, not necessarily scalar. In my opinion, we should develop Scalarizer without thinking too much about this, and one day maybe Scalarizer will be a special case of Combiner.
  • Should the input to the __call__ method of a Scalarizer necessarily be a vector (tensor of ndim=1) or could it be any tensor shape (scalar, vector, matrix, etc.)? I think it would be better if it could be any shape, to be coherent with the interface of autojac.backward that works on any shape of tensor. Also, some methods may make use of the shape of the loss tensor.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions