Skip to content

feat(scalarization): Add scalarization package#701

Open
ppraneth wants to merge 7 commits into
SimplexLab:mainfrom
ppraneth:scalarization
Open

feat(scalarization): Add scalarization package#701
ppraneth wants to merge 7 commits into
SimplexLab:mainfrom
ppraneth:scalarization

Conversation

@ppraneth
Copy link
Copy Markdown

Closes #666.

Summary

Adds a new torchjd.scalarization package providing simple baselines against which aggregators can be compared. Includes:

  • Scalarizer: abstract base class, inherits from nn.Module.
  • Mean: returns losses.mean().
  • Sum: returns losses.sum().
  • Constant(weights): returns (weights * losses).sum(). Validates weights.shape == losses.shape at call time.
  • Random: combines losses with positive random weights summing to 1 (RLW, Algorithm 2 of arXiv 2111.10603).

All scalarizers accept loss tensors of any shape (including 0-dim) and return a 0-dim scalar.

Design decisions (confirmed with maintainers )

  1. Short names mirroring the aggregation package, no Scalarizer suffix.
  2. Scalarizer inherits from nn.Module for consistency with Aggregator and Weighting, and to leave room for trainable scalarizers later.
  3. No Combiner generalization for now.
  4. Arbitrary-shape input, 0-dim output.
  5. Constant uses option (a): weights.shape must equal losses.shape.

The Stateful mixin stays in torchjd.aggregation._mixins for now. When the first stateful scalarizer lands, it can be moved to torchjd._mixins so both packages share it. A comment in scalarization/__init__.py records this.

Test plan

  • uv run pytest tests/unit/scalarization -W error -v passes.
  • uv run pytest tests/unit/scalarization --cov=src/torchjd/scalarization shows full coverage.
  • uv run pytest tests/unit -W error passes (3019 passed, 66 skipped, 33 xfailed).
  • PYTEST_TORCH_DTYPE=float64 uv run pytest tests/unit/scalarization -W error passes.
  • uv run ruff format --check and uv run ruff check pass.
  • uv run ty check passes.
  • uv run pre-commit run --all-files passes.

ppraneth added 3 commits May 26, 2026 19:03
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
@ppraneth ppraneth requested a review from a team as a code owner May 27, 2026 02:30
@PierreQuinton PierreQuinton added cc: feat Conventional commit type for new features. package: scalarization labels May 27, 2026
@github-actions github-actions Bot changed the title feat: Add scalarization package feat(scalarization): Add scalarization package May 27, 2026
@github-actions github-actions Bot changed the title feat: Add scalarization package feat(scalarization): Add scalarization package May 27, 2026
Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice PR, thanks!

There are few things we need to discuss with @ValerianRey but it's looking very good.

We are missing documentation (rst files in docs), also we should put a docstring in init.py that explains the package, sinilarly to aggregation.

Also @ValerianRey should we define a code owner for the package or leave it as the default (maintainers) and then after maybe a few PR from @ppraneth we can make him code owner?

Comment thread src/torchjd/scalarization/_scalarizer_base.py Outdated
Comment on lines +7 to +10
"""
Abstract base class for all scalarizers. Reduces a tensor of losses of any shape into a single
scalar loss that can be passed to :meth:`~torch.Tensor.backward`.
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ValerianRey I think I would abstract away from losses and differentiation to tensors/inputs and making them into scalars. What do you think?

Comment thread src/torchjd/scalarization/_random.py Outdated
Comment thread src/torchjd/scalarization/_random.py Outdated
Comment on lines +4 to +10
scalar_input: Tensor = tensor_(7.0)
vector_input: Tensor = randn_(5)
matrix_input: Tensor = randn_(3, 4)
tensor_3d_input: Tensor = randn_(2, 3, 4)

typical_inputs: list[Tensor] = [vector_input, matrix_input, tensor_3d_input]
all_inputs: list[Tensor] = [scalar_input, *typical_inputs]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
scalar_input: Tensor = tensor_(7.0)
vector_input: Tensor = randn_(5)
matrix_input: Tensor = randn_(3, 4)
tensor_3d_input: Tensor = randn_(2, 3, 4)
typical_inputs: list[Tensor] = [vector_input, matrix_input, tensor_3d_input]
all_inputs: list[Tensor] = [scalar_input, *typical_inputs]
scalar_inputs: Tensor = [randn_([]) for _ in range(3)]
vector_inputs: Tensor = [randn_([5]) for _ in range(3)]
matrix_inputs: Tensor = [randn_([3, 4]) for _ in range(3)]
tensor_3d_inputs: Tensor = [randn_([2, 3, 4]) for _ in range(3)]
typical_inputs: list[Tensor] = vector_inputs + matrix_inputs + tensor_3d_inputs
all_inputs: list[Tensor] = scalar_input + typical_inputs

Maybe we should even have many shapes instead, what do you think @ValerianRey ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rename typical_inputs to non_scalar_inputs.

Comment thread tests/unit/scalarization/test_constant.py Outdated
Comment thread tests/unit/scalarization/test_random.py Outdated
Comment thread tests/unit/scalarization/test_scalarizer_base.py Outdated
Comment thread src/torchjd/scalarization/__init__.py Outdated
ppraneth and others added 4 commits May 27, 2026 09:57
Co-authored-by: Pierre Quinton <pierre.quinton@gmail.com>
Co-authored-by: Pierre Quinton <pierre.quinton@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: scalarization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add scalarization package

2 participants