Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ changelog does not include internal changes that do not affect the user.

## [Unreleased]

### Added

- Added `pref_vector`, `norm_eps`, and `reg_eps` getters and setters to `UPGrad` and
`UPGradWeighting`. The setters for `norm_eps` and `reg_eps` validate that the assigned value is
non-negative.

## [0.10.0] - 2026-04-16

### Added
Expand Down
69 changes: 61 additions & 8 deletions src/torchjd/aggregation/_upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def __init__(
solver: SUPPORTED_SOLVER = "quadprog",
) -> None:
super().__init__()
self._pref_vector = pref_vector
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
self.pref_vector = pref_vector
self.norm_eps = norm_eps
self.reg_eps = reg_eps
self.solver: SUPPORTED_SOLVER = solver
Expand All @@ -46,6 +45,39 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
W = project_weights(U, G, self.solver)
return torch.sum(W, dim=0)

@property
def pref_vector(self) -> Tensor | None:
return self._pref_vector

@pref_vector.setter
def pref_vector(self, value: Tensor | None) -> None:
self.weighting = pref_vector_to_weighting(value, default=MeanWeighting())
Comment thread
ValerianRey marked this conversation as resolved.
self._pref_vector = value

@property
def norm_eps(self) -> float:
return self._norm_eps

@norm_eps.setter
def norm_eps(self, value: float) -> None:

if value < 0:
raise ValueError(f"norm_eps must be non-negative, but got {value}.")

self._norm_eps = value

@property
def reg_eps(self) -> float:
return self._reg_eps

@reg_eps.setter
def reg_eps(self, value: float) -> None:

if value < 0:
raise ValueError(f"reg_eps must be non-negative, but got {value}.")

self._reg_eps = value


class UPGrad(GramianWeightedAggregator):
r"""
Expand Down Expand Up @@ -73,9 +105,6 @@ def __init__(
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
) -> None:
self._pref_vector = pref_vector
self._norm_eps = norm_eps
self._reg_eps = reg_eps
self._solver: SUPPORTED_SOLVER = solver

super().__init__(
Expand All @@ -85,11 +114,35 @@ def __init__(
# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

@property
def pref_vector(self) -> Tensor | None:
return self.gramian_weighting.pref_vector

@pref_vector.setter
def pref_vector(self, value: Tensor | None) -> None:
self.gramian_weighting.pref_vector = value

@property
def norm_eps(self) -> float:
return self.gramian_weighting.norm_eps

@norm_eps.setter
def norm_eps(self, value: float) -> None:
self.gramian_weighting.norm_eps = value

@property
def reg_eps(self) -> float:
return self.gramian_weighting.reg_eps

@reg_eps.setter
def reg_eps(self, value: float) -> None:
self.gramian_weighting.reg_eps = value

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps="
f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})"
f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps="
f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})"
)

def __str__(self) -> str:
return f"UPGrad{pref_vector_to_str_suffix(self._pref_vector)}"
return f"UPGrad{pref_vector_to_str_suffix(self.pref_vector)}"
50 changes: 48 additions & 2 deletions tests/unit/aggregation/test_upgrad.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
from pytest import mark
from pytest import mark, raises
from torch import Tensor
from utils.tensors import ones_

from torchjd.aggregation import UPGrad
from torchjd.aggregation import ConstantWeighting, UPGrad
from torchjd.aggregation._upgrad import UPGradWeighting

from ._asserts import (
assert_expected_structure,
Expand Down Expand Up @@ -67,3 +68,48 @@ def test_representations() -> None:
"solver='quadprog')"
)
assert str(A) == "UPGrad([1., 2., 3.])"


def test_pref_vector_setter_updates_value() -> None:
A = UPGrad()
new_pref = torch.tensor([1.0, 2.0, 3.0])
A.pref_vector = new_pref
assert A.pref_vector is new_pref
Comment thread
ValerianRey marked this conversation as resolved.
assert isinstance(A.gramian_weighting.weighting, ConstantWeighting)
assert A.gramian_weighting.weighting.weights is new_pref


def test_norm_eps_setter_updates_value() -> None:
A = UPGrad()
A.norm_eps = 0.25
assert A.norm_eps == 0.25


def test_reg_eps_setter_updates_value() -> None:
A = UPGrad()
A.reg_eps = 0.25
assert A.reg_eps == 0.25


def test_norm_eps_setter_rejects_negative() -> None:
A = UPGrad()
with raises(ValueError, match="norm_eps"):
A.norm_eps = -1e-9


def test_reg_eps_setter_rejects_negative() -> None:
A = UPGrad()
with raises(ValueError, match="reg_eps"):
A.reg_eps = -1e-9


def test_weighting_norm_eps_setter_rejects_negative() -> None:
W = UPGradWeighting()
with raises(ValueError, match="norm_eps"):
W.norm_eps = -1e-9


def test_weighting_reg_eps_setter_rejects_negative() -> None:
W = UPGradWeighting()
with raises(ValueError, match="reg_eps"):
W.reg_eps = -1e-9
Loading