diff --git a/CHANGELOG.md b/CHANGELOG.md index 46aa6333..fe52e1fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,9 +10,15 @@ changelog does not include internal changes that do not affect the user. ### Added -- Added `pref_vector`, `norm_eps`, and `reg_eps` getters and setters to `UPGrad` and - `UPGradWeighting`. The setters for `norm_eps` and `reg_eps` validate that the assigned value is - non-negative. +- Added getters and setters for the constructor parameters of all aggregators and weightings, so + that they can be changed after initialization. This includes: `pref_vector`, + `norm_eps` and `reg_eps` in `UPGrad`, `UPGradWeighting`, `DualProj` and `DualProjWeighting`; + `pref_vector` and `scale_mode` in `AlignedMTL` and `AlignedMTLWeighting`; `c` and `norm_eps` in + `CAGrad` and `CAGradWeighting`; `pref_vector` in `ConFIG`; `leak` in `GradDrop`, `n_byzantine` and + `n_selected` in `Krum` and `KrumWeighting`; `epsilon` and `max_iters` in `MGDA` and + `MGDAWeighting`; `n_tasks`, `max_norm`, `update_weights_every` and `optim_niter` in `NashMTL`; + `trim_number` in `TrimmedMean`. Setters validate their inputs matching the existing constructor + checks. Note that setters for `GradVac` and `GradVacWeighting` already existed. ## [0.10.0] - 2026-04-16 diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index 07574b1b..b0fbe3b4 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -35,17 +35,25 @@ def __init__( scale_mode: SUPPORTED_SCALE_MODE = "min", ) -> None: super().__init__() - self._pref_vector = pref_vector - self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode - self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) + self.pref_vector = pref_vector + self.scale_mode: SUPPORTED_SCALE_MODE = scale_mode def forward(self, gramian: PSDMatrix, /) -> Tensor: w = self.weighting(gramian) - B = self._compute_balance_transformation(gramian, self._scale_mode) + B = self._compute_balance_transformation(gramian, self.scale_mode) alpha = B @ w return alpha + @property + def pref_vector(self) -> Tensor | None: + return self._pref_vector + + @pref_vector.setter + def pref_vector(self, value: Tensor | None) -> None: + self.weighting = pref_vector_to_weighting(value, default=MeanWeighting()) + self._pref_vector = value + @staticmethod def _compute_balance_transformation( M: Tensor, @@ -103,15 +111,29 @@ def __init__( pref_vector: Tensor | None = None, scale_mode: SUPPORTED_SCALE_MODE = "min", ) -> None: - self._pref_vector = pref_vector - self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode)) + @property + def pref_vector(self) -> Tensor | None: + return self.gramian_weighting.pref_vector + + @pref_vector.setter + def pref_vector(self, value: Tensor | None) -> None: + self.gramian_weighting.pref_vector = value + + @property + def scale_mode(self) -> SUPPORTED_SCALE_MODE: + return self.gramian_weighting.scale_mode + + @scale_mode.setter + def scale_mode(self, value: SUPPORTED_SCALE_MODE) -> None: + self.gramian_weighting.scale_mode = value + def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, " - f"scale_mode={repr(self._scale_mode)})" + f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, " + f"scale_mode={repr(self.scale_mode)})" ) def __str__(self) -> str: - return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}" + return f"AlignedMTL{pref_vector_to_str_suffix(self.pref_vector)}" diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 88ca66e0..40ce8515 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -37,10 +37,6 @@ class CAGradWeighting(GramianWeighting): def __init__(self, c: float, norm_eps: float = 0.0001) -> None: super().__init__() - - if c < 0.0: - raise ValueError(f"Parameter `c` should be a non-negative float. Found `c = {c}`.") - self.c = c self.norm_eps = norm_eps @@ -73,6 +69,28 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: return weights + @property + def c(self) -> float: + return self._c + + @c.setter + def c(self, value: float) -> None: + if value < 0: + raise ValueError(f"c must be non-negative, but got {value}.") + + self._c = value + + @property + def norm_eps(self) -> float: + return self._norm_eps + + @norm_eps.setter + def norm_eps(self, value: float) -> None: + if value < 0: + raise ValueError(f"norm_eps must be non-negative, but got {value}.") + + self._norm_eps = value + class CAGrad(GramianWeightedAggregator): """ @@ -94,15 +112,29 @@ class CAGrad(GramianWeightedAggregator): def __init__(self, c: float, norm_eps: float = 0.0001) -> None: super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps)) - self._c = c - self._norm_eps = norm_eps # This prevents considering the computed weights as constant w.r.t. the matrix. self.register_full_backward_pre_hook(raise_non_differentiable_error) + @property + def c(self) -> float: + return self.gramian_weighting.c + + @c.setter + def c(self, value: float) -> None: + self.gramian_weighting.c = value + + @property + def norm_eps(self) -> float: + return self.gramian_weighting.norm_eps + + @norm_eps.setter + def norm_eps(self, value: float) -> None: + self.gramian_weighting.norm_eps = value + def __repr__(self) -> str: - return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})" + return f"{self.__class__.__name__}(c={self.c}, norm_eps={self.norm_eps})" def __str__(self) -> str: - c_str = str(self._c).rstrip("0") + c_str = str(self.c).rstrip("0") return f"CAGrad{c_str}" diff --git a/src/torchjd/aggregation/_config.py b/src/torchjd/aggregation/_config.py index 261f3a64..f19c4023 100644 --- a/src/torchjd/aggregation/_config.py +++ b/src/torchjd/aggregation/_config.py @@ -29,8 +29,7 @@ class ConFIG(Aggregator): def __init__(self, pref_vector: Tensor | None = None) -> None: super().__init__() - self.weighting = pref_vector_to_weighting(pref_vector, default=SumWeighting()) - self._pref_vector = pref_vector + self.pref_vector = pref_vector # This prevents computing gradients that can be very wrong. self.register_full_backward_pre_hook(raise_non_differentiable_error) @@ -46,8 +45,17 @@ def forward(self, matrix: Matrix, /) -> Tensor: return length * unit_target_vector + @property + def pref_vector(self) -> Tensor | None: + return self._pref_vector + + @pref_vector.setter + def pref_vector(self, value: Tensor | None) -> None: + self.weighting = pref_vector_to_weighting(value, default=SumWeighting()) + self._pref_vector = value + def __repr__(self) -> str: - return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})" + return f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)})" def __str__(self) -> str: - return f"ConFIG{pref_vector_to_str_suffix(self._pref_vector)}" + return f"ConFIG{pref_vector_to_str_suffix(self.pref_vector)}" diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 087e8805..acb87d2f 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -33,8 +33,7 @@ def __init__( solver: SUPPORTED_SOLVER = "quadprog", ) -> None: super().__init__() - self._pref_vector = pref_vector - self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) + self.pref_vector = pref_vector self.norm_eps = norm_eps self.reg_eps = reg_eps self.solver: SUPPORTED_SOLVER = solver @@ -45,6 +44,37 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: w = project_weights(u, G, self.solver) return w + @property + def pref_vector(self) -> Tensor | None: + return self._pref_vector + + @pref_vector.setter + def pref_vector(self, value: Tensor | None) -> None: + self.weighting = pref_vector_to_weighting(value, default=MeanWeighting()) + self._pref_vector = value + + @property + def norm_eps(self) -> float: + return self._norm_eps + + @norm_eps.setter + def norm_eps(self, value: float) -> None: + if value < 0: + raise ValueError(f"norm_eps must be non-negative, but got {value}.") + + self._norm_eps = value + + @property + def reg_eps(self) -> float: + return self._reg_eps + + @reg_eps.setter + def reg_eps(self, value: float) -> None: + if value < 0: + raise ValueError(f"reg_eps must be non-negative, but got {value}.") + + self._reg_eps = value + class DualProj(GramianWeightedAggregator): r""" @@ -72,9 +102,6 @@ def __init__( reg_eps: float = 0.0001, solver: SUPPORTED_SOLVER = "quadprog", ) -> None: - self._pref_vector = pref_vector - self._norm_eps = norm_eps - self._reg_eps = reg_eps self._solver: SUPPORTED_SOLVER = solver super().__init__( @@ -84,11 +111,35 @@ def __init__( # This prevents considering the computed weights as constant w.r.t. the matrix. self.register_full_backward_pre_hook(raise_non_differentiable_error) + @property + def pref_vector(self) -> Tensor | None: + return self.gramian_weighting.pref_vector + + @pref_vector.setter + def pref_vector(self, value: Tensor | None) -> None: + self.gramian_weighting.pref_vector = value + + @property + def norm_eps(self) -> float: + return self.gramian_weighting.norm_eps + + @norm_eps.setter + def norm_eps(self, value: float) -> None: + self.gramian_weighting.norm_eps = value + + @property + def reg_eps(self) -> float: + return self.gramian_weighting.reg_eps + + @reg_eps.setter + def reg_eps(self, value: float) -> None: + self.gramian_weighting.reg_eps = value + def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps=" - f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})" + f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" + f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})" ) def __str__(self) -> str: - return f"DualProj{pref_vector_to_str_suffix(self._pref_vector)}" + return f"DualProj{pref_vector_to_str_suffix(self.pref_vector)}" diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py index 61c9354e..fc67810d 100644 --- a/src/torchjd/aggregation/_graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -27,12 +27,6 @@ class GradDrop(Aggregator): """ def __init__(self, f: Callable = _identity, leak: Tensor | None = None) -> None: - if leak is not None and leak.dim() != 1: - raise ValueError( - "Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = " - f"{leak.shape}`.", - ) - super().__init__() self.f = f self.leak = leak @@ -59,6 +53,19 @@ def forward(self, matrix: Matrix, /) -> Tensor: return vector + @property + def leak(self) -> Tensor | None: + return self._leak + + @leak.setter + def leak(self, value: Tensor | None) -> None: + if value is not None and value.dim() != 1: + raise ValueError( + f"leak must be a 1-dimensional tensor. Found leak.shape = {value.shape}.", + ) + + self._leak = value + def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None: n_rows = matrix.shape[0] if self.leak is not None and n_rows != len(self.leak): diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index 5a1edc35..26a075b9 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -41,13 +41,8 @@ class GradVacWeighting(GramianWeighting, Stateful): def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: super().__init__() - if not (0.0 <= beta <= 1.0): - raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.") - if eps < 0.0: - raise ValueError(f"Parameter `eps` must be non-negative. Found eps={eps!r}.") - - self._beta = beta - self._eps = eps + self.beta = beta + self.eps = eps self._phi_t: Tensor | None = None self._state_key: tuple[int, torch.dtype] | None = None diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index d48f8918..046c0a46 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -20,18 +20,6 @@ class KrumWeighting(GramianWeighting): def __init__(self, n_byzantine: int, n_selected: int = 1) -> None: super().__init__() - if n_byzantine < 0: - raise ValueError( - "Parameter `n_byzantine` should be a non-negative integer. Found `n_byzantine = " - f"{n_byzantine}`.", - ) - - if n_selected < 1: - raise ValueError( - "Parameter `n_selected` should be a positive integer. Found `n_selected = " - f"{n_selected}`.", - ) - self.n_byzantine = n_byzantine self.n_selected = n_selected @@ -54,6 +42,28 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: return weights + @property + def n_byzantine(self) -> int: + return self._n_byzantine + + @n_byzantine.setter + def n_byzantine(self, value: int) -> None: + if value < 0: + raise ValueError(f"n_byzantine must be non-negative, but got {value}.") + + self._n_byzantine = value + + @property + def n_selected(self) -> int: + return self._n_selected + + @n_selected.setter + def n_selected(self, value: int) -> None: + if value < 1: + raise ValueError(f"n_selected must be a positive integer, but got {value}.") + + self._n_selected = value + def _check_matrix_shape(self, gramian: PSDMatrix) -> None: min_rows = self.n_byzantine + 3 if gramian.shape[0] < min_rows: @@ -83,15 +93,29 @@ class Krum(GramianWeightedAggregator): gramian_weighting: KrumWeighting def __init__(self, n_byzantine: int, n_selected: int = 1) -> None: - self._n_byzantine = n_byzantine - self._n_selected = n_selected super().__init__(KrumWeighting(n_byzantine=n_byzantine, n_selected=n_selected)) + @property + def n_byzantine(self) -> int: + return self.gramian_weighting.n_byzantine + + @n_byzantine.setter + def n_byzantine(self, value: int) -> None: + self.gramian_weighting.n_byzantine = value + + @property + def n_selected(self) -> int: + return self.gramian_weighting.n_selected + + @n_selected.setter + def n_selected(self, value: int) -> None: + self.gramian_weighting.n_selected = value + def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(n_byzantine={self._n_byzantine}, n_selected=" - f"{self._n_selected})" + f"{self.__class__.__name__}(n_byzantine={self.n_byzantine}, n_selected=" + f"{self.n_selected})" ) def __str__(self) -> str: - return f"Krum{self._n_byzantine}-{self._n_selected}" + return f"Krum{self.n_byzantine}-{self.n_selected}" diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index 575f21a4..33c727aa 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -49,6 +49,28 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: break return alpha + @property + def epsilon(self) -> float: + return self._epsilon + + @epsilon.setter + def epsilon(self, value: float) -> None: + if value <= 0: + raise ValueError(f"epsilon must be positive, but got {value}.") + + self._epsilon = value + + @property + def max_iters(self) -> int: + return self._max_iters + + @max_iters.setter + def max_iters(self, value: int) -> None: + if value <= 0: + raise ValueError(f"max_iters must be a positive integer, but got {value}.") + + self._max_iters = value + class MGDA(GramianWeightedAggregator): r""" @@ -67,8 +89,22 @@ class MGDA(GramianWeightedAggregator): def __init__(self, epsilon: float = 0.001, max_iters: int = 100) -> None: super().__init__(MGDAWeighting(epsilon=epsilon, max_iters=max_iters)) - self._epsilon = epsilon - self._max_iters = max_iters + + @property + def epsilon(self) -> float: + return self.gramian_weighting.epsilon + + @epsilon.setter + def epsilon(self, value: float) -> None: + self.gramian_weighting.epsilon = value + + @property + def max_iters(self) -> int: + return self.gramian_weighting.max_iters + + @max_iters.setter + def max_iters(self, value: int) -> None: + self.gramian_weighting.max_iters = value def __repr__(self) -> str: - return f"{self.__class__.__name__}(epsilon={self._epsilon}, max_iters={self._max_iters})" + return f"{self.__class__.__name__}(epsilon={self.epsilon}, max_iters={self.max_iters})" diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 9cd3e7bc..65604d35 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -8,8 +8,6 @@ check_dependencies_are_installed(["cvxpy", "ecos"]) -from typing import cast - import cvxpy as cp import numpy as np import torch @@ -28,11 +26,17 @@ class _NashMTLWeighting(MatrixWeighting, Stateful): :param n_tasks: The number of tasks, corresponding to the number of rows in the provided matrices. - :param max_norm: Maximum value of the norm of :math:`J^T w`. + :param max_norm: Maximum value of the norm of :math:`J^T w`. A value of ``0`` disables the + norm clipping. :param update_weights_every: A parameter determining how often the actual weighting should be performed. A larger value means that the same weights will be re-used for more calls to the weighting. :param optim_niter: The number of iterations of the underlying optimization process. + + .. note:: + Changing any of these parameters after instantiation does not automatically reset the + internal state. Call :meth:`reset` if needed (especially after changing ``n_tasks``, which + affects the shape of the cached state). """ def __init__( @@ -55,6 +59,52 @@ def __init__( self.step = 0.0 self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32) + @property + def n_tasks(self) -> int: + return self._n_tasks + + @n_tasks.setter + def n_tasks(self, value: int) -> None: + if value <= 0: + raise ValueError(f"n_tasks must be a positive integer, but got {value}.") + + self._n_tasks = value + + @property + def max_norm(self) -> float: + return self._max_norm + + @max_norm.setter + def max_norm(self, value: float) -> None: + if value < 0: + raise ValueError(f"max_norm must be non-negative, but got {value}.") + + self._max_norm = value + + @property + def update_weights_every(self) -> int: + return self._update_weights_every + + @update_weights_every.setter + def update_weights_every(self, value: int) -> None: + if value <= 0: + raise ValueError( + f"update_weights_every must be a positive integer, but got {value}.", + ) + + self._update_weights_every = value + + @property + def optim_niter(self) -> int: + return self._optim_niter + + @optim_niter.setter + def optim_niter(self, value: int) -> None: + if value <= 0: + raise ValueError(f"optim_niter must be a positive integer, but got {value}.") + + self._optim_niter = value + def _stop_criteria(self, gtg: np.ndarray, alpha_t: np.ndarray) -> bool: return bool( (self.alpha_param.value is None) @@ -156,7 +206,8 @@ class NashMTL(WeightedAggregator, Stateful): :param n_tasks: The number of tasks, corresponding to the number of rows in the provided matrices. - :param max_norm: Maximum value of the norm of :math:`J^T w`. + :param max_norm: Maximum value of the norm of :math:`J^T w`. A value of ``0`` disables the + norm clipping. :param update_weights_every: A parameter determining how often the actual weighting should be performed. A larger value means that the same weights will be re-used for more calls to the aggregator. @@ -176,6 +227,11 @@ class NashMTL(WeightedAggregator, Stateful): This aggregator is stateful. Its output will thus depend not only on the input matrix, but also on its state. It thus depends on previously seen matrices. It should be reset between experiments. + + .. note:: + Changing any of these parameters after instantiation does not automatically reset the + internal state. Call :meth:`reset` if needed (especially after changing ``n_tasks``, which + affects the shape of the cached state). """ weighting: _NashMTLWeighting @@ -195,20 +251,48 @@ def __init__( optim_niter=optim_niter, ), ) - self._n_tasks = n_tasks - self._max_norm = max_norm - self._update_weights_every = update_weights_every - self._optim_niter = optim_niter # This prevents considering the computed weights as constant w.r.t. the matrix. self.register_full_backward_pre_hook(raise_non_differentiable_error) + @property + def n_tasks(self) -> int: + return self.weighting.n_tasks + + @n_tasks.setter + def n_tasks(self, value: int) -> None: + self.weighting.n_tasks = value + + @property + def max_norm(self) -> float: + return self.weighting.max_norm + + @max_norm.setter + def max_norm(self, value: float) -> None: + self.weighting.max_norm = value + + @property + def update_weights_every(self) -> int: + return self.weighting.update_weights_every + + @update_weights_every.setter + def update_weights_every(self, value: int) -> None: + self.weighting.update_weights_every = value + + @property + def optim_niter(self) -> int: + return self.weighting.optim_niter + + @optim_niter.setter + def optim_niter(self, value: int) -> None: + self.weighting.optim_niter = value + def reset(self) -> None: """Resets the internal state of the algorithm.""" - cast(_NashMTLWeighting, self.weighting).reset() + self.weighting.reset() def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(n_tasks={self._n_tasks}, max_norm={self._max_norm}, " - f"update_weights_every={self._update_weights_every}, optim_niter={self._optim_niter})" + f"{self.__class__.__name__}(n_tasks={self.n_tasks}, max_norm={self.max_norm}, " + f"update_weights_every={self.update_weights_every}, optim_niter={self.optim_niter})" ) diff --git a/src/torchjd/aggregation/_trimmed_mean.py b/src/torchjd/aggregation/_trimmed_mean.py index 8dffe990..013e8e57 100644 --- a/src/torchjd/aggregation/_trimmed_mean.py +++ b/src/torchjd/aggregation/_trimmed_mean.py @@ -17,11 +17,6 @@ class TrimmedMean(Aggregator): def __init__(self, trim_number: int) -> None: super().__init__() - if trim_number < 0: - raise ValueError( - "Parameter `trim_number` should be a non-negative integer. Found `trim_number` = " - f"{trim_number}`.", - ) self.trim_number = trim_number def forward(self, matrix: Tensor, /) -> Tensor: @@ -35,6 +30,17 @@ def forward(self, matrix: Tensor, /) -> Tensor: vector = trimmed.mean(dim=0) return vector + @property + def trim_number(self) -> int: + return self._trim_number + + @trim_number.setter + def trim_number(self, value: int) -> None: + if value < 0: + raise ValueError(f"trim_number must be non-negative, but got {value}.") + + self._trim_number = value + def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None: min_rows = 1 + 2 * self.trim_number n_rows = matrix.shape[0] diff --git a/tests/unit/aggregation/test_aligned_mtl.py b/tests/unit/aggregation/test_aligned_mtl.py index d48e8855..6eacfba9 100644 --- a/tests/unit/aggregation/test_aligned_mtl.py +++ b/tests/unit/aggregation/test_aligned_mtl.py @@ -3,7 +3,7 @@ from torch import Tensor from utils.tensors import ones_ -from torchjd.aggregation import AlignedMTL +from torchjd.aggregation import AlignedMTL, ConstantWeighting from ._asserts import assert_expected_structure, assert_permutation_invariant from ._inputs import scaled_matrices, typical_matrices @@ -43,3 +43,19 @@ def test_invalid_scale_mode() -> None: matrix = ones_(3, 4) with raises(ValueError, match=r"Invalid scale_mode=.*Expected"): aggregator(matrix) + + +def test_pref_vector_setter_updates_value() -> None: + A = AlignedMTL() + new_pref = torch.tensor([1.0, 2.0, 3.0]) + A.pref_vector = new_pref + assert A.pref_vector is new_pref + assert isinstance(A.gramian_weighting.weighting, ConstantWeighting) + assert A.gramian_weighting.weighting.weights is new_pref + + +def test_scale_mode_setter_updates_value() -> None: + A = AlignedMTL() + A.scale_mode = "rmse" + assert A.scale_mode == "rmse" + assert A.gramian_weighting.scale_mode == "rmse" diff --git a/tests/unit/aggregation/test_cagrad.py b/tests/unit/aggregation/test_cagrad.py index c7d18b1f..9128899f 100644 --- a/tests/unit/aggregation/test_cagrad.py +++ b/tests/unit/aggregation/test_cagrad.py @@ -7,6 +7,7 @@ try: from torchjd.aggregation import CAGrad + from torchjd.aggregation._cagrad import CAGradWeighting except ImportError: import pytest @@ -57,3 +58,41 @@ def test_representations() -> None: A = CAGrad(c=0.5, norm_eps=0.0001) assert repr(A) == "CAGrad(c=0.5, norm_eps=0.0001)" assert str(A) == "CAGrad0.5" + + +def test_c_setter_updates_value() -> None: + A = CAGrad(c=0.5) + A.c = 1.25 + assert A.c == 1.25 + assert A.gramian_weighting.c == 1.25 + + +def test_norm_eps_setter_updates_value() -> None: + A = CAGrad(c=0.5) + A.norm_eps = 0.25 + assert A.norm_eps == 0.25 + assert A.gramian_weighting.norm_eps == 0.25 + + +def test_c_setter_rejects_negative() -> None: + A = CAGrad(c=0.5) + with raises(ValueError, match="c"): + A.c = -1e-9 + + +def test_norm_eps_setter_rejects_negative() -> None: + A = CAGrad(c=0.5) + with raises(ValueError, match="norm_eps"): + A.norm_eps = -1e-9 + + +def test_weighting_c_setter_rejects_negative() -> None: + W = CAGradWeighting(c=0.5) + with raises(ValueError, match="c"): + W.c = -1e-9 + + +def test_weighting_norm_eps_setter_rejects_negative() -> None: + W = CAGradWeighting(c=0.5) + with raises(ValueError, match="norm_eps"): + W.norm_eps = -1e-9 diff --git a/tests/unit/aggregation/test_config.py b/tests/unit/aggregation/test_config.py index 2db2ea0f..fe951640 100644 --- a/tests/unit/aggregation/test_config.py +++ b/tests/unit/aggregation/test_config.py @@ -3,7 +3,7 @@ from torch import Tensor from utils.tensors import ones_ -from torchjd.aggregation import ConFIG +from torchjd.aggregation import ConFIG, ConstantWeighting from ._asserts import ( assert_expected_structure, @@ -47,3 +47,12 @@ def test_representations() -> None: A = ConFIG(pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu")) assert repr(A) == "ConFIG(pref_vector=tensor([1., 2., 3.]))" assert str(A) == "ConFIG([1., 2., 3.])" + + +def test_pref_vector_setter_updates_value() -> None: + A = ConFIG() + new_pref = torch.tensor([1.0, 2.0, 3.0]) + A.pref_vector = new_pref + assert A.pref_vector is new_pref + assert isinstance(A.weighting, ConstantWeighting) + assert A.weighting.weights is new_pref diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index 5bd0e71a..34fe8d46 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -1,9 +1,10 @@ import torch -from pytest import mark +from pytest import mark, raises from torch import Tensor from utils.tensors import ones_ -from torchjd.aggregation import DualProj +from torchjd.aggregation import ConstantWeighting, DualProj +from torchjd.aggregation._dualproj import DualProjWeighting from ._asserts import ( assert_expected_structure, @@ -63,3 +64,50 @@ def test_representations() -> None: "solver='quadprog')" ) assert str(A) == "DualProj([1., 2., 3.])" + + +def test_pref_vector_setter_updates_value() -> None: + A = DualProj() + new_pref = torch.tensor([1.0, 2.0, 3.0]) + A.pref_vector = new_pref + assert A.pref_vector is new_pref + assert isinstance(A.gramian_weighting.weighting, ConstantWeighting) + assert A.gramian_weighting.weighting.weights is new_pref + + +def test_norm_eps_setter_updates_value() -> None: + A = DualProj() + A.norm_eps = 0.25 + assert A.norm_eps == 0.25 + assert A.gramian_weighting.norm_eps == 0.25 + + +def test_reg_eps_setter_updates_value() -> None: + A = DualProj() + A.reg_eps = 0.25 + assert A.reg_eps == 0.25 + assert A.gramian_weighting.reg_eps == 0.25 + + +def test_norm_eps_setter_rejects_negative() -> None: + A = DualProj() + with raises(ValueError, match="norm_eps"): + A.norm_eps = -1e-9 + + +def test_reg_eps_setter_rejects_negative() -> None: + A = DualProj() + with raises(ValueError, match="reg_eps"): + A.reg_eps = -1e-9 + + +def test_weighting_norm_eps_setter_rejects_negative() -> None: + W = DualProjWeighting() + with raises(ValueError, match="norm_eps"): + W.norm_eps = -1e-9 + + +def test_weighting_reg_eps_setter_rejects_negative() -> None: + W = DualProjWeighting() + with raises(ValueError, match="reg_eps"): + W.reg_eps = -1e-9 diff --git a/tests/unit/aggregation/test_graddrop.py b/tests/unit/aggregation/test_graddrop.py index 2868dca0..d21ef690 100644 --- a/tests/unit/aggregation/test_graddrop.py +++ b/tests/unit/aggregation/test_graddrop.py @@ -83,3 +83,22 @@ def test_representations() -> None: A = GradDrop() assert re.match(r"GradDrop\(f=, leak=None\)", repr(A)) assert str(A) == "GradDrop" + + +def test_leak_setter_updates_value() -> None: + A = GradDrop() + new_leak = torch.tensor([0.0, 0.5, 1.0]) + A.leak = new_leak + assert A.leak is new_leak + + +def test_leak_setter_accepts_none() -> None: + A = GradDrop(leak=torch.tensor([0.0, 1.0])) + A.leak = None + assert A.leak is None + + +def test_leak_setter_rejects_non_1d() -> None: + A = GradDrop() + with raises(ValueError, match="leak"): + A.leak = torch.tensor([[0.0, 1.0], [1.0, 0.0]]) diff --git a/tests/unit/aggregation/test_gradvac.py b/tests/unit/aggregation/test_gradvac.py index bde2e8fd..28db6090 100644 --- a/tests/unit/aggregation/test_gradvac.py +++ b/tests/unit/aggregation/test_gradvac.py @@ -20,48 +20,6 @@ def test_representations() -> None: assert str(A) == "GradVac" -def test_beta_out_of_range() -> None: - with raises(ValueError, match="beta"): - GradVac(beta=-0.1) - with raises(ValueError, match="beta"): - GradVac(beta=1.1) - - -def test_beta_setter_out_of_range() -> None: - A = GradVac() - with raises(ValueError, match="beta"): - A.beta = -0.1 - with raises(ValueError, match="beta"): - A.beta = 1.1 - - -def test_beta_setter_updates_value() -> None: - A = GradVac() - A.beta = 0.25 - assert A.beta == 0.25 - - -def test_eps_rejects_negative() -> None: - with raises(ValueError, match="eps"): - GradVac(eps=-1e-9) - - -def test_eps_setter_rejects_negative() -> None: - A = GradVac() - with raises(ValueError, match="eps"): - A.eps = -1e-9 - - -def test_eps_can_be_changed_between_steps() -> None: - J = tensor_([[1.0, 0.0], [0.0, 1.0]]) - A = GradVac() - A.eps = 1e-6 - assert A(J).isfinite().all() - A.reset() - A.eps = 1e-10 - assert A(J).isfinite().all() - - def test_zero_rows_returns_zero_vector() -> None: out = GradVac()(tensor_([]).reshape(0, 3)) assert_close(out, tensor_([0.0, 0.0, 0.0])) @@ -104,18 +62,6 @@ def test_non_differentiable(aggregator: GradVac, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) -def test_weighting_beta_out_of_range() -> None: - with raises(ValueError, match="beta"): - GradVacWeighting(beta=-0.1) - with raises(ValueError, match="beta"): - GradVacWeighting(beta=1.1) - - -def test_weighting_eps_rejects_negative() -> None: - with raises(ValueError, match="eps"): - GradVacWeighting(eps=-1e-9) - - def test_weighting_reset_restores_first_step_behavior() -> None: J = randn_((3, 8)) G = J @ J.T @@ -144,3 +90,45 @@ def test_aggregator_and_weighting_agree() -> None: result = weights @ J assert_close(result, expected, rtol=1e-4, atol=1e-4) + + +def test_beta_setter_updates_value() -> None: + A = GradVac() + A.beta = 0.25 + assert A.beta == 0.25 + assert A.gramian_weighting.beta == 0.25 + + +def test_eps_setter_updates_value() -> None: + A = GradVac() + A.eps = 1e-6 + assert A.eps == 1e-6 + assert A.gramian_weighting.eps == 1e-6 + + +def test_beta_setter_rejects_out_of_range() -> None: + A = GradVac() + with raises(ValueError, match="beta"): + A.beta = -0.1 + with raises(ValueError, match="beta"): + A.beta = 1.1 + + +def test_eps_setter_rejects_negative() -> None: + A = GradVac() + with raises(ValueError, match="eps"): + A.eps = -1e-9 + + +def test_weighting_beta_setter_rejects_out_of_range() -> None: + W = GradVacWeighting() + with raises(ValueError, match="beta"): + W.beta = -0.1 + with raises(ValueError, match="beta"): + W.beta = 1.1 + + +def test_weighting_eps_setter_rejects_negative() -> None: + W = GradVacWeighting() + with raises(ValueError, match="eps"): + W.eps = -1e-9 diff --git a/tests/unit/aggregation/test_krum.py b/tests/unit/aggregation/test_krum.py index 4097f2eb..6270f677 100644 --- a/tests/unit/aggregation/test_krum.py +++ b/tests/unit/aggregation/test_krum.py @@ -6,6 +6,7 @@ from utils.tensors import ones_ from torchjd.aggregation import Krum +from torchjd.aggregation._krum import KrumWeighting from ._asserts import assert_expected_structure from ._inputs import scaled_matrices_2_plus_rows, typical_matrices_2_plus_rows @@ -78,3 +79,43 @@ def test_representations() -> None: A = Krum(n_byzantine=1, n_selected=2) assert repr(A) == "Krum(n_byzantine=1, n_selected=2)" assert str(A) == "Krum1-2" + + +def test_n_byzantine_setter_updates_value() -> None: + A = Krum(n_byzantine=1) + A.n_byzantine = 3 + assert A.n_byzantine == 3 + assert A.gramian_weighting.n_byzantine == 3 + + +def test_n_selected_setter_updates_value() -> None: + A = Krum(n_byzantine=1) + A.n_selected = 3 + assert A.n_selected == 3 + assert A.gramian_weighting.n_selected == 3 + + +def test_n_byzantine_setter_rejects_negative() -> None: + A = Krum(n_byzantine=1) + with raises(ValueError, match="n_byzantine"): + A.n_byzantine = -1 + + +def test_n_selected_setter_rejects_non_positive() -> None: + A = Krum(n_byzantine=1) + with raises(ValueError, match="n_selected"): + A.n_selected = 0 + with raises(ValueError, match="n_selected"): + A.n_selected = -1 + + +def test_weighting_n_byzantine_setter_rejects_negative() -> None: + W = KrumWeighting(n_byzantine=1) + with raises(ValueError, match="n_byzantine"): + W.n_byzantine = -1 + + +def test_weighting_n_selected_setter_rejects_non_positive() -> None: + W = KrumWeighting(n_byzantine=1) + with raises(ValueError, match="n_selected"): + W.n_selected = 0 diff --git a/tests/unit/aggregation/test_mgda.py b/tests/unit/aggregation/test_mgda.py index 5c925b8f..f6c67236 100644 --- a/tests/unit/aggregation/test_mgda.py +++ b/tests/unit/aggregation/test_mgda.py @@ -1,4 +1,4 @@ -from pytest import mark +from pytest import mark, raises from torch import Tensor from torch.testing import assert_close from utils.tensors import ones_, randn_ @@ -70,3 +70,45 @@ def test_representations() -> None: A = MGDA(epsilon=0.001, max_iters=100) assert repr(A) == "MGDA(epsilon=0.001, max_iters=100)" assert str(A) == "MGDA" + + +def test_epsilon_setter_updates_value() -> None: + A = MGDA() + A.epsilon = 0.25 + assert A.epsilon == 0.25 + assert A.gramian_weighting.epsilon == 0.25 + + +def test_max_iters_setter_updates_value() -> None: + A = MGDA() + A.max_iters = 42 + assert A.max_iters == 42 + assert A.gramian_weighting.max_iters == 42 + + +def test_epsilon_setter_rejects_non_positive() -> None: + A = MGDA() + with raises(ValueError, match="epsilon"): + A.epsilon = 0.0 + with raises(ValueError, match="epsilon"): + A.epsilon = -1e-9 + + +def test_max_iters_setter_rejects_non_positive() -> None: + A = MGDA() + with raises(ValueError, match="max_iters"): + A.max_iters = 0 + with raises(ValueError, match="max_iters"): + A.max_iters = -1 + + +def test_weighting_epsilon_setter_rejects_non_positive() -> None: + W = MGDAWeighting() + with raises(ValueError, match="epsilon"): + W.epsilon = 0.0 + + +def test_weighting_max_iters_setter_rejects_non_positive() -> None: + W = MGDAWeighting() + with raises(ValueError, match="max_iters"): + W.max_iters = 0 diff --git a/tests/unit/aggregation/test_nash_mtl.py b/tests/unit/aggregation/test_nash_mtl.py index d82fca41..6dac1f0e 100644 --- a/tests/unit/aggregation/test_nash_mtl.py +++ b/tests/unit/aggregation/test_nash_mtl.py @@ -1,10 +1,11 @@ -from pytest import mark +from pytest import mark, raises from torch import Tensor from torch.testing import assert_close from utils.tensors import ones_, randn_, tensor_ try: from torchjd.aggregation import NashMTL + from torchjd.aggregation._nash_mtl import _NashMTLWeighting except ImportError: import pytest @@ -72,3 +73,57 @@ def test_representations() -> None: A = NashMTL(n_tasks=2, max_norm=1.5, update_weights_every=2, optim_niter=5) assert repr(A) == "NashMTL(n_tasks=2, max_norm=1.5, update_weights_every=2, optim_niter=5)" assert str(A) == "NashMTL" + + +def test_setters_update_values() -> None: + A = NashMTL(n_tasks=2) + A.n_tasks = 4 + A.max_norm = 2.5 + A.update_weights_every = 3 + A.optim_niter = 7 + assert A.n_tasks == 4 + assert A.max_norm == 2.5 + assert A.update_weights_every == 3 + assert A.optim_niter == 7 + assert A.weighting.n_tasks == 4 + assert A.weighting.max_norm == 2.5 + assert A.weighting.update_weights_every == 3 + assert A.weighting.optim_niter == 7 + + +def test_n_tasks_setter_rejects_non_positive() -> None: + A = NashMTL(n_tasks=2) + with raises(ValueError, match="n_tasks"): + A.n_tasks = 0 + with raises(ValueError, match="n_tasks"): + A.n_tasks = -1 + + +def test_max_norm_setter_rejects_negative() -> None: + A = NashMTL(n_tasks=2) + with raises(ValueError, match="max_norm"): + A.max_norm = -1e-9 + + +def test_update_weights_every_setter_rejects_non_positive() -> None: + A = NashMTL(n_tasks=2) + with raises(ValueError, match="update_weights_every"): + A.update_weights_every = 0 + + +def test_optim_niter_setter_rejects_non_positive() -> None: + A = NashMTL(n_tasks=2) + with raises(ValueError, match="optim_niter"): + A.optim_niter = 0 + + +def test_weighting_setters_validate() -> None: + W = _NashMTLWeighting(n_tasks=2, max_norm=1.0, update_weights_every=1, optim_niter=5) + with raises(ValueError, match="n_tasks"): + W.n_tasks = 0 + with raises(ValueError, match="max_norm"): + W.max_norm = -1.0 + with raises(ValueError, match="update_weights_every"): + W.update_weights_every = 0 + with raises(ValueError, match="optim_niter"): + W.optim_niter = 0 diff --git a/tests/unit/aggregation/test_trimmed_mean.py b/tests/unit/aggregation/test_trimmed_mean.py index 3a6ccb2b..7217a0e0 100644 --- a/tests/unit/aggregation/test_trimmed_mean.py +++ b/tests/unit/aggregation/test_trimmed_mean.py @@ -61,3 +61,15 @@ def test_representations() -> None: aggregator = TrimmedMean(trim_number=2) assert repr(aggregator) == "TrimmedMean(trim_number=2)" assert str(aggregator) == "TM2" + + +def test_trim_number_setter_updates_value() -> None: + A = TrimmedMean(trim_number=1) + A.trim_number = 3 + assert A.trim_number == 3 + + +def test_trim_number_setter_rejects_negative() -> None: + A = TrimmedMean(trim_number=1) + with raises(ValueError, match="trim_number"): + A.trim_number = -1