From 54e2b6e69900ff6e65a96ee44e92f5a965a7d0bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 16 Apr 2026 16:44:33 +0200 Subject: [PATCH 1/2] Reorder code to have weightings defined before aggregators - This is needed in order to use the weighting as a type hint for a field of the aggregator --- src/torchjd/aggregation/_aligned_mtl.py | 72 ++++++------- src/torchjd/aggregation/_cagrad.py | 64 ++++++------ src/torchjd/aggregation/_constant.py | 40 ++++---- src/torchjd/aggregation/_dualproj.py | 72 ++++++------- src/torchjd/aggregation/_gradvac.py | 128 ++++++++++++------------ src/torchjd/aggregation/_imtl_g.py | 30 +++--- src/torchjd/aggregation/_krum.py | 52 +++++----- src/torchjd/aggregation/_mean.py | 20 ++-- src/torchjd/aggregation/_mgda.py | 44 ++++---- src/torchjd/aggregation/_nash_mtl.py | 128 ++++++++++++------------ src/torchjd/aggregation/_pcgrad.py | 26 ++--- src/torchjd/aggregation/_random.py | 24 ++--- src/torchjd/aggregation/_sum.py | 20 ++-- src/torchjd/aggregation/_upgrad.py | 72 ++++++------- 14 files changed, 396 insertions(+), 396 deletions(-) diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index 2708b8006..4227671d9 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -17,42 +17,6 @@ SUPPORTED_SCALE_MODE: TypeAlias = Literal["min", "median", "rmse"] -class AlignedMTL(GramianWeightedAggregator): - r""" - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of - `Independent Component Alignment for Multi-Task Learning - `_. - - :param pref_vector: The preference vector to use. If not provided, defaults to - :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. - :param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses - the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"`` - uses the mean eigenvalue (as in the original implementation). - - .. note:: - This implementation was adapted from the official implementation of SamsungLabs/MTL, - which is not available anymore at the time of writing. - """ - - def __init__( - self, - pref_vector: Tensor | None = None, - scale_mode: SUPPORTED_SCALE_MODE = "min", - ) -> None: - self._pref_vector = pref_vector - self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode - super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode)) - - def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, " - f"scale_mode={repr(self._scale_mode)})" - ) - - def __str__(self) -> str: - return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}" - - class AlignedMTLWeighting(Weighting[PSDMatrix]): r""" :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of @@ -113,3 +77,39 @@ def _compute_balance_transformation( B = scale.sqrt() * V @ sigma_inv @ V.T return B + + +class AlignedMTL(GramianWeightedAggregator): + r""" + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of + `Independent Component Alignment for Multi-Task Learning + `_. + + :param pref_vector: The preference vector to use. If not provided, defaults to + :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. + :param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses + the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"`` + uses the mean eigenvalue (as in the original implementation). + + .. note:: + This implementation was adapted from the official implementation of SamsungLabs/MTL, + which is not available anymore at the time of writing. + """ + + def __init__( + self, + pref_vector: Tensor | None = None, + scale_mode: SUPPORTED_SCALE_MODE = "min", + ) -> None: + self._pref_vector = pref_vector + self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode + super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode)) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, " + f"scale_mode={repr(self._scale_mode)})" + ) + + def __str__(self) -> str: + return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}" diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 6731178bb..c6b42d960 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -18,38 +18,6 @@ from ._utils.non_differentiable import raise_non_differentiable_error -class CAGrad(GramianWeightedAggregator): - """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of - `Conflict-Averse Gradient Descent for Multi-task Learning - `_. - - :param c: The scale of the radius of the ball constraint. - :param norm_eps: A small value to avoid division by zero when normalizing. - - .. note:: - This aggregator is not installed by default. When not installed, trying to import it should - result in the following error: - ``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``. - To install it, use ``pip install "torchjd[cagrad]"``. - """ - - def __init__(self, c: float, norm_eps: float = 0.0001) -> None: - super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps)) - self._c = c - self._norm_eps = norm_eps - - # This prevents considering the computed weights as constant w.r.t. the matrix. - self.register_full_backward_pre_hook(raise_non_differentiable_error) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})" - - def __str__(self) -> str: - c_str = str(self._c).rstrip("0") - return f"CAGrad{c_str}" - - class CAGradWeighting(Weighting[PSDMatrix]): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of @@ -104,3 +72,35 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype) return weights + + +class CAGrad(GramianWeightedAggregator): + """ + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of + `Conflict-Averse Gradient Descent for Multi-task Learning + `_. + + :param c: The scale of the radius of the ball constraint. + :param norm_eps: A small value to avoid division by zero when normalizing. + + .. note:: + This aggregator is not installed by default. When not installed, trying to import it should + result in the following error: + ``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``. + To install it, use ``pip install "torchjd[cagrad]"``. + """ + + def __init__(self, c: float, norm_eps: float = 0.0001) -> None: + super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps)) + self._c = c + self._norm_eps = norm_eps + + # This prevents considering the computed weights as constant w.r.t. the matrix. + self.register_full_backward_pre_hook(raise_non_differentiable_error) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})" + + def __str__(self) -> str: + c_str = str(self._c).rstrip("0") + return f"CAGrad{c_str}" diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index a547b813b..024c94ff1 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -7,26 +7,6 @@ from ._weighting_bases import Weighting -class Constant(WeightedAggregator): - """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that makes a linear combination of - the rows of the provided matrix, with constant, pre-determined weights. - - :param weights: The weights associated to the rows of the input matrices. - """ - - def __init__(self, weights: Tensor) -> None: - super().__init__(weighting=ConstantWeighting(weights=weights)) - self._weights = weights - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(weights={repr(self._weights)})" - - def __str__(self) -> str: - weights_str = vector_to_str(self._weights) - return f"{self.__class__.__name__}([{weights_str}])" - - class ConstantWeighting(Weighting[Matrix]): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined @@ -55,3 +35,23 @@ def _check_matrix_shape(self, matrix: Tensor) -> None: f"Parameter `matrix` should have {len(self.weights)} rows (the number of specified " f"weights). Found `matrix` with {matrix.shape[0]} rows.", ) + + +class Constant(WeightedAggregator): + """ + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that makes a linear combination of + the rows of the provided matrix, with constant, pre-determined weights. + + :param weights: The weights associated to the rows of the input matrices. + """ + + def __init__(self, weights: Tensor) -> None: + super().__init__(weighting=ConstantWeighting(weights=weights)) + self._weights = weights + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(weights={repr(self._weights)})" + + def __str__(self) -> str: + weights_str = vector_to_str(self._weights) + return f"{self.__class__.__name__}([{weights_str}])" diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 7e868f620..61503a5bc 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -10,6 +10,42 @@ from ._weighting_bases import Weighting +class DualProjWeighting(Weighting[PSDMatrix]): + r""" + :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of + :class:`~torchjd.aggregation.DualProj`. + + :param pref_vector: The preference vector to use. If not provided, defaults to + :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. + :param norm_eps: A small value to avoid division by zero when normalizing. + :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to + numerical errors when computing the gramian, it might not exactly be positive definite. + This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian + ensures that it is positive definite. + :param solver: The solver used to optimize the underlying optimization problem. + """ + + def __init__( + self, + pref_vector: Tensor | None = None, + norm_eps: float = 0.0001, + reg_eps: float = 0.0001, + solver: SUPPORTED_SOLVER = "quadprog", + ) -> None: + super().__init__() + self._pref_vector = pref_vector + self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) + self.norm_eps = norm_eps + self.reg_eps = reg_eps + self.solver: SUPPORTED_SOLVER = solver + + def forward(self, gramian: PSDMatrix, /) -> Tensor: + u = self.weighting(gramian) + G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) + w = project_weights(u, G, self.solver) + return w + + class DualProj(GramianWeightedAggregator): r""" :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input @@ -54,39 +90,3 @@ def __repr__(self) -> str: def __str__(self) -> str: return f"DualProj{pref_vector_to_str_suffix(self._pref_vector)}" - - -class DualProjWeighting(Weighting[PSDMatrix]): - r""" - :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of - :class:`~torchjd.aggregation.DualProj`. - - :param pref_vector: The preference vector to use. If not provided, defaults to - :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. - :param norm_eps: A small value to avoid division by zero when normalizing. - :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to - numerical errors when computing the gramian, it might not exactly be positive definite. - This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian - ensures that it is positive definite. - :param solver: The solver used to optimize the underlying optimization problem. - """ - - def __init__( - self, - pref_vector: Tensor | None = None, - norm_eps: float = 0.0001, - reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", - ) -> None: - super().__init__() - self._pref_vector = pref_vector - self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) - self.norm_eps = norm_eps - self.reg_eps = reg_eps - self.solver: SUPPORTED_SOLVER = solver - - def forward(self, gramian: PSDMatrix, /) -> Tensor: - u = self.weighting(gramian) - G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - w = project_weights(u, G, self.solver) - return w diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index e593a8eb5..758070b51 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -13,70 +13,6 @@ from ._weighting_bases import Weighting -class GradVac(GramianWeightedAggregator, Stateful): - r""" - :class:`~torchjd.aggregation._mixins.Stateful` - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of - Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task - Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) - `_. - - For each task :math:`i`, the order in which other tasks :math:`j` are visited is drawn at - random. For each pair :math:`(i, j)`, the cosine similarity :math:`\phi_{ij}` between the - (possibly already modified) gradient of task :math:`i` and the original gradient of task - :math:`j` is compared to an EMA target :math:`\hat{\phi}_{ij}`. When - :math:`\phi_{ij} < \hat{\phi}_{ij}`, a closed-form correction adds a scaled copy of - :math:`g_j` to :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with - :math:`\hat{\phi}_{ij} \leftarrow (1-\beta)\hat{\phi}_{ij} + \beta \phi_{ij}`. The aggregated - vector is the sum of the modified rows. - - This aggregator is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when - the number of tasks or dtype changes. - - :param beta: EMA decay for :math:`\hat{\phi}`. - :param eps: Small non-negative constant added to denominators. - - .. note:: - For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently - using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if - you need reproducibility. - - .. note:: - To apply GradVac with the `whole_model`, `enc_dec`, `all_layer` or `all_matrix` grouping - strategy, please refer to the :doc:`Grouping ` examples. - """ - - def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: - weighting = GradVacWeighting(beta=beta, eps=eps) - super().__init__(weighting) - self._gradvac_weighting = weighting - self.register_full_backward_pre_hook(raise_non_differentiable_error) - - @property - def beta(self) -> float: - return self._gradvac_weighting.beta - - @beta.setter - def beta(self, value: float) -> None: - self._gradvac_weighting.beta = value - - @property - def eps(self) -> float: - return self._gradvac_weighting.eps - - @eps.setter - def eps(self, value: float) -> None: - self._gradvac_weighting.eps = value - - def reset(self) -> None: - """Clears EMA state so the next forward starts from zero targets.""" - - self._gradvac_weighting.reset() - - def __repr__(self) -> str: - return f"GradVac(beta={self.beta!r}, eps={self.eps!r})" - - class GradVacWeighting(Weighting[PSDMatrix], Stateful): r""" :class:`~torchjd.aggregation._mixins.Stateful` @@ -195,3 +131,67 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None: if self._state_key != key or self._phi_t is None: self._phi_t = torch.zeros(m, m, dtype=dtype) self._state_key = key + + +class GradVac(GramianWeightedAggregator, Stateful): + r""" + :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of + Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task + Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) + `_. + + For each task :math:`i`, the order in which other tasks :math:`j` are visited is drawn at + random. For each pair :math:`(i, j)`, the cosine similarity :math:`\phi_{ij}` between the + (possibly already modified) gradient of task :math:`i` and the original gradient of task + :math:`j` is compared to an EMA target :math:`\hat{\phi}_{ij}`. When + :math:`\phi_{ij} < \hat{\phi}_{ij}`, a closed-form correction adds a scaled copy of + :math:`g_j` to :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with + :math:`\hat{\phi}_{ij} \leftarrow (1-\beta)\hat{\phi}_{ij} + \beta \phi_{ij}`. The aggregated + vector is the sum of the modified rows. + + This aggregator is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when + the number of tasks or dtype changes. + + :param beta: EMA decay for :math:`\hat{\phi}`. + :param eps: Small non-negative constant added to denominators. + + .. note:: + For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently + using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if + you need reproducibility. + + .. note:: + To apply GradVac with the `whole_model`, `enc_dec`, `all_layer` or `all_matrix` grouping + strategy, please refer to the :doc:`Grouping ` examples. + """ + + def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: + weighting = GradVacWeighting(beta=beta, eps=eps) + super().__init__(weighting) + self._gradvac_weighting = weighting + self.register_full_backward_pre_hook(raise_non_differentiable_error) + + @property + def beta(self) -> float: + return self._gradvac_weighting.beta + + @beta.setter + def beta(self, value: float) -> None: + self._gradvac_weighting.beta = value + + @property + def eps(self) -> float: + return self._gradvac_weighting.eps + + @eps.setter + def eps(self, value: float) -> None: + self._gradvac_weighting.eps = value + + def reset(self) -> None: + """Clears EMA state so the next forward starts from zero targets.""" + + self._gradvac_weighting.reset() + + def __repr__(self) -> str: + return f"GradVac(beta={self.beta!r}, eps={self.eps!r})" diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 75d00b76e..2c7828499 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -8,21 +8,6 @@ from ._weighting_bases import Weighting -class IMTLG(GramianWeightedAggregator): - """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` generalizing the method described in - `Towards Impartial Multi-task Learning `_. - This generalization, defined formally in `Jacobian Descent For Multi-Objective Optimization - `_, supports matrices with some linearly dependant rows. - """ - - def __init__(self) -> None: - super().__init__(IMTLGWeighting()) - - # This prevents computing gradients that can be very wrong. - self.register_full_backward_pre_hook(raise_non_differentiable_error) - - class IMTLGWeighting(Weighting[PSDMatrix]): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of @@ -37,3 +22,18 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: weights = torch.zeros_like(v) if v_sum.abs() < 1e-12 else v / v_sum return weights + + +class IMTLG(GramianWeightedAggregator): + """ + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` generalizing the method described in + `Towards Impartial Multi-task Learning `_. + This generalization, defined formally in `Jacobian Descent For Multi-Objective Optimization + `_, supports matrices with some linearly dependant rows. + """ + + def __init__(self) -> None: + super().__init__(IMTLGWeighting()) + + # This prevents computing gradients that can be very wrong. + self.register_full_backward_pre_hook(raise_non_differentiable_error) diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index 40285d89c..5d017bf68 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -8,32 +8,6 @@ from ._weighting_bases import Weighting -class Krum(GramianWeightedAggregator): - """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` for adversarial federated learning, - as defined in `Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent - `_. - - :param n_byzantine: The number of rows of the input matrix that can come from an adversarial - source. - :param n_selected: The number of selected rows in the context of Multi-Krum. Defaults to 1. - """ - - def __init__(self, n_byzantine: int, n_selected: int = 1) -> None: - self._n_byzantine = n_byzantine - self._n_selected = n_selected - super().__init__(KrumWeighting(n_byzantine=n_byzantine, n_selected=n_selected)) - - def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}(n_byzantine={self._n_byzantine}, n_selected=" - f"{self._n_selected})" - ) - - def __str__(self) -> str: - return f"Krum{self._n_byzantine}-{self._n_selected}" - - class KrumWeighting(Weighting[PSDMatrix]): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of @@ -93,3 +67,29 @@ def _check_matrix_shape(self, gramian: PSDMatrix) -> None: f"Parameter `gramian` should have at least {self.n_selected} rows (n_selected). " f"Found `gramian` with {gramian.shape[0]} rows.", ) + + +class Krum(GramianWeightedAggregator): + """ + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` for adversarial federated learning, + as defined in `Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent + `_. + + :param n_byzantine: The number of rows of the input matrix that can come from an adversarial + source. + :param n_selected: The number of selected rows in the context of Multi-Krum. Defaults to 1. + """ + + def __init__(self, n_byzantine: int, n_selected: int = 1) -> None: + self._n_byzantine = n_byzantine + self._n_selected = n_selected + super().__init__(KrumWeighting(n_byzantine=n_byzantine, n_selected=n_selected)) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(n_byzantine={self._n_byzantine}, n_selected=" + f"{self._n_selected})" + ) + + def __str__(self) -> str: + return f"Krum{self._n_byzantine}-{self._n_selected}" diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index 8fc5b057a..2194541eb 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -7,16 +7,6 @@ from ._weighting_bases import Weighting -class Mean(WeightedAggregator): - """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input - matrices. - """ - - def __init__(self) -> None: - super().__init__(weighting=MeanWeighting()) - - class MeanWeighting(Weighting[Matrix]): r""" :class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights @@ -30,3 +20,13 @@ def forward(self, matrix: Tensor, /) -> Tensor: m = matrix.shape[0] weights = torch.full(size=[m], fill_value=1 / m, device=device, dtype=dtype) return weights + + +class Mean(WeightedAggregator): + """ + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input + matrices. + """ + + def __init__(self) -> None: + super().__init__(weighting=MeanWeighting()) diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index aec329470..6cb7d1de6 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -7,28 +7,6 @@ from ._weighting_bases import Weighting -class MGDA(GramianWeightedAggregator): - r""" - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` performing the gradient aggregation - step of `Multiple-gradient descent algorithm (MGDA) for multiobjective optimization - `_. - The implementation is based on Algorithm 2 of `Multi-Task Learning as Multi-Objective - Optimization - `_. - - :param epsilon: The value of :math:`\hat{\gamma}` below which we stop the optimization. - :param max_iters: The maximum number of iterations of the optimization loop. - """ - - def __init__(self, epsilon: float = 0.001, max_iters: int = 100) -> None: - super().__init__(MGDAWeighting(epsilon=epsilon, max_iters=max_iters)) - self._epsilon = epsilon - self._max_iters = max_iters - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(epsilon={self._epsilon}, max_iters={self._max_iters})" - - class MGDAWeighting(Weighting[PSDMatrix]): r""" :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of @@ -70,3 +48,25 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: if gamma < self.epsilon: break return alpha + + +class MGDA(GramianWeightedAggregator): + r""" + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` performing the gradient aggregation + step of `Multiple-gradient descent algorithm (MGDA) for multiobjective optimization + `_. + The implementation is based on Algorithm 2 of `Multi-Task Learning as Multi-Objective + Optimization + `_. + + :param epsilon: The value of :math:`\hat{\gamma}` below which we stop the optimization. + :param max_iters: The maximum number of iterations of the optimization loop. + """ + + def __init__(self, epsilon: float = 0.001, max_iters: int = 100) -> None: + super().__init__(MGDAWeighting(epsilon=epsilon, max_iters=max_iters)) + self._epsilon = epsilon + self._max_iters = max_iters + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(epsilon={self._epsilon}, max_iters={self._max_iters})" diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index f64f51823..b3c4672b2 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -21,70 +21,6 @@ from ._utils.non_differentiable import raise_non_differentiable_error -class NashMTL(WeightedAggregator, Stateful): - """ - :class:`~torchjd.aggregation._mixins.Stateful` - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as proposed in Algorithm 1 of - `Multi-Task Learning as a Bargaining Game `_. - - :param n_tasks: The number of tasks, corresponding to the number of rows in the provided - matrices. - :param max_norm: Maximum value of the norm of :math:`J^T w`. - :param update_weights_every: A parameter determining how often the actual weighting should be - performed. A larger value means that the same weights will be re-used for more calls to the - aggregator. - :param optim_niter: The number of iterations of the underlying optimization process. - - .. note:: - This aggregator is not installed by default. When not installed, trying to import it should - result in the following error: - ``ImportError: cannot import name 'NashMTL' from 'torchjd.aggregation'``. - To install it, use ``pip install "torchjd[nash_mtl]"``. - - .. warning:: - This implementation was adapted from the `official implementation - `_, which has some flaws. Use with caution. - - .. warning:: - This aggregator is stateful. Its output will thus depend not only on the input matrix, but - also on its state. It thus depends on previously seen matrices. It should be reset between - experiments. - """ - - def __init__( - self, - n_tasks: int, - max_norm: float = 1.0, - update_weights_every: int = 1, - optim_niter: int = 20, - ) -> None: - super().__init__( - weighting=_NashMTLWeighting( - n_tasks=n_tasks, - max_norm=max_norm, - update_weights_every=update_weights_every, - optim_niter=optim_niter, - ), - ) - self._n_tasks = n_tasks - self._max_norm = max_norm - self._update_weights_every = update_weights_every - self._optim_niter = optim_niter - - # This prevents considering the computed weights as constant w.r.t. the matrix. - self.register_full_backward_pre_hook(raise_non_differentiable_error) - - def reset(self) -> None: - """Resets the internal state of the algorithm.""" - cast(_NashMTLWeighting, self.weighting).reset() - - def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}(n_tasks={self._n_tasks}, max_norm={self._max_norm}, " - f"update_weights_every={self._update_weights_every}, optim_niter={self._optim_niter})" - ) - - class _NashMTLWeighting(Weighting[Matrix], Stateful): """ :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` that @@ -211,3 +147,67 @@ def reset(self) -> None: self.init_gtg = np.eye(self.n_tasks) self.step = 0.0 self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32) + + +class NashMTL(WeightedAggregator, Stateful): + """ + :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as proposed in Algorithm 1 of + `Multi-Task Learning as a Bargaining Game `_. + + :param n_tasks: The number of tasks, corresponding to the number of rows in the provided + matrices. + :param max_norm: Maximum value of the norm of :math:`J^T w`. + :param update_weights_every: A parameter determining how often the actual weighting should be + performed. A larger value means that the same weights will be re-used for more calls to the + aggregator. + :param optim_niter: The number of iterations of the underlying optimization process. + + .. note:: + This aggregator is not installed by default. When not installed, trying to import it should + result in the following error: + ``ImportError: cannot import name 'NashMTL' from 'torchjd.aggregation'``. + To install it, use ``pip install "torchjd[nash_mtl]"``. + + .. warning:: + This implementation was adapted from the `official implementation + `_, which has some flaws. Use with caution. + + .. warning:: + This aggregator is stateful. Its output will thus depend not only on the input matrix, but + also on its state. It thus depends on previously seen matrices. It should be reset between + experiments. + """ + + def __init__( + self, + n_tasks: int, + max_norm: float = 1.0, + update_weights_every: int = 1, + optim_niter: int = 20, + ) -> None: + super().__init__( + weighting=_NashMTLWeighting( + n_tasks=n_tasks, + max_norm=max_norm, + update_weights_every=update_weights_every, + optim_niter=optim_niter, + ), + ) + self._n_tasks = n_tasks + self._max_norm = max_norm + self._update_weights_every = update_weights_every + self._optim_niter = optim_niter + + # This prevents considering the computed weights as constant w.r.t. the matrix. + self.register_full_backward_pre_hook(raise_non_differentiable_error) + + def reset(self) -> None: + """Resets the internal state of the algorithm.""" + cast(_NashMTLWeighting, self.weighting).reset() + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(n_tasks={self._n_tasks}, max_norm={self._max_norm}, " + f"update_weights_every={self._update_weights_every}, optim_niter={self._optim_niter})" + ) diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index 0f1241df7..42efeb472 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -10,19 +10,6 @@ from ._weighting_bases import Weighting -class PCGrad(GramianWeightedAggregator): - """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of - `Gradient Surgery for Multi-Task Learning `_. - """ - - def __init__(self) -> None: - super().__init__(PCGradWeighting()) - - # This prevents running into a RuntimeError due to modifying stored tensors in place. - self.register_full_backward_pre_hook(raise_non_differentiable_error) - - class PCGradWeighting(Weighting[PSDMatrix]): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of @@ -57,3 +44,16 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: weights = weights + current_weights return weights.to(device) + + +class PCGrad(GramianWeightedAggregator): + """ + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of + `Gradient Surgery for Multi-Task Learning `_. + """ + + def __init__(self) -> None: + super().__init__(PCGradWeighting()) + + # This prevents running into a RuntimeError due to modifying stored tensors in place. + self.register_full_backward_pre_hook(raise_non_differentiable_error) diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index 734dfc177..00f3859fb 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -8,18 +8,6 @@ from ._weighting_bases import Weighting -class Random(WeightedAggregator): - """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that computes a random combination of - the rows of the provided matrices, as defined in algorithm 2 of `Reasonable Effectiveness of - Random Weighting: A Litmus Test for Multi-Task Learning - `_. - """ - - def __init__(self) -> None: - super().__init__(RandomWeighting()) - - class RandomWeighting(Weighting[Matrix]): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights @@ -30,3 +18,15 @@ def forward(self, matrix: Tensor, /) -> Tensor: random_vector = torch.randn(matrix.shape[0], device=matrix.device, dtype=matrix.dtype) weights = F.softmax(random_vector, dim=-1) return weights + + +class Random(WeightedAggregator): + """ + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that computes a random combination of + the rows of the provided matrices, as defined in algorithm 2 of `Reasonable Effectiveness of + Random Weighting: A Litmus Test for Multi-Task Learning + `_. + """ + + def __init__(self) -> None: + super().__init__(RandomWeighting()) diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index aaf73f029..f68972b76 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -7,16 +7,6 @@ from ._weighting_bases import Weighting -class Sum(WeightedAggregator): - """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that sums of the rows of the input - matrices. - """ - - def __init__(self) -> None: - super().__init__(weighting=SumWeighting()) - - class SumWeighting(Weighting[Matrix]): r""" :class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights @@ -28,3 +18,13 @@ def forward(self, matrix: Tensor, /) -> Tensor: dtype = matrix.dtype weights = torch.ones(matrix.shape[0], device=device, dtype=dtype) return weights + + +class Sum(WeightedAggregator): + """ + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that sums of the rows of the input + matrices. + """ + + def __init__(self) -> None: + super().__init__(weighting=SumWeighting()) diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 45f760be9..41d9db55c 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -11,6 +11,42 @@ from ._weighting_bases import Weighting +class UPGradWeighting(Weighting[PSDMatrix]): + r""" + :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of + :class:`~torchjd.aggregation.UPGrad`. + + :param pref_vector: The preference vector to use. If not provided, defaults to + :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. + :param norm_eps: A small value to avoid division by zero when normalizing. + :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to + numerical errors when computing the gramian, it might not exactly be positive definite. + This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian + ensures that it is positive definite. + :param solver: The solver used to optimize the underlying optimization problem. + """ + + def __init__( + self, + pref_vector: Tensor | None = None, + norm_eps: float = 0.0001, + reg_eps: float = 0.0001, + solver: SUPPORTED_SOLVER = "quadprog", + ) -> None: + super().__init__() + self._pref_vector = pref_vector + self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) + self.norm_eps = norm_eps + self.reg_eps = reg_eps + self.solver: SUPPORTED_SOLVER = solver + + def forward(self, gramian: PSDMatrix, /) -> Tensor: + U = torch.diag(self.weighting(gramian)) + G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) + W = project_weights(U, G, self.solver) + return torch.sum(W, dim=0) + + class UPGrad(GramianWeightedAggregator): r""" :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that projects each row of the input @@ -55,39 +91,3 @@ def __repr__(self) -> str: def __str__(self) -> str: return f"UPGrad{pref_vector_to_str_suffix(self._pref_vector)}" - - -class UPGradWeighting(Weighting[PSDMatrix]): - r""" - :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of - :class:`~torchjd.aggregation.UPGrad`. - - :param pref_vector: The preference vector to use. If not provided, defaults to - :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. - :param norm_eps: A small value to avoid division by zero when normalizing. - :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to - numerical errors when computing the gramian, it might not exactly be positive definite. - This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian - ensures that it is positive definite. - :param solver: The solver used to optimize the underlying optimization problem. - """ - - def __init__( - self, - pref_vector: Tensor | None = None, - norm_eps: float = 0.0001, - reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", - ) -> None: - super().__init__() - self._pref_vector = pref_vector - self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) - self.norm_eps = norm_eps - self.reg_eps = reg_eps - self.solver: SUPPORTED_SOLVER = solver - - def forward(self, gramian: PSDMatrix, /) -> Tensor: - U = torch.diag(self.weighting(gramian)) - G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - W = project_weights(U, G, self.solver) - return torch.sum(W, dim=0) From eddd329c2dba6ef0a1da56d4f827e5ca6adcb511 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 16 Apr 2026 16:50:33 +0200 Subject: [PATCH 2/2] Add type hints to weighting and gramian_weighting fields --- src/torchjd/aggregation/_aligned_mtl.py | 2 ++ src/torchjd/aggregation/_cagrad.py | 2 ++ src/torchjd/aggregation/_constant.py | 2 ++ src/torchjd/aggregation/_dualproj.py | 2 ++ src/torchjd/aggregation/_gradvac.py | 2 ++ src/torchjd/aggregation/_imtl_g.py | 2 ++ src/torchjd/aggregation/_krum.py | 2 ++ src/torchjd/aggregation/_mean.py | 2 ++ src/torchjd/aggregation/_mgda.py | 2 ++ src/torchjd/aggregation/_nash_mtl.py | 2 ++ src/torchjd/aggregation/_pcgrad.py | 2 ++ src/torchjd/aggregation/_random.py | 2 ++ src/torchjd/aggregation/_sum.py | 2 ++ src/torchjd/aggregation/_upgrad.py | 2 ++ 14 files changed, 28 insertions(+) diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index 4227671d9..ced7ae459 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -96,6 +96,8 @@ class AlignedMTL(GramianWeightedAggregator): which is not available anymore at the time of writing. """ + gramian_weighting: AlignedMTLWeighting + def __init__( self, pref_vector: Tensor | None = None, diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index c6b42d960..b008fefb3 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -90,6 +90,8 @@ class CAGrad(GramianWeightedAggregator): To install it, use ``pip install "torchjd[cagrad]"``. """ + gramian_weighting: CAGradWeighting + def __init__(self, c: float, norm_eps: float = 0.0001) -> None: super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps)) self._c = c diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index 024c94ff1..0485e7261 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -45,6 +45,8 @@ class Constant(WeightedAggregator): :param weights: The weights associated to the rows of the input matrices. """ + weighting: ConstantWeighting + def __init__(self, weights: Tensor) -> None: super().__init__(weighting=ConstantWeighting(weights=weights)) self._weights = weights diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 61503a5bc..372cd18d1 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -63,6 +63,8 @@ class DualProj(GramianWeightedAggregator): :param solver: The solver used to optimize the underlying optimization problem. """ + gramian_weighting: DualProjWeighting + def __init__( self, pref_vector: Tensor | None = None, diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index 758070b51..cc518fbbc 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -166,6 +166,8 @@ class GradVac(GramianWeightedAggregator, Stateful): strategy, please refer to the :doc:`Grouping ` examples. """ + gramian_weighting: GradVacWeighting + def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: weighting = GradVacWeighting(beta=beta, eps=eps) super().__init__(weighting) diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 2c7828499..42062e93c 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -32,6 +32,8 @@ class IMTLG(GramianWeightedAggregator): `_, supports matrices with some linearly dependant rows. """ + gramian_weighting: IMTLGWeighting + def __init__(self) -> None: super().__init__(IMTLGWeighting()) diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index 5d017bf68..70e072024 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -80,6 +80,8 @@ class Krum(GramianWeightedAggregator): :param n_selected: The number of selected rows in the context of Multi-Krum. Defaults to 1. """ + gramian_weighting: KrumWeighting + def __init__(self, n_byzantine: int, n_selected: int = 1) -> None: self._n_byzantine = n_byzantine self._n_selected = n_selected diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index 2194541eb..2ebe208de 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -28,5 +28,7 @@ class Mean(WeightedAggregator): matrices. """ + weighting: MeanWeighting + def __init__(self) -> None: super().__init__(weighting=MeanWeighting()) diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index 6cb7d1de6..510fa7256 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -63,6 +63,8 @@ class MGDA(GramianWeightedAggregator): :param max_iters: The maximum number of iterations of the optimization loop. """ + gramian_weighting: MGDAWeighting + def __init__(self, epsilon: float = 0.001, max_iters: int = 100) -> None: super().__init__(MGDAWeighting(epsilon=epsilon, max_iters=max_iters)) self._epsilon = epsilon diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index b3c4672b2..e48b32c83 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -179,6 +179,8 @@ class NashMTL(WeightedAggregator, Stateful): experiments. """ + weighting: _NashMTLWeighting + def __init__( self, n_tasks: int, diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index 42efeb472..770ffe09f 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -52,6 +52,8 @@ class PCGrad(GramianWeightedAggregator): `Gradient Surgery for Multi-Task Learning `_. """ + gramian_weighting: PCGradWeighting + def __init__(self) -> None: super().__init__(PCGradWeighting()) diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index 00f3859fb..8345a15cb 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -28,5 +28,7 @@ class Random(WeightedAggregator): `_. """ + weighting: RandomWeighting + def __init__(self) -> None: super().__init__(RandomWeighting()) diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index f68972b76..0754f4668 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -26,5 +26,7 @@ class Sum(WeightedAggregator): matrices. """ + weighting: SumWeighting + def __init__(self) -> None: super().__init__(weighting=SumWeighting()) diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 41d9db55c..172e55a65 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -64,6 +64,8 @@ class UPGrad(GramianWeightedAggregator): :param solver: The solver used to optimize the underlying optimization problem. """ + gramian_weighting: UPGradWeighting + def __init__( self, pref_vector: Tensor | None = None,