From f1bf5b6353be129770f8c583c0eb32e1be28f0a0 Mon Sep 17 00:00:00 2001 From: mattbuot Date: Mon, 20 Apr 2026 21:17:03 +0200 Subject: [PATCH 01/15] feat(aggregation): Add getters and setters to DualProj parameters Expose pref_vector, norm_eps, and reg_eps as properties on DualProj and DualProjWeighting, mirroring the UPGrad pattern. The norm_eps and reg_eps setters validate that the new value is non-negative. --- src/torchjd/aggregation/_dualproj.py | 67 ++++++++++++++++++++++--- tests/unit/aggregation/test_dualproj.py | 50 +++++++++++++++++- 2 files changed, 107 insertions(+), 10 deletions(-) diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 087e8805f..acb87d2fb 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -33,8 +33,7 @@ def __init__( solver: SUPPORTED_SOLVER = "quadprog", ) -> None: super().__init__() - self._pref_vector = pref_vector - self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) + self.pref_vector = pref_vector self.norm_eps = norm_eps self.reg_eps = reg_eps self.solver: SUPPORTED_SOLVER = solver @@ -45,6 +44,37 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: w = project_weights(u, G, self.solver) return w + @property + def pref_vector(self) -> Tensor | None: + return self._pref_vector + + @pref_vector.setter + def pref_vector(self, value: Tensor | None) -> None: + self.weighting = pref_vector_to_weighting(value, default=MeanWeighting()) + self._pref_vector = value + + @property + def norm_eps(self) -> float: + return self._norm_eps + + @norm_eps.setter + def norm_eps(self, value: float) -> None: + if value < 0: + raise ValueError(f"norm_eps must be non-negative, but got {value}.") + + self._norm_eps = value + + @property + def reg_eps(self) -> float: + return self._reg_eps + + @reg_eps.setter + def reg_eps(self, value: float) -> None: + if value < 0: + raise ValueError(f"reg_eps must be non-negative, but got {value}.") + + self._reg_eps = value + class DualProj(GramianWeightedAggregator): r""" @@ -72,9 +102,6 @@ def __init__( reg_eps: float = 0.0001, solver: SUPPORTED_SOLVER = "quadprog", ) -> None: - self._pref_vector = pref_vector - self._norm_eps = norm_eps - self._reg_eps = reg_eps self._solver: SUPPORTED_SOLVER = solver super().__init__( @@ -84,11 +111,35 @@ def __init__( # This prevents considering the computed weights as constant w.r.t. the matrix. self.register_full_backward_pre_hook(raise_non_differentiable_error) + @property + def pref_vector(self) -> Tensor | None: + return self.gramian_weighting.pref_vector + + @pref_vector.setter + def pref_vector(self, value: Tensor | None) -> None: + self.gramian_weighting.pref_vector = value + + @property + def norm_eps(self) -> float: + return self.gramian_weighting.norm_eps + + @norm_eps.setter + def norm_eps(self, value: float) -> None: + self.gramian_weighting.norm_eps = value + + @property + def reg_eps(self) -> float: + return self.gramian_weighting.reg_eps + + @reg_eps.setter + def reg_eps(self, value: float) -> None: + self.gramian_weighting.reg_eps = value + def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps=" - f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})" + f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" + f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})" ) def __str__(self) -> str: - return f"DualProj{pref_vector_to_str_suffix(self._pref_vector)}" + return f"DualProj{pref_vector_to_str_suffix(self.pref_vector)}" diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index 5bd0e71af..dbe5d4c10 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -1,9 +1,10 @@ import torch -from pytest import mark +from pytest import mark, raises from torch import Tensor from utils.tensors import ones_ -from torchjd.aggregation import DualProj +from torchjd.aggregation import ConstantWeighting, DualProj +from torchjd.aggregation._dualproj import DualProjWeighting from ._asserts import ( assert_expected_structure, @@ -63,3 +64,48 @@ def test_representations() -> None: "solver='quadprog')" ) assert str(A) == "DualProj([1., 2., 3.])" + + +def test_pref_vector_setter_updates_value() -> None: + A = DualProj() + new_pref = torch.tensor([1.0, 2.0, 3.0]) + A.pref_vector = new_pref + assert A.pref_vector is new_pref + assert isinstance(A.gramian_weighting.weighting, ConstantWeighting) + assert A.gramian_weighting.weighting.weights is new_pref + + +def test_norm_eps_setter_updates_value() -> None: + A = DualProj() + A.norm_eps = 0.25 + assert A.norm_eps == 0.25 + + +def test_reg_eps_setter_updates_value() -> None: + A = DualProj() + A.reg_eps = 0.25 + assert A.reg_eps == 0.25 + + +def test_norm_eps_setter_rejects_negative() -> None: + A = DualProj() + with raises(ValueError, match="norm_eps"): + A.norm_eps = -1e-9 + + +def test_reg_eps_setter_rejects_negative() -> None: + A = DualProj() + with raises(ValueError, match="reg_eps"): + A.reg_eps = -1e-9 + + +def test_weighting_norm_eps_setter_rejects_negative() -> None: + W = DualProjWeighting() + with raises(ValueError, match="norm_eps"): + W.norm_eps = -1e-9 + + +def test_weighting_reg_eps_setter_rejects_negative() -> None: + W = DualProjWeighting() + with raises(ValueError, match="reg_eps"): + W.reg_eps = -1e-9 From ea5fefd17851ecfc814311682db4134b66ca60a3 Mon Sep 17 00:00:00 2001 From: mattbuot Date: Mon, 20 Apr 2026 21:17:11 +0200 Subject: [PATCH 02/15] feat(aggregation): Add getters and setters to AlignedMTL parameters Expose pref_vector and scale_mode as properties on AlignedMTL and AlignedMTLWeighting, mirroring the UPGrad pattern. --- src/torchjd/aggregation/_aligned_mtl.py | 40 +++++++++++++++++----- tests/unit/aggregation/test_aligned_mtl.py | 17 ++++++++- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index 07574b1b8..b0fbe3b42 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -35,17 +35,25 @@ def __init__( scale_mode: SUPPORTED_SCALE_MODE = "min", ) -> None: super().__init__() - self._pref_vector = pref_vector - self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode - self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) + self.pref_vector = pref_vector + self.scale_mode: SUPPORTED_SCALE_MODE = scale_mode def forward(self, gramian: PSDMatrix, /) -> Tensor: w = self.weighting(gramian) - B = self._compute_balance_transformation(gramian, self._scale_mode) + B = self._compute_balance_transformation(gramian, self.scale_mode) alpha = B @ w return alpha + @property + def pref_vector(self) -> Tensor | None: + return self._pref_vector + + @pref_vector.setter + def pref_vector(self, value: Tensor | None) -> None: + self.weighting = pref_vector_to_weighting(value, default=MeanWeighting()) + self._pref_vector = value + @staticmethod def _compute_balance_transformation( M: Tensor, @@ -103,15 +111,29 @@ def __init__( pref_vector: Tensor | None = None, scale_mode: SUPPORTED_SCALE_MODE = "min", ) -> None: - self._pref_vector = pref_vector - self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode)) + @property + def pref_vector(self) -> Tensor | None: + return self.gramian_weighting.pref_vector + + @pref_vector.setter + def pref_vector(self, value: Tensor | None) -> None: + self.gramian_weighting.pref_vector = value + + @property + def scale_mode(self) -> SUPPORTED_SCALE_MODE: + return self.gramian_weighting.scale_mode + + @scale_mode.setter + def scale_mode(self, value: SUPPORTED_SCALE_MODE) -> None: + self.gramian_weighting.scale_mode = value + def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, " - f"scale_mode={repr(self._scale_mode)})" + f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, " + f"scale_mode={repr(self.scale_mode)})" ) def __str__(self) -> str: - return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}" + return f"AlignedMTL{pref_vector_to_str_suffix(self.pref_vector)}" diff --git a/tests/unit/aggregation/test_aligned_mtl.py b/tests/unit/aggregation/test_aligned_mtl.py index d48e88556..303064637 100644 --- a/tests/unit/aggregation/test_aligned_mtl.py +++ b/tests/unit/aggregation/test_aligned_mtl.py @@ -3,7 +3,7 @@ from torch import Tensor from utils.tensors import ones_ -from torchjd.aggregation import AlignedMTL +from torchjd.aggregation import AlignedMTL, ConstantWeighting from ._asserts import assert_expected_structure, assert_permutation_invariant from ._inputs import scaled_matrices, typical_matrices @@ -43,3 +43,18 @@ def test_invalid_scale_mode() -> None: matrix = ones_(3, 4) with raises(ValueError, match=r"Invalid scale_mode=.*Expected"): aggregator(matrix) + + +def test_pref_vector_setter_updates_value() -> None: + A = AlignedMTL() + new_pref = torch.tensor([1.0, 2.0, 3.0]) + A.pref_vector = new_pref + assert A.pref_vector is new_pref + assert isinstance(A.gramian_weighting.weighting, ConstantWeighting) + assert A.gramian_weighting.weighting.weights is new_pref + + +def test_scale_mode_setter_updates_value() -> None: + A = AlignedMTL() + A.scale_mode = "rmse" + assert A.scale_mode == "rmse" From 83ebb2459a5ff7ecb89495a60d8207841465a3ee Mon Sep 17 00:00:00 2001 From: mattbuot Date: Mon, 20 Apr 2026 21:17:13 +0200 Subject: [PATCH 03/15] feat(aggregation): Add getter and setter to ConFIG pref_vector Expose pref_vector as a property on ConFIG. The setter rebuilds the internal weighting via pref_vector_to_weighting to keep state in sync. --- src/torchjd/aggregation/_config.py | 16 ++++++++++++---- tests/unit/aggregation/test_config.py | 11 ++++++++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/torchjd/aggregation/_config.py b/src/torchjd/aggregation/_config.py index 261f3a64a..f19c4023a 100644 --- a/src/torchjd/aggregation/_config.py +++ b/src/torchjd/aggregation/_config.py @@ -29,8 +29,7 @@ class ConFIG(Aggregator): def __init__(self, pref_vector: Tensor | None = None) -> None: super().__init__() - self.weighting = pref_vector_to_weighting(pref_vector, default=SumWeighting()) - self._pref_vector = pref_vector + self.pref_vector = pref_vector # This prevents computing gradients that can be very wrong. self.register_full_backward_pre_hook(raise_non_differentiable_error) @@ -46,8 +45,17 @@ def forward(self, matrix: Matrix, /) -> Tensor: return length * unit_target_vector + @property + def pref_vector(self) -> Tensor | None: + return self._pref_vector + + @pref_vector.setter + def pref_vector(self, value: Tensor | None) -> None: + self.weighting = pref_vector_to_weighting(value, default=SumWeighting()) + self._pref_vector = value + def __repr__(self) -> str: - return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})" + return f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)})" def __str__(self) -> str: - return f"ConFIG{pref_vector_to_str_suffix(self._pref_vector)}" + return f"ConFIG{pref_vector_to_str_suffix(self.pref_vector)}" diff --git a/tests/unit/aggregation/test_config.py b/tests/unit/aggregation/test_config.py index 2db2ea0fa..fe9516404 100644 --- a/tests/unit/aggregation/test_config.py +++ b/tests/unit/aggregation/test_config.py @@ -3,7 +3,7 @@ from torch import Tensor from utils.tensors import ones_ -from torchjd.aggregation import ConFIG +from torchjd.aggregation import ConFIG, ConstantWeighting from ._asserts import ( assert_expected_structure, @@ -47,3 +47,12 @@ def test_representations() -> None: A = ConFIG(pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu")) assert repr(A) == "ConFIG(pref_vector=tensor([1., 2., 3.]))" assert str(A) == "ConFIG([1., 2., 3.])" + + +def test_pref_vector_setter_updates_value() -> None: + A = ConFIG() + new_pref = torch.tensor([1.0, 2.0, 3.0]) + A.pref_vector = new_pref + assert A.pref_vector is new_pref + assert isinstance(A.weighting, ConstantWeighting) + assert A.weighting.weights is new_pref From cdb7ce0fa174eb61eff0be0348b05c366e449af8 Mon Sep 17 00:00:00 2001 From: mattbuot Date: Mon, 20 Apr 2026 21:20:12 +0200 Subject: [PATCH 04/15] feat(aggregation): Add getters and setters to CAGrad parameters Expose c and norm_eps as properties on CAGrad and CAGradWeighting. The setters validate that the new value is non-negative. --- src/torchjd/aggregation/_cagrad.py | 48 ++++++++++++++++++++++----- tests/unit/aggregation/test_cagrad.py | 37 +++++++++++++++++++++ 2 files changed, 77 insertions(+), 8 deletions(-) diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 88ca66e06..40ce85157 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -37,10 +37,6 @@ class CAGradWeighting(GramianWeighting): def __init__(self, c: float, norm_eps: float = 0.0001) -> None: super().__init__() - - if c < 0.0: - raise ValueError(f"Parameter `c` should be a non-negative float. Found `c = {c}`.") - self.c = c self.norm_eps = norm_eps @@ -73,6 +69,28 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: return weights + @property + def c(self) -> float: + return self._c + + @c.setter + def c(self, value: float) -> None: + if value < 0: + raise ValueError(f"c must be non-negative, but got {value}.") + + self._c = value + + @property + def norm_eps(self) -> float: + return self._norm_eps + + @norm_eps.setter + def norm_eps(self, value: float) -> None: + if value < 0: + raise ValueError(f"norm_eps must be non-negative, but got {value}.") + + self._norm_eps = value + class CAGrad(GramianWeightedAggregator): """ @@ -94,15 +112,29 @@ class CAGrad(GramianWeightedAggregator): def __init__(self, c: float, norm_eps: float = 0.0001) -> None: super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps)) - self._c = c - self._norm_eps = norm_eps # This prevents considering the computed weights as constant w.r.t. the matrix. self.register_full_backward_pre_hook(raise_non_differentiable_error) + @property + def c(self) -> float: + return self.gramian_weighting.c + + @c.setter + def c(self, value: float) -> None: + self.gramian_weighting.c = value + + @property + def norm_eps(self) -> float: + return self.gramian_weighting.norm_eps + + @norm_eps.setter + def norm_eps(self, value: float) -> None: + self.gramian_weighting.norm_eps = value + def __repr__(self) -> str: - return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})" + return f"{self.__class__.__name__}(c={self.c}, norm_eps={self.norm_eps})" def __str__(self) -> str: - c_str = str(self._c).rstrip("0") + c_str = str(self.c).rstrip("0") return f"CAGrad{c_str}" diff --git a/tests/unit/aggregation/test_cagrad.py b/tests/unit/aggregation/test_cagrad.py index c7d18b1f2..ee479746f 100644 --- a/tests/unit/aggregation/test_cagrad.py +++ b/tests/unit/aggregation/test_cagrad.py @@ -7,6 +7,7 @@ try: from torchjd.aggregation import CAGrad + from torchjd.aggregation._cagrad import CAGradWeighting except ImportError: import pytest @@ -57,3 +58,39 @@ def test_representations() -> None: A = CAGrad(c=0.5, norm_eps=0.0001) assert repr(A) == "CAGrad(c=0.5, norm_eps=0.0001)" assert str(A) == "CAGrad0.5" + + +def test_c_setter_updates_value() -> None: + A = CAGrad(c=0.5) + A.c = 1.25 + assert A.c == 1.25 + + +def test_norm_eps_setter_updates_value() -> None: + A = CAGrad(c=0.5) + A.norm_eps = 0.25 + assert A.norm_eps == 0.25 + + +def test_c_setter_rejects_negative() -> None: + A = CAGrad(c=0.5) + with raises(ValueError, match="c"): + A.c = -1e-9 + + +def test_norm_eps_setter_rejects_negative() -> None: + A = CAGrad(c=0.5) + with raises(ValueError, match="norm_eps"): + A.norm_eps = -1e-9 + + +def test_weighting_c_setter_rejects_negative() -> None: + W = CAGradWeighting(c=0.5) + with raises(ValueError, match="c"): + W.c = -1e-9 + + +def test_weighting_norm_eps_setter_rejects_negative() -> None: + W = CAGradWeighting(c=0.5) + with raises(ValueError, match="norm_eps"): + W.norm_eps = -1e-9 From 5aaaaf0ccc9c65f9680242ebd8e5e84f07d2fe82 Mon Sep 17 00:00:00 2001 From: mattbuot Date: Mon, 20 Apr 2026 21:21:15 +0200 Subject: [PATCH 05/15] feat(aggregation): Add getters and setters to Krum parameters Expose n_byzantine and n_selected as properties on Krum and KrumWeighting. The setters enforce n_byzantine >= 0 and n_selected >= 1. --- src/torchjd/aggregation/_krum.py | 58 ++++++++++++++++++++--------- tests/unit/aggregation/test_krum.py | 39 +++++++++++++++++++ 2 files changed, 80 insertions(+), 17 deletions(-) diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index d48f8918a..046c0a461 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -20,18 +20,6 @@ class KrumWeighting(GramianWeighting): def __init__(self, n_byzantine: int, n_selected: int = 1) -> None: super().__init__() - if n_byzantine < 0: - raise ValueError( - "Parameter `n_byzantine` should be a non-negative integer. Found `n_byzantine = " - f"{n_byzantine}`.", - ) - - if n_selected < 1: - raise ValueError( - "Parameter `n_selected` should be a positive integer. Found `n_selected = " - f"{n_selected}`.", - ) - self.n_byzantine = n_byzantine self.n_selected = n_selected @@ -54,6 +42,28 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: return weights + @property + def n_byzantine(self) -> int: + return self._n_byzantine + + @n_byzantine.setter + def n_byzantine(self, value: int) -> None: + if value < 0: + raise ValueError(f"n_byzantine must be non-negative, but got {value}.") + + self._n_byzantine = value + + @property + def n_selected(self) -> int: + return self._n_selected + + @n_selected.setter + def n_selected(self, value: int) -> None: + if value < 1: + raise ValueError(f"n_selected must be a positive integer, but got {value}.") + + self._n_selected = value + def _check_matrix_shape(self, gramian: PSDMatrix) -> None: min_rows = self.n_byzantine + 3 if gramian.shape[0] < min_rows: @@ -83,15 +93,29 @@ class Krum(GramianWeightedAggregator): gramian_weighting: KrumWeighting def __init__(self, n_byzantine: int, n_selected: int = 1) -> None: - self._n_byzantine = n_byzantine - self._n_selected = n_selected super().__init__(KrumWeighting(n_byzantine=n_byzantine, n_selected=n_selected)) + @property + def n_byzantine(self) -> int: + return self.gramian_weighting.n_byzantine + + @n_byzantine.setter + def n_byzantine(self, value: int) -> None: + self.gramian_weighting.n_byzantine = value + + @property + def n_selected(self) -> int: + return self.gramian_weighting.n_selected + + @n_selected.setter + def n_selected(self, value: int) -> None: + self.gramian_weighting.n_selected = value + def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(n_byzantine={self._n_byzantine}, n_selected=" - f"{self._n_selected})" + f"{self.__class__.__name__}(n_byzantine={self.n_byzantine}, n_selected=" + f"{self.n_selected})" ) def __str__(self) -> str: - return f"Krum{self._n_byzantine}-{self._n_selected}" + return f"Krum{self.n_byzantine}-{self.n_selected}" diff --git a/tests/unit/aggregation/test_krum.py b/tests/unit/aggregation/test_krum.py index 4097f2ebe..5a86601bf 100644 --- a/tests/unit/aggregation/test_krum.py +++ b/tests/unit/aggregation/test_krum.py @@ -6,6 +6,7 @@ from utils.tensors import ones_ from torchjd.aggregation import Krum +from torchjd.aggregation._krum import KrumWeighting from ._asserts import assert_expected_structure from ._inputs import scaled_matrices_2_plus_rows, typical_matrices_2_plus_rows @@ -78,3 +79,41 @@ def test_representations() -> None: A = Krum(n_byzantine=1, n_selected=2) assert repr(A) == "Krum(n_byzantine=1, n_selected=2)" assert str(A) == "Krum1-2" + + +def test_n_byzantine_setter_updates_value() -> None: + A = Krum(n_byzantine=1) + A.n_byzantine = 3 + assert A.n_byzantine == 3 + + +def test_n_selected_setter_updates_value() -> None: + A = Krum(n_byzantine=1) + A.n_selected = 3 + assert A.n_selected == 3 + + +def test_n_byzantine_setter_rejects_negative() -> None: + A = Krum(n_byzantine=1) + with raises(ValueError, match="n_byzantine"): + A.n_byzantine = -1 + + +def test_n_selected_setter_rejects_non_positive() -> None: + A = Krum(n_byzantine=1) + with raises(ValueError, match="n_selected"): + A.n_selected = 0 + with raises(ValueError, match="n_selected"): + A.n_selected = -1 + + +def test_weighting_n_byzantine_setter_rejects_negative() -> None: + W = KrumWeighting(n_byzantine=1) + with raises(ValueError, match="n_byzantine"): + W.n_byzantine = -1 + + +def test_weighting_n_selected_setter_rejects_non_positive() -> None: + W = KrumWeighting(n_byzantine=1) + with raises(ValueError, match="n_selected"): + W.n_selected = 0 From e2abe505d402761179d500a857469715e5088ddf Mon Sep 17 00:00:00 2001 From: mattbuot Date: Mon, 20 Apr 2026 21:23:02 +0200 Subject: [PATCH 06/15] feat(aggregation): Add getters and setters to MGDA parameters Expose epsilon and max_iters as properties on MGDA and MGDAWeighting. The setters enforce that both parameters are strictly positive. --- src/torchjd/aggregation/_mgda.py | 42 ++++++++++++++++++++++++++--- tests/unit/aggregation/test_mgda.py | 42 ++++++++++++++++++++++++++++- 2 files changed, 80 insertions(+), 4 deletions(-) diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index 575f21a48..33c727aa1 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -49,6 +49,28 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: break return alpha + @property + def epsilon(self) -> float: + return self._epsilon + + @epsilon.setter + def epsilon(self, value: float) -> None: + if value <= 0: + raise ValueError(f"epsilon must be positive, but got {value}.") + + self._epsilon = value + + @property + def max_iters(self) -> int: + return self._max_iters + + @max_iters.setter + def max_iters(self, value: int) -> None: + if value <= 0: + raise ValueError(f"max_iters must be a positive integer, but got {value}.") + + self._max_iters = value + class MGDA(GramianWeightedAggregator): r""" @@ -67,8 +89,22 @@ class MGDA(GramianWeightedAggregator): def __init__(self, epsilon: float = 0.001, max_iters: int = 100) -> None: super().__init__(MGDAWeighting(epsilon=epsilon, max_iters=max_iters)) - self._epsilon = epsilon - self._max_iters = max_iters + + @property + def epsilon(self) -> float: + return self.gramian_weighting.epsilon + + @epsilon.setter + def epsilon(self, value: float) -> None: + self.gramian_weighting.epsilon = value + + @property + def max_iters(self) -> int: + return self.gramian_weighting.max_iters + + @max_iters.setter + def max_iters(self, value: int) -> None: + self.gramian_weighting.max_iters = value def __repr__(self) -> str: - return f"{self.__class__.__name__}(epsilon={self._epsilon}, max_iters={self._max_iters})" + return f"{self.__class__.__name__}(epsilon={self.epsilon}, max_iters={self.max_iters})" diff --git a/tests/unit/aggregation/test_mgda.py b/tests/unit/aggregation/test_mgda.py index 5c925b8fe..b9d6873f9 100644 --- a/tests/unit/aggregation/test_mgda.py +++ b/tests/unit/aggregation/test_mgda.py @@ -1,4 +1,4 @@ -from pytest import mark +from pytest import mark, raises from torch import Tensor from torch.testing import assert_close from utils.tensors import ones_, randn_ @@ -70,3 +70,43 @@ def test_representations() -> None: A = MGDA(epsilon=0.001, max_iters=100) assert repr(A) == "MGDA(epsilon=0.001, max_iters=100)" assert str(A) == "MGDA" + + +def test_epsilon_setter_updates_value() -> None: + A = MGDA() + A.epsilon = 0.25 + assert A.epsilon == 0.25 + + +def test_max_iters_setter_updates_value() -> None: + A = MGDA() + A.max_iters = 42 + assert A.max_iters == 42 + + +def test_epsilon_setter_rejects_non_positive() -> None: + A = MGDA() + with raises(ValueError, match="epsilon"): + A.epsilon = 0.0 + with raises(ValueError, match="epsilon"): + A.epsilon = -1e-9 + + +def test_max_iters_setter_rejects_non_positive() -> None: + A = MGDA() + with raises(ValueError, match="max_iters"): + A.max_iters = 0 + with raises(ValueError, match="max_iters"): + A.max_iters = -1 + + +def test_weighting_epsilon_setter_rejects_non_positive() -> None: + W = MGDAWeighting() + with raises(ValueError, match="epsilon"): + W.epsilon = 0.0 + + +def test_weighting_max_iters_setter_rejects_non_positive() -> None: + W = MGDAWeighting() + with raises(ValueError, match="max_iters"): + W.max_iters = 0 From 41e02bd26d6a4f8c7a18ca0bb29805d0fcca728f Mon Sep 17 00:00:00 2001 From: mattbuot Date: Mon, 20 Apr 2026 22:30:40 +0200 Subject: [PATCH 07/15] feat(aggregation): Add getters and setters to NashMTL parameters Expose n_tasks, max_norm, update_weights_every, and optim_niter as properties on NashMTL and _NashMTLWeighting. Validation: n_tasks/update_weights_every/optim_niter > 0, max_norm >= 0 (0 disables norm clipping, matching existing forward logic). Setters do not automatically reset the internal state; users must call reset() if the state needs rebuilding (especially after changing n_tasks). Documented in both class docstrings. --- src/torchjd/aggregation/_nash_mtl.py | 102 ++++++++++++++++++++++-- tests/unit/aggregation/test_nash_mtl.py | 53 +++++++++++- 2 files changed, 146 insertions(+), 9 deletions(-) diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 9cd3e7bc8..952fcd0e8 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -28,11 +28,17 @@ class _NashMTLWeighting(MatrixWeighting, Stateful): :param n_tasks: The number of tasks, corresponding to the number of rows in the provided matrices. - :param max_norm: Maximum value of the norm of :math:`J^T w`. + :param max_norm: Maximum value of the norm of :math:`J^T w`. A value of ``0`` disables the + norm clipping. :param update_weights_every: A parameter determining how often the actual weighting should be performed. A larger value means that the same weights will be re-used for more calls to the weighting. :param optim_niter: The number of iterations of the underlying optimization process. + + .. note:: + Changing any of these parameters after instantiation does not automatically reset the + internal state. Call :meth:`reset` if needed (especially after changing ``n_tasks``, which + affects the shape of the cached state). """ def __init__( @@ -55,6 +61,52 @@ def __init__( self.step = 0.0 self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32) + @property + def n_tasks(self) -> int: + return self._n_tasks + + @n_tasks.setter + def n_tasks(self, value: int) -> None: + if value <= 0: + raise ValueError(f"n_tasks must be a positive integer, but got {value}.") + + self._n_tasks = value + + @property + def max_norm(self) -> float: + return self._max_norm + + @max_norm.setter + def max_norm(self, value: float) -> None: + if value < 0: + raise ValueError(f"max_norm must be non-negative, but got {value}.") + + self._max_norm = value + + @property + def update_weights_every(self) -> int: + return self._update_weights_every + + @update_weights_every.setter + def update_weights_every(self, value: int) -> None: + if value <= 0: + raise ValueError( + f"update_weights_every must be a positive integer, but got {value}.", + ) + + self._update_weights_every = value + + @property + def optim_niter(self) -> int: + return self._optim_niter + + @optim_niter.setter + def optim_niter(self, value: int) -> None: + if value <= 0: + raise ValueError(f"optim_niter must be a positive integer, but got {value}.") + + self._optim_niter = value + def _stop_criteria(self, gtg: np.ndarray, alpha_t: np.ndarray) -> bool: return bool( (self.alpha_param.value is None) @@ -156,7 +208,8 @@ class NashMTL(WeightedAggregator, Stateful): :param n_tasks: The number of tasks, corresponding to the number of rows in the provided matrices. - :param max_norm: Maximum value of the norm of :math:`J^T w`. + :param max_norm: Maximum value of the norm of :math:`J^T w`. A value of ``0`` disables the + norm clipping. :param update_weights_every: A parameter determining how often the actual weighting should be performed. A larger value means that the same weights will be re-used for more calls to the aggregator. @@ -176,6 +229,11 @@ class NashMTL(WeightedAggregator, Stateful): This aggregator is stateful. Its output will thus depend not only on the input matrix, but also on its state. It thus depends on previously seen matrices. It should be reset between experiments. + + .. note:: + Changing any of these parameters after instantiation does not automatically reset the + internal state. Call :meth:`reset` if needed (especially after changing ``n_tasks``, which + affects the shape of the cached state). """ weighting: _NashMTLWeighting @@ -195,20 +253,48 @@ def __init__( optim_niter=optim_niter, ), ) - self._n_tasks = n_tasks - self._max_norm = max_norm - self._update_weights_every = update_weights_every - self._optim_niter = optim_niter # This prevents considering the computed weights as constant w.r.t. the matrix. self.register_full_backward_pre_hook(raise_non_differentiable_error) + @property + def n_tasks(self) -> int: + return cast(_NashMTLWeighting, self.weighting).n_tasks + + @n_tasks.setter + def n_tasks(self, value: int) -> None: + cast(_NashMTLWeighting, self.weighting).n_tasks = value + + @property + def max_norm(self) -> float: + return cast(_NashMTLWeighting, self.weighting).max_norm + + @max_norm.setter + def max_norm(self, value: float) -> None: + cast(_NashMTLWeighting, self.weighting).max_norm = value + + @property + def update_weights_every(self) -> int: + return cast(_NashMTLWeighting, self.weighting).update_weights_every + + @update_weights_every.setter + def update_weights_every(self, value: int) -> None: + cast(_NashMTLWeighting, self.weighting).update_weights_every = value + + @property + def optim_niter(self) -> int: + return cast(_NashMTLWeighting, self.weighting).optim_niter + + @optim_niter.setter + def optim_niter(self, value: int) -> None: + cast(_NashMTLWeighting, self.weighting).optim_niter = value + def reset(self) -> None: """Resets the internal state of the algorithm.""" cast(_NashMTLWeighting, self.weighting).reset() def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(n_tasks={self._n_tasks}, max_norm={self._max_norm}, " - f"update_weights_every={self._update_weights_every}, optim_niter={self._optim_niter})" + f"{self.__class__.__name__}(n_tasks={self.n_tasks}, max_norm={self.max_norm}, " + f"update_weights_every={self.update_weights_every}, optim_niter={self.optim_niter})" ) diff --git a/tests/unit/aggregation/test_nash_mtl.py b/tests/unit/aggregation/test_nash_mtl.py index d82fca414..7d2cb0f9e 100644 --- a/tests/unit/aggregation/test_nash_mtl.py +++ b/tests/unit/aggregation/test_nash_mtl.py @@ -1,10 +1,11 @@ -from pytest import mark +from pytest import mark, raises from torch import Tensor from torch.testing import assert_close from utils.tensors import ones_, randn_, tensor_ try: from torchjd.aggregation import NashMTL + from torchjd.aggregation._nash_mtl import _NashMTLWeighting except ImportError: import pytest @@ -72,3 +73,53 @@ def test_representations() -> None: A = NashMTL(n_tasks=2, max_norm=1.5, update_weights_every=2, optim_niter=5) assert repr(A) == "NashMTL(n_tasks=2, max_norm=1.5, update_weights_every=2, optim_niter=5)" assert str(A) == "NashMTL" + + +def test_setters_update_values() -> None: + A = NashMTL(n_tasks=2) + A.n_tasks = 4 + A.max_norm = 2.5 + A.update_weights_every = 3 + A.optim_niter = 7 + assert A.n_tasks == 4 + assert A.max_norm == 2.5 + assert A.update_weights_every == 3 + assert A.optim_niter == 7 + + +def test_n_tasks_setter_rejects_non_positive() -> None: + A = NashMTL(n_tasks=2) + with raises(ValueError, match="n_tasks"): + A.n_tasks = 0 + with raises(ValueError, match="n_tasks"): + A.n_tasks = -1 + + +def test_max_norm_setter_rejects_negative() -> None: + A = NashMTL(n_tasks=2) + with raises(ValueError, match="max_norm"): + A.max_norm = -1e-9 + + +def test_update_weights_every_setter_rejects_non_positive() -> None: + A = NashMTL(n_tasks=2) + with raises(ValueError, match="update_weights_every"): + A.update_weights_every = 0 + + +def test_optim_niter_setter_rejects_non_positive() -> None: + A = NashMTL(n_tasks=2) + with raises(ValueError, match="optim_niter"): + A.optim_niter = 0 + + +def test_weighting_setters_validate() -> None: + W = _NashMTLWeighting(n_tasks=2, max_norm=1.0, update_weights_every=1, optim_niter=5) + with raises(ValueError, match="n_tasks"): + W.n_tasks = 0 + with raises(ValueError, match="max_norm"): + W.max_norm = -1.0 + with raises(ValueError, match="update_weights_every"): + W.update_weights_every = 0 + with raises(ValueError, match="optim_niter"): + W.optim_niter = 0 From 5412b7cc3c5f99d3b1b3243ed8f222394e77b826 Mon Sep 17 00:00:00 2001 From: mattbuot Date: Mon, 20 Apr 2026 22:31:18 +0200 Subject: [PATCH 08/15] feat(aggregation): Add getter and setter to TrimmedMean trim_number Expose trim_number as a property on TrimmedMean. The setter preserves the existing non-negative validation. --- src/torchjd/aggregation/_trimmed_mean.py | 16 +++++++++++----- tests/unit/aggregation/test_trimmed_mean.py | 12 ++++++++++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/torchjd/aggregation/_trimmed_mean.py b/src/torchjd/aggregation/_trimmed_mean.py index 8dffe990d..013e8e579 100644 --- a/src/torchjd/aggregation/_trimmed_mean.py +++ b/src/torchjd/aggregation/_trimmed_mean.py @@ -17,11 +17,6 @@ class TrimmedMean(Aggregator): def __init__(self, trim_number: int) -> None: super().__init__() - if trim_number < 0: - raise ValueError( - "Parameter `trim_number` should be a non-negative integer. Found `trim_number` = " - f"{trim_number}`.", - ) self.trim_number = trim_number def forward(self, matrix: Tensor, /) -> Tensor: @@ -35,6 +30,17 @@ def forward(self, matrix: Tensor, /) -> Tensor: vector = trimmed.mean(dim=0) return vector + @property + def trim_number(self) -> int: + return self._trim_number + + @trim_number.setter + def trim_number(self, value: int) -> None: + if value < 0: + raise ValueError(f"trim_number must be non-negative, but got {value}.") + + self._trim_number = value + def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None: min_rows = 1 + 2 * self.trim_number n_rows = matrix.shape[0] diff --git a/tests/unit/aggregation/test_trimmed_mean.py b/tests/unit/aggregation/test_trimmed_mean.py index 3a6ccb2bc..7217a0e08 100644 --- a/tests/unit/aggregation/test_trimmed_mean.py +++ b/tests/unit/aggregation/test_trimmed_mean.py @@ -61,3 +61,15 @@ def test_representations() -> None: aggregator = TrimmedMean(trim_number=2) assert repr(aggregator) == "TrimmedMean(trim_number=2)" assert str(aggregator) == "TM2" + + +def test_trim_number_setter_updates_value() -> None: + A = TrimmedMean(trim_number=1) + A.trim_number = 3 + assert A.trim_number == 3 + + +def test_trim_number_setter_rejects_negative() -> None: + A = TrimmedMean(trim_number=1) + with raises(ValueError, match="trim_number"): + A.trim_number = -1 From fdf5f5b805b1291522ec0d36363f60505222ed42 Mon Sep 17 00:00:00 2001 From: mattbuot Date: Mon, 20 Apr 2026 22:32:19 +0200 Subject: [PATCH 09/15] feat(aggregation): Add getter and setter to GradDrop leak Expose leak as a property on GradDrop. The setter preserves the existing 1D-tensor validation. f is kept as a plain public attribute (no validation needed), matching the pattern used for UPGrad's solver. --- src/torchjd/aggregation/_graddrop.py | 19 +++++++++++++------ tests/unit/aggregation/test_graddrop.py | 19 +++++++++++++++++++ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py index 61c9354ec..fc67810da 100644 --- a/src/torchjd/aggregation/_graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -27,12 +27,6 @@ class GradDrop(Aggregator): """ def __init__(self, f: Callable = _identity, leak: Tensor | None = None) -> None: - if leak is not None and leak.dim() != 1: - raise ValueError( - "Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = " - f"{leak.shape}`.", - ) - super().__init__() self.f = f self.leak = leak @@ -59,6 +53,19 @@ def forward(self, matrix: Matrix, /) -> Tensor: return vector + @property + def leak(self) -> Tensor | None: + return self._leak + + @leak.setter + def leak(self, value: Tensor | None) -> None: + if value is not None and value.dim() != 1: + raise ValueError( + f"leak must be a 1-dimensional tensor. Found leak.shape = {value.shape}.", + ) + + self._leak = value + def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None: n_rows = matrix.shape[0] if self.leak is not None and n_rows != len(self.leak): diff --git a/tests/unit/aggregation/test_graddrop.py b/tests/unit/aggregation/test_graddrop.py index 2868dca0d..d21ef690f 100644 --- a/tests/unit/aggregation/test_graddrop.py +++ b/tests/unit/aggregation/test_graddrop.py @@ -83,3 +83,22 @@ def test_representations() -> None: A = GradDrop() assert re.match(r"GradDrop\(f=, leak=None\)", repr(A)) assert str(A) == "GradDrop" + + +def test_leak_setter_updates_value() -> None: + A = GradDrop() + new_leak = torch.tensor([0.0, 0.5, 1.0]) + A.leak = new_leak + assert A.leak is new_leak + + +def test_leak_setter_accepts_none() -> None: + A = GradDrop(leak=torch.tensor([0.0, 1.0])) + A.leak = None + assert A.leak is None + + +def test_leak_setter_rejects_non_1d() -> None: + A = GradDrop() + with raises(ValueError, match="leak"): + A.leak = torch.tensor([[0.0, 1.0], [1.0, 0.0]]) From 95cfe885e21bcbeca79af3fd31fd74eafdf33a8f Mon Sep 17 00:00:00 2001 From: mattbuot Date: Mon, 20 Apr 2026 22:33:13 +0200 Subject: [PATCH 10/15] docs: Add changelog entry for aggregator getter/setter rollout Covers AlignedMTL, CAGrad, ConFIG, DualProj, GradDrop, Krum, MGDA, NashMTL, and TrimmedMean. --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 46aa6333d..74be3971a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,13 @@ changelog does not include internal changes that do not affect the user. - Added `pref_vector`, `norm_eps`, and `reg_eps` getters and setters to `UPGrad` and `UPGradWeighting`. The setters for `norm_eps` and `reg_eps` validate that the assigned value is non-negative. +- Added getters and setters for the constructor parameters of `AlignedMTL` / `AlignedMTLWeighting` + (`pref_vector`, `scale_mode`), `CAGrad` / `CAGradWeighting` (`c`, `norm_eps`), `ConFIG` + (`pref_vector`), `DualProj` / `DualProjWeighting` (`pref_vector`, `norm_eps`, `reg_eps`), + `GradDrop` (`leak`), `Krum` / `KrumWeighting` (`n_byzantine`, `n_selected`), `MGDA` / + `MGDAWeighting` (`epsilon`, `max_iters`), `NashMTL` / `_NashMTLWeighting` (`n_tasks`, `max_norm`, + `update_weights_every`, `optim_niter`), and `TrimmedMean` (`trim_number`). Setters validate their + inputs matching the existing constructor checks (e.g. non-negative or strictly-positive). ## [0.10.0] - 2026-04-16 From 089f8be3832fba16e65ac9f9c3d05021133d14fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 21 Apr 2026 13:33:13 +0200 Subject: [PATCH 11/15] Remove unnecessary cast in NashMTL --- src/torchjd/aggregation/_nash_mtl.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 952fcd0e8..65604d358 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -8,8 +8,6 @@ check_dependencies_are_installed(["cvxpy", "ecos"]) -from typing import cast - import cvxpy as cp import numpy as np import torch @@ -259,39 +257,39 @@ def __init__( @property def n_tasks(self) -> int: - return cast(_NashMTLWeighting, self.weighting).n_tasks + return self.weighting.n_tasks @n_tasks.setter def n_tasks(self, value: int) -> None: - cast(_NashMTLWeighting, self.weighting).n_tasks = value + self.weighting.n_tasks = value @property def max_norm(self) -> float: - return cast(_NashMTLWeighting, self.weighting).max_norm + return self.weighting.max_norm @max_norm.setter def max_norm(self, value: float) -> None: - cast(_NashMTLWeighting, self.weighting).max_norm = value + self.weighting.max_norm = value @property def update_weights_every(self) -> int: - return cast(_NashMTLWeighting, self.weighting).update_weights_every + return self.weighting.update_weights_every @update_weights_every.setter def update_weights_every(self, value: int) -> None: - cast(_NashMTLWeighting, self.weighting).update_weights_every = value + self.weighting.update_weights_every = value @property def optim_niter(self) -> int: - return cast(_NashMTLWeighting, self.weighting).optim_niter + return self.weighting.optim_niter @optim_niter.setter def optim_niter(self, value: int) -> None: - cast(_NashMTLWeighting, self.weighting).optim_niter = value + self.weighting.optim_niter = value def reset(self) -> None: """Resets the internal state of the algorithm.""" - cast(_NashMTLWeighting, self.weighting).reset() + self.weighting.reset() def __repr__(self) -> str: return ( From 3f7c499a451e743524bc9a34aead3bfb7bf3b100 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 21 Apr 2026 13:43:44 +0200 Subject: [PATCH 12/15] Improve changelog entry --- CHANGELOG.md | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 74be3971a..fe52e1fec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,16 +10,15 @@ changelog does not include internal changes that do not affect the user. ### Added -- Added `pref_vector`, `norm_eps`, and `reg_eps` getters and setters to `UPGrad` and - `UPGradWeighting`. The setters for `norm_eps` and `reg_eps` validate that the assigned value is - non-negative. -- Added getters and setters for the constructor parameters of `AlignedMTL` / `AlignedMTLWeighting` - (`pref_vector`, `scale_mode`), `CAGrad` / `CAGradWeighting` (`c`, `norm_eps`), `ConFIG` - (`pref_vector`), `DualProj` / `DualProjWeighting` (`pref_vector`, `norm_eps`, `reg_eps`), - `GradDrop` (`leak`), `Krum` / `KrumWeighting` (`n_byzantine`, `n_selected`), `MGDA` / - `MGDAWeighting` (`epsilon`, `max_iters`), `NashMTL` / `_NashMTLWeighting` (`n_tasks`, `max_norm`, - `update_weights_every`, `optim_niter`), and `TrimmedMean` (`trim_number`). Setters validate their - inputs matching the existing constructor checks (e.g. non-negative or strictly-positive). +- Added getters and setters for the constructor parameters of all aggregators and weightings, so + that they can be changed after initialization. This includes: `pref_vector`, + `norm_eps` and `reg_eps` in `UPGrad`, `UPGradWeighting`, `DualProj` and `DualProjWeighting`; + `pref_vector` and `scale_mode` in `AlignedMTL` and `AlignedMTLWeighting`; `c` and `norm_eps` in + `CAGrad` and `CAGradWeighting`; `pref_vector` in `ConFIG`; `leak` in `GradDrop`, `n_byzantine` and + `n_selected` in `Krum` and `KrumWeighting`; `epsilon` and `max_iters` in `MGDA` and + `MGDAWeighting`; `n_tasks`, `max_norm`, `update_weights_every` and `optim_niter` in `NashMTL`; + `trim_number` in `TrimmedMean`. Setters validate their inputs matching the existing constructor + checks. Note that setters for `GradVac` and `GradVacWeighting` already existed. ## [0.10.0] - 2026-04-16 From 5fbe93eadf368ba84d716851585d55ecb2248a3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 21 Apr 2026 13:48:20 +0200 Subject: [PATCH 13/15] Make GradVac use its setters at init --- src/torchjd/aggregation/_gradvac.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index 5a1edc355..26a075b90 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -41,13 +41,8 @@ class GradVacWeighting(GramianWeighting, Stateful): def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: super().__init__() - if not (0.0 <= beta <= 1.0): - raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.") - if eps < 0.0: - raise ValueError(f"Parameter `eps` must be non-negative. Found eps={eps!r}.") - - self._beta = beta - self._eps = eps + self.beta = beta + self.eps = eps self._phi_t: Tensor | None = None self._state_key: tuple[int, torch.dtype] | None = None From ce64ca3ea1c215b30548d670ddf85faef0e5be54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Tue, 21 Apr 2026 13:52:25 +0200 Subject: [PATCH 14/15] Also test that aggregator's setters modify the weighting's params MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com> --- tests/unit/aggregation/test_aligned_mtl.py | 1 + tests/unit/aggregation/test_cagrad.py | 2 ++ tests/unit/aggregation/test_dualproj.py | 2 ++ tests/unit/aggregation/test_krum.py | 2 ++ tests/unit/aggregation/test_mgda.py | 2 ++ tests/unit/aggregation/test_nash_mtl.py | 4 ++++ 6 files changed, 13 insertions(+) diff --git a/tests/unit/aggregation/test_aligned_mtl.py b/tests/unit/aggregation/test_aligned_mtl.py index 303064637..6eacfba9b 100644 --- a/tests/unit/aggregation/test_aligned_mtl.py +++ b/tests/unit/aggregation/test_aligned_mtl.py @@ -58,3 +58,4 @@ def test_scale_mode_setter_updates_value() -> None: A = AlignedMTL() A.scale_mode = "rmse" assert A.scale_mode == "rmse" + assert A.gramian_weighting.scale_mode == "rmse" diff --git a/tests/unit/aggregation/test_cagrad.py b/tests/unit/aggregation/test_cagrad.py index ee479746f..9128899fd 100644 --- a/tests/unit/aggregation/test_cagrad.py +++ b/tests/unit/aggregation/test_cagrad.py @@ -64,12 +64,14 @@ def test_c_setter_updates_value() -> None: A = CAGrad(c=0.5) A.c = 1.25 assert A.c == 1.25 + assert A.gramian_weighting.c == 1.25 def test_norm_eps_setter_updates_value() -> None: A = CAGrad(c=0.5) A.norm_eps = 0.25 assert A.norm_eps == 0.25 + assert A.gramian_weighting.norm_eps == 0.25 def test_c_setter_rejects_negative() -> None: diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index dbe5d4c10..34fe8d462 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -79,12 +79,14 @@ def test_norm_eps_setter_updates_value() -> None: A = DualProj() A.norm_eps = 0.25 assert A.norm_eps == 0.25 + assert A.gramian_weighting.norm_eps == 0.25 def test_reg_eps_setter_updates_value() -> None: A = DualProj() A.reg_eps = 0.25 assert A.reg_eps == 0.25 + assert A.gramian_weighting.reg_eps == 0.25 def test_norm_eps_setter_rejects_negative() -> None: diff --git a/tests/unit/aggregation/test_krum.py b/tests/unit/aggregation/test_krum.py index 5a86601bf..6270f6778 100644 --- a/tests/unit/aggregation/test_krum.py +++ b/tests/unit/aggregation/test_krum.py @@ -85,12 +85,14 @@ def test_n_byzantine_setter_updates_value() -> None: A = Krum(n_byzantine=1) A.n_byzantine = 3 assert A.n_byzantine == 3 + assert A.gramian_weighting.n_byzantine == 3 def test_n_selected_setter_updates_value() -> None: A = Krum(n_byzantine=1) A.n_selected = 3 assert A.n_selected == 3 + assert A.gramian_weighting.n_selected == 3 def test_n_byzantine_setter_rejects_negative() -> None: diff --git a/tests/unit/aggregation/test_mgda.py b/tests/unit/aggregation/test_mgda.py index b9d6873f9..f6c67236b 100644 --- a/tests/unit/aggregation/test_mgda.py +++ b/tests/unit/aggregation/test_mgda.py @@ -76,12 +76,14 @@ def test_epsilon_setter_updates_value() -> None: A = MGDA() A.epsilon = 0.25 assert A.epsilon == 0.25 + assert A.gramian_weighting.epsilon == 0.25 def test_max_iters_setter_updates_value() -> None: A = MGDA() A.max_iters = 42 assert A.max_iters == 42 + assert A.gramian_weighting.max_iters == 42 def test_epsilon_setter_rejects_non_positive() -> None: diff --git a/tests/unit/aggregation/test_nash_mtl.py b/tests/unit/aggregation/test_nash_mtl.py index 7d2cb0f9e..6dac1f0ea 100644 --- a/tests/unit/aggregation/test_nash_mtl.py +++ b/tests/unit/aggregation/test_nash_mtl.py @@ -85,6 +85,10 @@ def test_setters_update_values() -> None: assert A.max_norm == 2.5 assert A.update_weights_every == 3 assert A.optim_niter == 7 + assert A.weighting.n_tasks == 4 + assert A.weighting.max_norm == 2.5 + assert A.weighting.update_weights_every == 3 + assert A.weighting.optim_niter == 7 def test_n_tasks_setter_rejects_non_positive() -> None: From 6ab28710f8e3156f8831e31e646851d5c0385ce2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 21 Apr 2026 13:56:28 +0200 Subject: [PATCH 15/15] Upate GradVac tests to uniformize them a bit with the rest --- tests/unit/aggregation/test_gradvac.py | 96 +++++++++++--------------- 1 file changed, 42 insertions(+), 54 deletions(-) diff --git a/tests/unit/aggregation/test_gradvac.py b/tests/unit/aggregation/test_gradvac.py index bde2e8fd6..28db60905 100644 --- a/tests/unit/aggregation/test_gradvac.py +++ b/tests/unit/aggregation/test_gradvac.py @@ -20,48 +20,6 @@ def test_representations() -> None: assert str(A) == "GradVac" -def test_beta_out_of_range() -> None: - with raises(ValueError, match="beta"): - GradVac(beta=-0.1) - with raises(ValueError, match="beta"): - GradVac(beta=1.1) - - -def test_beta_setter_out_of_range() -> None: - A = GradVac() - with raises(ValueError, match="beta"): - A.beta = -0.1 - with raises(ValueError, match="beta"): - A.beta = 1.1 - - -def test_beta_setter_updates_value() -> None: - A = GradVac() - A.beta = 0.25 - assert A.beta == 0.25 - - -def test_eps_rejects_negative() -> None: - with raises(ValueError, match="eps"): - GradVac(eps=-1e-9) - - -def test_eps_setter_rejects_negative() -> None: - A = GradVac() - with raises(ValueError, match="eps"): - A.eps = -1e-9 - - -def test_eps_can_be_changed_between_steps() -> None: - J = tensor_([[1.0, 0.0], [0.0, 1.0]]) - A = GradVac() - A.eps = 1e-6 - assert A(J).isfinite().all() - A.reset() - A.eps = 1e-10 - assert A(J).isfinite().all() - - def test_zero_rows_returns_zero_vector() -> None: out = GradVac()(tensor_([]).reshape(0, 3)) assert_close(out, tensor_([0.0, 0.0, 0.0])) @@ -104,18 +62,6 @@ def test_non_differentiable(aggregator: GradVac, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) -def test_weighting_beta_out_of_range() -> None: - with raises(ValueError, match="beta"): - GradVacWeighting(beta=-0.1) - with raises(ValueError, match="beta"): - GradVacWeighting(beta=1.1) - - -def test_weighting_eps_rejects_negative() -> None: - with raises(ValueError, match="eps"): - GradVacWeighting(eps=-1e-9) - - def test_weighting_reset_restores_first_step_behavior() -> None: J = randn_((3, 8)) G = J @ J.T @@ -144,3 +90,45 @@ def test_aggregator_and_weighting_agree() -> None: result = weights @ J assert_close(result, expected, rtol=1e-4, atol=1e-4) + + +def test_beta_setter_updates_value() -> None: + A = GradVac() + A.beta = 0.25 + assert A.beta == 0.25 + assert A.gramian_weighting.beta == 0.25 + + +def test_eps_setter_updates_value() -> None: + A = GradVac() + A.eps = 1e-6 + assert A.eps == 1e-6 + assert A.gramian_weighting.eps == 1e-6 + + +def test_beta_setter_rejects_out_of_range() -> None: + A = GradVac() + with raises(ValueError, match="beta"): + A.beta = -0.1 + with raises(ValueError, match="beta"): + A.beta = 1.1 + + +def test_eps_setter_rejects_negative() -> None: + A = GradVac() + with raises(ValueError, match="eps"): + A.eps = -1e-9 + + +def test_weighting_beta_setter_rejects_out_of_range() -> None: + W = GradVacWeighting() + with raises(ValueError, match="beta"): + W.beta = -0.1 + with raises(ValueError, match="beta"): + W.beta = 1.1 + + +def test_weighting_eps_setter_rejects_negative() -> None: + W = GradVacWeighting() + with raises(ValueError, match="eps"): + W.eps = -1e-9