From 5a6358f4b61f64334113186f35cdd0996884412e Mon Sep 17 00:00:00 2001 From: Khush Date: Sat, 20 Jun 2026 11:05:41 -0400 Subject: [PATCH] feat(Aggregation): Add ExcessMTLWeighting --- CHANGELOG.md | 1 + NOTICES | 28 +++ docs/source/docs/aggregation/excess_mtl.rst | 7 + docs/source/docs/aggregation/index.rst | 1 + src/torchjd/aggregation/__init__.py | 2 + src/torchjd/aggregation/_excess_mtl.py | 184 +++++++++++++++++ tests/unit/aggregation/test_excess_mtl.py | 209 ++++++++++++++++++++ 7 files changed, 432 insertions(+) create mode 100644 docs/source/docs/aggregation/excess_mtl.rst create mode 100644 src/torchjd/aggregation/_excess_mtl.py create mode 100644 tests/unit/aggregation/test_excess_mtl.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 52107b54..287b3b14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ changelog does not include internal changes that do not affect the user. Algorithm Based on Decomposition](https://ieeexplore.ieee.org/document/4358754) (IEEE TEVC 2007), a `Scalarizer` that decomposes the values into a component along a preference direction and a penalized perpendicular component. +- Added `ExcessMTLWeighting` from [Robust Multi-Task Learning with Excess Risks](https://proceedings.mlr.press/v235/he24n.html) (ICML 2024). It is a stateful `Weighting` that maintains task weights across calls via an exponentiated gradient update driven by per-task excess risk estimates. The excess risk is approximated using an AdaGrad-style diagonal Hessian. An optional `n_warmup_steps` parameter controls how many forward calls collect gradient statistics before weight updates begin. ## [0.15.0] - 2026-06-15 diff --git a/NOTICES b/NOTICES index 098695c3..0f1ac03a 100644 --- a/NOTICES +++ b/NOTICES @@ -143,6 +143,34 @@ SOFTWARE. ------------------------------------------------------------------------------- +Project: ExcessMTL +Source: https://github.com/uiuctml/ExcessMTL/blob/main/LibMTL/LibMTL/weighting/ExcessMTL.py +Used in: src/torchjd/aggregation/_excess_mtl.py + +MIT License + +Copyright (c) 2024 UIUC TML Lab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +------------------------------------------------------------------------------- + Project: SDMGrad Source: https://github.com/OptMN-Lab/SDMGrad/blob/main/methods/weight_methods.py Used in: src/torchjd/aggregation/_sdmgrad.py diff --git a/docs/source/docs/aggregation/excess_mtl.rst b/docs/source/docs/aggregation/excess_mtl.rst new file mode 100644 index 00000000..e79fcc88 --- /dev/null +++ b/docs/source/docs/aggregation/excess_mtl.rst @@ -0,0 +1,7 @@ +:hide-toc: + +ExcessMTL +========= + +.. autoclass:: torchjd.aggregation.ExcessMTLWeighting + :members: __call__, reset diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 867e8d8c..b77332db 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -30,6 +30,7 @@ Abstract base classes constant.rst cr_mogm.rst dualproj.rst + excess_mtl.rst fairgrad.rst graddrop.rst gradvac.rst diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 5285e6bf..61cae873 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -45,6 +45,7 @@ from ._constant import Constant, ConstantWeighting from ._cr_mogm import CRMOGMWeighting from ._dualproj import DualProj, DualProjWeighting +from ._excess_mtl import ExcessMTLWeighting from ._fairgrad import FairGrad, FairGradWeighting from ._graddrop import GradDrop from ._gradvac import GradVac, GradVacWeighting @@ -74,6 +75,7 @@ "CRMOGMWeighting", "DualProj", "DualProjWeighting", + "ExcessMTLWeighting", "FairGrad", "FairGradWeighting", "GradDrop", diff --git a/src/torchjd/aggregation/_excess_mtl.py b/src/torchjd/aggregation/_excess_mtl.py new file mode 100644 index 00000000..266fe25c --- /dev/null +++ b/src/torchjd/aggregation/_excess_mtl.py @@ -0,0 +1,184 @@ +# Partly adapted from https://github.com/uiuctml/ExcessMTL — MIT License, Copyright (c) 2024 UIUC TML Lab. +# See NOTICES for the full license text. +from __future__ import annotations + +from typing import cast + +import torch +from torch import Tensor + +from torchjd._mixins import Stateful +from torchjd.aggregation._mixins import _NonDifferentiable +from torchjd.linalg import Matrix + +from ._weighting_bases import _MatrixWeighting + + +class ExcessMTLWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): + r""" + :class:`~torchjd.Stateful` + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] from `Robust + Multi-Task Learning with Excess Risks + `_ (ICML 2024). + + At each call, task weights are updated via an exponentiated gradient step (Equation 9) driven + by per-task excess risk estimates. The excess risk for task :math:`i` is approximated via a + second-order Taylor expansion (Equations 6-7): + + :param robust_step_size: Step size :math:`\eta_\alpha` for the exponentiated weight update. + Must be positive. + :param n_warmup_steps: Number of forward calls during which weights stay uniform + (:math:`[1/m, \ldots, 1/m]`) and gradient statistics are collected. The baseline excess + risk is set to the average excess risk observed during warmup. When ``0`` (default), the + first call's excess risk is used as the baseline and weights are updated immediately + (matching the official implementation). + + .. warning:: + The state tensor :math:`S \in \mathbb{R}^{m \times n}` accumulates squared gradients + across **all** calls, where :math:`n` is the total number of model parameters. For large + models this can be a significant memory cost. Call :meth:`reset` between experiments. + + .. note:: + The weight update is adapted from the `official implementation + `_ and `LibMTL + `_. + The warmup strategy follows Appendix C.1 of the paper, which recommends collecting + gradient statistics for several epochs before beginning weight updates; set + ``n_warmup_steps`` accordingly (e.g. ``3 * len(dataloader)``). + + .. admonition:: Example + + .. testcode:: + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd import autojac + from torchjd.aggregation import ExcessMTLWeighting, WeightedAggregator + from torchjd.autojac import jac_to_grad + + inputs = torch.randn(8, 5) + targets = torch.randn(8, 2) + + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 2)) + optimizer = SGD(model.parameters()) + criterion = MSELoss() + aggregator = WeightedAggregator(ExcessMTLWeighting()) + + outputs = model(inputs) + losses = [criterion(outputs[:, i], targets[:, i]) for i in range(2)] + autojac.backward(losses) + jac_to_grad(model.parameters(), aggregator) + optimizer.step() + optimizer.zero_grad() + """ + + def __init__( + self, + robust_step_size: float = 1.0, + n_warmup_steps: int = 0, + ) -> None: + super().__init__() + self.robust_step_size = robust_step_size + self.n_warmup_steps = n_warmup_steps + self.register_buffer("_weights", None) + self.register_buffer("_grad_sum", None) + self.register_buffer("_initial_w", None) + self.register_buffer("_warmup_w_sum", None) + self.register_buffer("_n_steps", torch.zeros((), dtype=torch.long)) + self._state_key: tuple[int, int, torch.dtype, torch.device] | None = None + + @property + def robust_step_size(self) -> float: + return self._robust_step_size + + @robust_step_size.setter + def robust_step_size(self, value: float) -> None: + if value <= 0.0: + raise ValueError( + f"Attribute `robust_step_size` must be positive. Found robust_step_size={value!r}." + ) + self._robust_step_size = value + + @property + def n_warmup_steps(self) -> int: + return self._n_warmup_steps + + @n_warmup_steps.setter + def n_warmup_steps(self, value: int) -> None: + if value < 0: + raise ValueError( + f"Attribute `n_warmup_steps` must be non-negative. Found n_warmup_steps={value!r}." + ) + self._n_warmup_steps = value + + def reset(self) -> None: + """Clears all state so the next forward starts from uniform weights and re-enters + warmup.""" + + self._weights = None + self._grad_sum = None + self._initial_w = None + self._warmup_w_sum = None + self._n_steps.zero_() + self._state_key = None + + def forward(self, matrix: Matrix, /) -> Tensor: + self._ensure_state(matrix) + + # Accumulate squared gradients for AdaGrad-style diagonal Hessian (Equation 7) + grad_sum = cast(Tensor, self._grad_sum) + grad_sum = grad_sum + matrix.detach() ** 2 + self._grad_sum = grad_sum + + # Excess risk proxy: Ê_i ≈ g_i^T H_i^{-1} g_i (Equation 6) + h = torch.sqrt(grad_sum + 1e-7) + w = (matrix.detach() ** 2 / h).sum(dim=1) # shape [m] + + n_steps = int(self._n_steps.item()) + self._n_steps = self._n_steps + 1 + + # Warmup: collect excess risk stats but return uniform weights + if n_steps < self._n_warmup_steps: + warmup_w_sum = self._warmup_w_sum + self._warmup_w_sum = w if warmup_w_sum is None else cast(Tensor, warmup_w_sum) + w + return cast(Tensor, self._weights) + + # Set baseline on the first non-warmup call + if self._initial_w is None: + if self._n_warmup_steps > 0: + # Average excess risk observed during warmup (Appendix C.1) + self._initial_w = cast(Tensor, self._warmup_w_sum) / self._n_warmup_steps + w = w / (cast(Tensor, self._initial_w) + 1e-7) + else: + # Official impl behaviour: first call's excess is the baseline; use w raw + self._initial_w = w + else: + w = w / (cast(Tensor, self._initial_w) + 1e-7) + + # Exponentiated gradient weight update (Equation 9) + weights = cast(Tensor, self._weights) + weights = weights * torch.exp(w * self._robust_step_size) + weights = weights / weights.sum() + self._weights = weights + return weights + + def _ensure_state(self, matrix: Matrix) -> None: + key = (matrix.shape[0], matrix.shape[1], matrix.dtype, matrix.device) + if self._state_key == key and self._grad_sum is not None: + return + m, n = matrix.shape + self._grad_sum = matrix.new_zeros(m, n) + self._weights = matrix.new_full((m,), 1.0 / m) + self._initial_w = None + self._warmup_w_sum = None + self._n_steps.zero_() + self._state_key = key + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"robust_step_size={self.robust_step_size!r}, " + f"n_warmup_steps={self.n_warmup_steps!r})" + ) diff --git a/tests/unit/aggregation/test_excess_mtl.py b/tests/unit/aggregation/test_excess_mtl.py new file mode 100644 index 00000000..f0fcb9dc --- /dev/null +++ b/tests/unit/aggregation/test_excess_mtl.py @@ -0,0 +1,209 @@ +import torch +from pytest import raises +from torch.testing import assert_close +from utils.tensors import randn_, tensor_ + +from torchjd.aggregation._excess_mtl import ExcessMTLWeighting + + +def test_representations() -> None: + W = ExcessMTLWeighting(robust_step_size=1.0, n_warmup_steps=0) + assert repr(W) == "ExcessMTLWeighting(robust_step_size=1.0, n_warmup_steps=0)" + + +def test_reset_restores_first_step_behavior() -> None: + J = randn_((3, 8)) + W = ExcessMTLWeighting() + first = W(J) + W(J) + W.reset() + assert_close(first, W(J)) + + +def test_robust_step_size_setter_accepts_valid() -> None: + W = ExcessMTLWeighting() + W.robust_step_size = 0.1 + assert W.robust_step_size == 0.1 + W.robust_step_size = 10.0 + assert W.robust_step_size == 10.0 + + +def test_robust_step_size_setter_rejects_non_positive() -> None: + W = ExcessMTLWeighting() + with raises(ValueError, match="robust_step_size"): + W.robust_step_size = 0.0 + with raises(ValueError, match="robust_step_size"): + W.robust_step_size = -1.0 + + +def test_n_warmup_steps_setter_accepts_valid() -> None: + W = ExcessMTLWeighting() + W.n_warmup_steps = 0 + assert W.n_warmup_steps == 0 + W.n_warmup_steps = 100 + assert W.n_warmup_steps == 100 + + +def test_n_warmup_steps_setter_rejects_negative() -> None: + W = ExcessMTLWeighting() + with raises(ValueError, match="n_warmup_steps"): + W.n_warmup_steps = -1 + + +def test_output_lies_on_simplex() -> None: + """The exponentiated update followed by normalisation keeps the weights on the simplex.""" + + J = randn_((4, 10)) + W = ExcessMTLWeighting() + # Call twice so the second call exercises the normalised-w branch + W(J) + weights = W(J) + assert weights.shape == (4,) + assert (weights >= 0).all() + assert_close(weights.sum(), tensor_(1.0)) + + +def test_warmup_returns_uniform() -> None: + """During warmup every call must return [1/m, ..., 1/m] regardless of the input.""" + + m, n_warmup = 3, 5 + W = ExcessMTLWeighting(n_warmup_steps=n_warmup) + expected = tensor_([1.0 / m] * m) + for _ in range(n_warmup): + assert_close(W(randn_((m, 8))), expected) + + +def test_weights_change_after_warmup() -> None: + """After warmup ends the weights must diverge from uniform when tasks have different excess risks.""" + + W = ExcessMTLWeighting(n_warmup_steps=2, robust_step_size=1.0) + # Symmetric warmup: equal excess risk for both tasks → equal initial_w + J_sym = tensor_([[1.0, 0.0], [1.0, 0.0]]) + W(J_sym) + W(J_sym) + + # Asymmetric step: task 0 has larger gradient → higher excess → weight must exceed task 1 + J_unequal = tensor_([[2.0, 0.0], [1.0, 0.0]]) + weights = W(J_unequal) + assert weights[0] > weights[1] + + +def test_update_recurrence() -> None: + """Verify the first weight update manually (n_warmup_steps=0, LibMTL behaviour). + + With J = [[2., 0.], [1., 0.]] and robust_step_size=1.0: + grad_sum = [[4., 0.], [1., 0.]] + h ≈ [[2., sqrt(eps)], [1., sqrt(eps)]] (eps = 1e-7, negligible in float32 for nonzero entries) + w = [4/2 + 0, 1/1 + 0] = [2, 1] + initial_w = [2, 1] (first call: save raw excess as baseline) + weights = [exp(2), exp(1)] / (exp(2) + exp(1)) + """ + J = tensor_([[2.0, 0.0], [1.0, 0.0]]) + W = ExcessMTLWeighting(robust_step_size=1.0) + e2 = torch.exp(tensor_(2.0)) + e1 = torch.exp(tensor_(1.0)) + assert_close(W(J), tensor_([e2 / (e2 + e1), e1 / (e2 + e1)])) + + +def test_two_consecutive_steps() -> None: + """Verify warm-started carry-over across two calls. + + Call 1: J = [[2., 0.], [1., 0.]] → weights = [e^2, e] / (e^2 + e) (from test above) + Call 2: J = [[1., 0.], [2., 0.]] + grad_sum = [[4+1., 0.], [1+4., 0.]] = [[5., 0.], [5., 0.]] + h ≈ [[sqrt(5), sqrt(eps)], [sqrt(5), sqrt(eps)]] + w = [1/sqrt(5), 4/sqrt(5)] + initial_w = [2, 1] (from call 1) + w_norm = [1/(2*sqrt(5)), 4/sqrt(5)] + weights_2 = weights_1 * [exp(w_norm_0), exp(w_norm_1)] / normalization + """ + J1 = tensor_([[2.0, 0.0], [1.0, 0.0]]) + J2 = tensor_([[1.0, 0.0], [2.0, 0.0]]) + W = ExcessMTLWeighting(robust_step_size=1.0) + + e2 = torch.exp(tensor_(2.0)) + e1 = torch.exp(tensor_(1.0)) + weights_1 = tensor_([e2 / (e2 + e1), e1 / (e2 + e1)]) + assert_close(W(J1), weights_1) + + sqrt5 = torch.sqrt(tensor_(5.0)) + w_norm_0 = tensor_(1.0) / (tensor_(2.0) * sqrt5) + w_norm_1 = tensor_(4.0) / sqrt5 + unnorm_0 = weights_1[0] * torch.exp(w_norm_0) + unnorm_1 = weights_1[1] * torch.exp(w_norm_1) + weights_2 = tensor_([unnorm_0 / (unnorm_0 + unnorm_1), unnorm_1 / (unnorm_0 + unnorm_1)]) + assert_close(W(J2), weights_2) + + +def test_warmup_baseline_is_average() -> None: + """initial_w after warmup must equal the average excess risk collected during warmup. + + With n_warmup_steps=2 and J1=[[2,0],[1,0]], J2=[[1,0],[2,0]]: + + Warmup call 1 — grad_sum_1 = J1**2 = [[4,0],[1,0]]: + h_1 ≈ [[2, sqrt(eps)], [1, sqrt(eps)]] + w_1 = [4/2, 1/1] = [2, 1] + + Warmup call 2 — grad_sum_2 = J1**2 + J2**2 = [[5,0],[5,0]]: + h_2 ≈ [[sqrt(5), sqrt(eps)], [sqrt(5), sqrt(eps)]] + w_2 = [1/sqrt(5), 4/sqrt(5)] + + initial_w = (w_1 + w_2) / 2 (Appendix C.1 average) + + Post-warmup call 3 with J3 = J1 — grad_sum_3 = [[9,0],[6,0]]: + h_3 ≈ [[3, sqrt(eps)], [sqrt(6), sqrt(eps)]] + w_3 = [4/3, 1/sqrt(6)] + w_norm = w_3 / (initial_w + 1e-7) + weights = [0.5, 0.5] * exp(w_norm) / normalize + """ + + J1 = tensor_([[2.0, 0.0], [1.0, 0.0]]) + J2 = tensor_([[1.0, 0.0], [2.0, 0.0]]) + J3 = tensor_([[2.0, 0.0], [1.0, 0.0]]) + W = ExcessMTLWeighting(n_warmup_steps=2, robust_step_size=1.0) + + W(J1) # warmup step 1 — grad_sum becomes J1**2 + W(J2) # warmup step 2 — grad_sum becomes J1**2 + J2**2 + + grad_sum_1 = J1**2 + h_1 = torch.sqrt(grad_sum_1 + 1e-7) + w_1 = (J1**2 / h_1).sum(dim=1) + + grad_sum_2 = grad_sum_1 + J2**2 + h_2 = torch.sqrt(grad_sum_2 + 1e-7) + w_2 = (J2**2 / h_2).sum(dim=1) + + initial_w = (w_1 + w_2) / 2 # Appendix C.1 baseline + + grad_sum_3 = grad_sum_2 + J3**2 + h_3 = torch.sqrt(grad_sum_3 + 1e-7) + w_3 = (J3**2 / h_3).sum(dim=1) + w_norm = w_3 / (initial_w + 1e-7) + pre_norm = tensor_([0.5, 0.5]) * torch.exp(w_norm) + expected = pre_norm / pre_norm.sum() + + assert_close(W(J3), expected) + + +def test_n_steps_resets_on_m_change() -> None: + """When the number of objectives changes the warmup counter must restart.""" + + W = ExcessMTLWeighting(n_warmup_steps=10) + # Burn through 5 warmup steps + for _ in range(5): + W(randn_((3, 8))) + + # Switch to 2 objectives — state including step counter resets + fresh = ExcessMTLWeighting(n_warmup_steps=10) + J = randn_((2, 8)) + assert_close(W(J), fresh(J)) + + +def test_non_differentiable() -> None: + """The _NonDifferentiable mixin must prevent autograd graph construction.""" + + J = randn_((3, 8)) + J.requires_grad_(True) + W = ExcessMTLWeighting() + weights = W(J) + assert not weights.requires_grad