feat(scalarization): Add scalarization package#701
Conversation
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
There was a problem hiding this comment.
Very nice PR, thanks!
There are few things we need to discuss with @ValerianRey but it's looking very good.
We are missing documentation (rst files in docs), also we should put a docstring in init.py that explains the package, sinilarly to aggregation.
Also @ValerianRey should we define a code owner for the package or leave it as the default (maintainers) and then after maybe a few PR from @ppraneth we can make him code owner?
| """ | ||
| Abstract base class for all scalarizers. Reduces a tensor of losses of any shape into a single | ||
| scalar loss that can be passed to :meth:`~torch.Tensor.backward`. | ||
| """ |
There was a problem hiding this comment.
@ValerianRey I think I would abstract away from losses and differentiation to tensors/inputs and making them into scalars. What do you think?
| scalar_input: Tensor = tensor_(7.0) | ||
| vector_input: Tensor = randn_(5) | ||
| matrix_input: Tensor = randn_(3, 4) | ||
| tensor_3d_input: Tensor = randn_(2, 3, 4) | ||
|
|
||
| typical_inputs: list[Tensor] = [vector_input, matrix_input, tensor_3d_input] | ||
| all_inputs: list[Tensor] = [scalar_input, *typical_inputs] |
There was a problem hiding this comment.
| scalar_input: Tensor = tensor_(7.0) | |
| vector_input: Tensor = randn_(5) | |
| matrix_input: Tensor = randn_(3, 4) | |
| tensor_3d_input: Tensor = randn_(2, 3, 4) | |
| typical_inputs: list[Tensor] = [vector_input, matrix_input, tensor_3d_input] | |
| all_inputs: list[Tensor] = [scalar_input, *typical_inputs] | |
| scalar_inputs: Tensor = [randn_([]) for _ in range(3)] | |
| vector_inputs: Tensor = [randn_([5]) for _ in range(3)] | |
| matrix_inputs: Tensor = [randn_([3, 4]) for _ in range(3)] | |
| tensor_3d_inputs: Tensor = [randn_([2, 3, 4]) for _ in range(3)] | |
| typical_inputs: list[Tensor] = vector_inputs + matrix_inputs + tensor_3d_inputs | |
| all_inputs: list[Tensor] = scalar_input + typical_inputs |
Maybe we should even have many shapes instead, what do you think @ValerianRey ?
There was a problem hiding this comment.
I would rename typical_inputs to non_scalar_inputs.
Co-authored-by: Pierre Quinton <pierre.quinton@gmail.com>
Co-authored-by: Pierre Quinton <pierre.quinton@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Closes #666.
Summary
Adds a new
torchjd.scalarizationpackage providing simple baselines against which aggregators can be compared. Includes:Scalarizer: abstract base class, inherits fromnn.Module.Mean: returnslosses.mean().Sum: returnslosses.sum().Constant(weights): returns(weights * losses).sum(). Validatesweights.shape == losses.shapeat call time.Random: combines losses with positive random weights summing to 1 (RLW, Algorithm 2 of arXiv 2111.10603).All scalarizers accept loss tensors of any shape (including 0-dim) and return a 0-dim scalar.
Design decisions (confirmed with maintainers )
Scalarizersuffix.Scalarizerinherits fromnn.Modulefor consistency withAggregatorandWeighting, and to leave room for trainable scalarizers later.Combinergeneralization for now.Constantuses option (a):weights.shapemust equallosses.shape.The
Statefulmixin stays intorchjd.aggregation._mixinsfor now. When the first stateful scalarizer lands, it can be moved totorchjd._mixinsso both packages share it. A comment inscalarization/__init__.pyrecords this.Test plan
uv run pytest tests/unit/scalarization -W error -vpasses.uv run pytest tests/unit/scalarization --cov=src/torchjd/scalarizationshows full coverage.uv run pytest tests/unit -W errorpasses (3019 passed, 66 skipped, 33 xfailed).PYTEST_TORCH_DTYPE=float64 uv run pytest tests/unit/scalarization -W errorpasses.uv run ruff format --checkanduv run ruff checkpass.uv run ty checkpasses.uv run pre-commit run --all-filespasses.