Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@ changelog does not include internal changes that do not affect the user.

### Added

- Added `pref_vector`, `norm_eps`, and `reg_eps` getters and setters to `UPGrad` and
`UPGradWeighting`. The setters for `norm_eps` and `reg_eps` validate that the assigned value is
non-negative.
- Added getters and setters for the constructor parameters of all aggregators and weightings, so
that they can be changed after initialization. This includes: `pref_vector`,
`norm_eps` and `reg_eps` in `UPGrad`, `UPGradWeighting`, `DualProj` and `DualProjWeighting`;
`pref_vector` and `scale_mode` in `AlignedMTL` and `AlignedMTLWeighting`; `c` and `norm_eps` in
`CAGrad` and `CAGradWeighting`; `pref_vector` in `ConFIG`; `leak` in `GradDrop`, `n_byzantine` and
`n_selected` in `Krum` and `KrumWeighting`; `epsilon` and `max_iters` in `MGDA` and
`MGDAWeighting`; `n_tasks`, `max_norm`, `update_weights_every` and `optim_niter` in `NashMTL`;
`trim_number` in `TrimmedMean`. Setters validate their inputs matching the existing constructor
checks. Note that setters for `GradVac` and `GradVacWeighting` already existed.

## [0.10.0] - 2026-04-16

Expand Down
40 changes: 31 additions & 9 deletions src/torchjd/aggregation/_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,25 @@ def __init__(
scale_mode: SUPPORTED_SCALE_MODE = "min",
) -> None:
super().__init__()
self._pref_vector = pref_vector
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
self.pref_vector = pref_vector
self.scale_mode: SUPPORTED_SCALE_MODE = scale_mode

def forward(self, gramian: PSDMatrix, /) -> Tensor:
w = self.weighting(gramian)
B = self._compute_balance_transformation(gramian, self._scale_mode)
B = self._compute_balance_transformation(gramian, self.scale_mode)
alpha = B @ w

return alpha

@property
def pref_vector(self) -> Tensor | None:
return self._pref_vector

@pref_vector.setter
def pref_vector(self, value: Tensor | None) -> None:
self.weighting = pref_vector_to_weighting(value, default=MeanWeighting())
self._pref_vector = value

@staticmethod
def _compute_balance_transformation(
M: Tensor,
Expand Down Expand Up @@ -103,15 +111,29 @@ def __init__(
pref_vector: Tensor | None = None,
scale_mode: SUPPORTED_SCALE_MODE = "min",
) -> None:
self._pref_vector = pref_vector
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode))

@property
def pref_vector(self) -> Tensor | None:
return self.gramian_weighting.pref_vector

@pref_vector.setter
def pref_vector(self, value: Tensor | None) -> None:
self.gramian_weighting.pref_vector = value

@property
def scale_mode(self) -> SUPPORTED_SCALE_MODE:
return self.gramian_weighting.scale_mode

@scale_mode.setter
def scale_mode(self, value: SUPPORTED_SCALE_MODE) -> None:
self.gramian_weighting.scale_mode = value

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, "
f"scale_mode={repr(self._scale_mode)})"
f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, "
f"scale_mode={repr(self.scale_mode)})"
)

def __str__(self) -> str:
return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}"
return f"AlignedMTL{pref_vector_to_str_suffix(self.pref_vector)}"
48 changes: 40 additions & 8 deletions src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ class CAGradWeighting(GramianWeighting):

def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
super().__init__()

if c < 0.0:
raise ValueError(f"Parameter `c` should be a non-negative float. Found `c = {c}`.")

self.c = c
self.norm_eps = norm_eps

Expand Down Expand Up @@ -73,6 +69,28 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:

return weights

@property
def c(self) -> float:
return self._c

@c.setter
def c(self, value: float) -> None:
if value < 0:
raise ValueError(f"c must be non-negative, but got {value}.")

self._c = value

@property
def norm_eps(self) -> float:
return self._norm_eps

@norm_eps.setter
def norm_eps(self, value: float) -> None:
if value < 0:
raise ValueError(f"norm_eps must be non-negative, but got {value}.")

self._norm_eps = value


class CAGrad(GramianWeightedAggregator):
"""
Expand All @@ -94,15 +112,29 @@ class CAGrad(GramianWeightedAggregator):

def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps))
self._c = c
self._norm_eps = norm_eps

# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

@property
def c(self) -> float:
return self.gramian_weighting.c

@c.setter
def c(self, value: float) -> None:
self.gramian_weighting.c = value

@property
def norm_eps(self) -> float:
return self.gramian_weighting.norm_eps

@norm_eps.setter
def norm_eps(self, value: float) -> None:
self.gramian_weighting.norm_eps = value

def __repr__(self) -> str:
return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})"
return f"{self.__class__.__name__}(c={self.c}, norm_eps={self.norm_eps})"

def __str__(self) -> str:
c_str = str(self._c).rstrip("0")
c_str = str(self.c).rstrip("0")
return f"CAGrad{c_str}"
16 changes: 12 additions & 4 deletions src/torchjd/aggregation/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ class ConFIG(Aggregator):

def __init__(self, pref_vector: Tensor | None = None) -> None:
super().__init__()
self.weighting = pref_vector_to_weighting(pref_vector, default=SumWeighting())
self._pref_vector = pref_vector
self.pref_vector = pref_vector

# This prevents computing gradients that can be very wrong.
self.register_full_backward_pre_hook(raise_non_differentiable_error)
Expand All @@ -46,8 +45,17 @@ def forward(self, matrix: Matrix, /) -> Tensor:

return length * unit_target_vector

@property
def pref_vector(self) -> Tensor | None:
return self._pref_vector

@pref_vector.setter
def pref_vector(self, value: Tensor | None) -> None:
self.weighting = pref_vector_to_weighting(value, default=SumWeighting())
self._pref_vector = value

def __repr__(self) -> str:
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"
return f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)})"

def __str__(self) -> str:
return f"ConFIG{pref_vector_to_str_suffix(self._pref_vector)}"
return f"ConFIG{pref_vector_to_str_suffix(self.pref_vector)}"
67 changes: 59 additions & 8 deletions src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ def __init__(
solver: SUPPORTED_SOLVER = "quadprog",
) -> None:
super().__init__()
self._pref_vector = pref_vector
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
self.pref_vector = pref_vector
self.norm_eps = norm_eps
self.reg_eps = reg_eps
self.solver: SUPPORTED_SOLVER = solver
Expand All @@ -45,6 +44,37 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
w = project_weights(u, G, self.solver)
return w

@property
def pref_vector(self) -> Tensor | None:
return self._pref_vector

@pref_vector.setter
def pref_vector(self, value: Tensor | None) -> None:
self.weighting = pref_vector_to_weighting(value, default=MeanWeighting())
self._pref_vector = value

@property
def norm_eps(self) -> float:
return self._norm_eps

@norm_eps.setter
def norm_eps(self, value: float) -> None:
if value < 0:
raise ValueError(f"norm_eps must be non-negative, but got {value}.")

self._norm_eps = value

@property
def reg_eps(self) -> float:
return self._reg_eps

@reg_eps.setter
def reg_eps(self, value: float) -> None:
if value < 0:
raise ValueError(f"reg_eps must be non-negative, but got {value}.")

self._reg_eps = value


class DualProj(GramianWeightedAggregator):
r"""
Expand Down Expand Up @@ -72,9 +102,6 @@ def __init__(
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
) -> None:
self._pref_vector = pref_vector
self._norm_eps = norm_eps
self._reg_eps = reg_eps
self._solver: SUPPORTED_SOLVER = solver

super().__init__(
Expand All @@ -84,11 +111,35 @@ def __init__(
# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

@property
def pref_vector(self) -> Tensor | None:
return self.gramian_weighting.pref_vector

@pref_vector.setter
def pref_vector(self, value: Tensor | None) -> None:
self.gramian_weighting.pref_vector = value

@property
def norm_eps(self) -> float:
return self.gramian_weighting.norm_eps

@norm_eps.setter
def norm_eps(self, value: float) -> None:
self.gramian_weighting.norm_eps = value

@property
def reg_eps(self) -> float:
return self.gramian_weighting.reg_eps

@reg_eps.setter
def reg_eps(self, value: float) -> None:
self.gramian_weighting.reg_eps = value

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps="
f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})"
f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps="
f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})"
)

def __str__(self) -> str:
return f"DualProj{pref_vector_to_str_suffix(self._pref_vector)}"
return f"DualProj{pref_vector_to_str_suffix(self.pref_vector)}"
19 changes: 13 additions & 6 deletions src/torchjd/aggregation/_graddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@ class GradDrop(Aggregator):
"""

def __init__(self, f: Callable = _identity, leak: Tensor | None = None) -> None:
if leak is not None and leak.dim() != 1:
raise ValueError(
"Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = "
f"{leak.shape}`.",
)

super().__init__()
self.f = f
self.leak = leak
Expand All @@ -59,6 +53,19 @@ def forward(self, matrix: Matrix, /) -> Tensor:

return vector

@property
def leak(self) -> Tensor | None:
return self._leak

@leak.setter
def leak(self, value: Tensor | None) -> None:
if value is not None and value.dim() != 1:
raise ValueError(
f"leak must be a 1-dimensional tensor. Found leak.shape = {value.shape}.",
)

self._leak = value

def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None:
n_rows = matrix.shape[0]
if self.leak is not None and n_rows != len(self.leak):
Expand Down
9 changes: 2 additions & 7 deletions src/torchjd/aggregation/_gradvac.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,8 @@ class GradVacWeighting(GramianWeighting, Stateful):

def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
super().__init__()
if not (0.0 <= beta <= 1.0):
raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.")
if eps < 0.0:
raise ValueError(f"Parameter `eps` must be non-negative. Found eps={eps!r}.")

self._beta = beta
self._eps = eps
self.beta = beta
self.eps = eps
self._phi_t: Tensor | None = None
self._state_key: tuple[int, torch.dtype] | None = None

Expand Down
Loading
Loading