diff --git a/CHANGELOG.md b/CHANGELOG.md index d2aa0eef..46aa6333 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,12 @@ changelog does not include internal changes that do not affect the user. ## [Unreleased] +### Added + +- Added `pref_vector`, `norm_eps`, and `reg_eps` getters and setters to `UPGrad` and + `UPGradWeighting`. The setters for `norm_eps` and `reg_eps` validate that the assigned value is + non-negative. + ## [0.10.0] - 2026-04-16 ### Added diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index b09c0a59..68689829 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -34,8 +34,7 @@ def __init__( solver: SUPPORTED_SOLVER = "quadprog", ) -> None: super().__init__() - self._pref_vector = pref_vector - self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) + self.pref_vector = pref_vector self.norm_eps = norm_eps self.reg_eps = reg_eps self.solver: SUPPORTED_SOLVER = solver @@ -46,6 +45,39 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: W = project_weights(U, G, self.solver) return torch.sum(W, dim=0) + @property + def pref_vector(self) -> Tensor | None: + return self._pref_vector + + @pref_vector.setter + def pref_vector(self, value: Tensor | None) -> None: + self.weighting = pref_vector_to_weighting(value, default=MeanWeighting()) + self._pref_vector = value + + @property + def norm_eps(self) -> float: + return self._norm_eps + + @norm_eps.setter + def norm_eps(self, value: float) -> None: + + if value < 0: + raise ValueError(f"norm_eps must be non-negative, but got {value}.") + + self._norm_eps = value + + @property + def reg_eps(self) -> float: + return self._reg_eps + + @reg_eps.setter + def reg_eps(self, value: float) -> None: + + if value < 0: + raise ValueError(f"reg_eps must be non-negative, but got {value}.") + + self._reg_eps = value + class UPGrad(GramianWeightedAggregator): r""" @@ -73,9 +105,6 @@ def __init__( reg_eps: float = 0.0001, solver: SUPPORTED_SOLVER = "quadprog", ) -> None: - self._pref_vector = pref_vector - self._norm_eps = norm_eps - self._reg_eps = reg_eps self._solver: SUPPORTED_SOLVER = solver super().__init__( @@ -85,11 +114,35 @@ def __init__( # This prevents considering the computed weights as constant w.r.t. the matrix. self.register_full_backward_pre_hook(raise_non_differentiable_error) + @property + def pref_vector(self) -> Tensor | None: + return self.gramian_weighting.pref_vector + + @pref_vector.setter + def pref_vector(self, value: Tensor | None) -> None: + self.gramian_weighting.pref_vector = value + + @property + def norm_eps(self) -> float: + return self.gramian_weighting.norm_eps + + @norm_eps.setter + def norm_eps(self, value: float) -> None: + self.gramian_weighting.norm_eps = value + + @property + def reg_eps(self) -> float: + return self.gramian_weighting.reg_eps + + @reg_eps.setter + def reg_eps(self, value: float) -> None: + self.gramian_weighting.reg_eps = value + def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps=" - f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})" + f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" + f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})" ) def __str__(self) -> str: - return f"UPGrad{pref_vector_to_str_suffix(self._pref_vector)}" + return f"UPGrad{pref_vector_to_str_suffix(self.pref_vector)}" diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index 1859b662..075680a0 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -1,9 +1,10 @@ import torch -from pytest import mark +from pytest import mark, raises from torch import Tensor from utils.tensors import ones_ -from torchjd.aggregation import UPGrad +from torchjd.aggregation import ConstantWeighting, UPGrad +from torchjd.aggregation._upgrad import UPGradWeighting from ._asserts import ( assert_expected_structure, @@ -67,3 +68,48 @@ def test_representations() -> None: "solver='quadprog')" ) assert str(A) == "UPGrad([1., 2., 3.])" + + +def test_pref_vector_setter_updates_value() -> None: + A = UPGrad() + new_pref = torch.tensor([1.0, 2.0, 3.0]) + A.pref_vector = new_pref + assert A.pref_vector is new_pref + assert isinstance(A.gramian_weighting.weighting, ConstantWeighting) + assert A.gramian_weighting.weighting.weights is new_pref + + +def test_norm_eps_setter_updates_value() -> None: + A = UPGrad() + A.norm_eps = 0.25 + assert A.norm_eps == 0.25 + + +def test_reg_eps_setter_updates_value() -> None: + A = UPGrad() + A.reg_eps = 0.25 + assert A.reg_eps == 0.25 + + +def test_norm_eps_setter_rejects_negative() -> None: + A = UPGrad() + with raises(ValueError, match="norm_eps"): + A.norm_eps = -1e-9 + + +def test_reg_eps_setter_rejects_negative() -> None: + A = UPGrad() + with raises(ValueError, match="reg_eps"): + A.reg_eps = -1e-9 + + +def test_weighting_norm_eps_setter_rejects_negative() -> None: + W = UPGradWeighting() + with raises(ValueError, match="norm_eps"): + W.norm_eps = -1e-9 + + +def test_weighting_reg_eps_setter_rejects_negative() -> None: + W = UPGradWeighting() + with raises(ValueError, match="reg_eps"): + W.reg_eps = -1e-9