Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 38 additions & 36 deletions src/torchjd/aggregation/_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,6 @@
SUPPORTED_SCALE_MODE: TypeAlias = Literal["min", "median", "rmse"]


class AlignedMTL(GramianWeightedAggregator):
r"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
`Independent Component Alignment for Multi-Task Learning
<https://openaccess.thecvf.com/content/CVPR2023/papers/Senushkin_Independent_Component_Alignment_for_Multi-Task_Learning_CVPR_2023_paper.pdf>`_.

:param pref_vector: The preference vector to use. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
:param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses
the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"``
uses the mean eigenvalue (as in the original implementation).

.. note::
This implementation was adapted from the official implementation of SamsungLabs/MTL,
which is not available anymore at the time of writing.
"""

def __init__(
self,
pref_vector: Tensor | None = None,
scale_mode: SUPPORTED_SCALE_MODE = "min",
) -> None:
self._pref_vector = pref_vector
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode))

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, "
f"scale_mode={repr(self._scale_mode)})"
)

def __str__(self) -> str:
return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}"


class AlignedMTLWeighting(Weighting[PSDMatrix]):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
Expand Down Expand Up @@ -113,3 +77,41 @@ def _compute_balance_transformation(

B = scale.sqrt() * V @ sigma_inv @ V.T
return B


class AlignedMTL(GramianWeightedAggregator):
r"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
`Independent Component Alignment for Multi-Task Learning
<https://openaccess.thecvf.com/content/CVPR2023/papers/Senushkin_Independent_Component_Alignment_for_Multi-Task_Learning_CVPR_2023_paper.pdf>`_.

:param pref_vector: The preference vector to use. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
:param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses
the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"``
uses the mean eigenvalue (as in the original implementation).

.. note::
This implementation was adapted from the official implementation of SamsungLabs/MTL,
which is not available anymore at the time of writing.
"""

gramian_weighting: AlignedMTLWeighting

def __init__(
self,
pref_vector: Tensor | None = None,
scale_mode: SUPPORTED_SCALE_MODE = "min",
) -> None:
self._pref_vector = pref_vector
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode))

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, "
f"scale_mode={repr(self._scale_mode)})"
)

def __str__(self) -> str:
return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}"
66 changes: 34 additions & 32 deletions src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,6 @@
from ._utils.non_differentiable import raise_non_differentiable_error


class CAGrad(GramianWeightedAggregator):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
`Conflict-Averse Gradient Descent for Multi-task Learning
<https://arxiv.org/pdf/2110.14048.pdf>`_.

:param c: The scale of the radius of the ball constraint.
:param norm_eps: A small value to avoid division by zero when normalizing.

.. note::
This aggregator is not installed by default. When not installed, trying to import it should
result in the following error:
``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
To install it, use ``pip install "torchjd[cagrad]"``.
"""

def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps))
self._c = c
self._norm_eps = norm_eps

# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})"

def __str__(self) -> str:
c_str = str(self._c).rstrip("0")
return f"CAGrad{c_str}"


class CAGradWeighting(Weighting[PSDMatrix]):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
Expand Down Expand Up @@ -104,3 +72,37 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype)

return weights


class CAGrad(GramianWeightedAggregator):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
`Conflict-Averse Gradient Descent for Multi-task Learning
<https://arxiv.org/pdf/2110.14048.pdf>`_.

:param c: The scale of the radius of the ball constraint.
:param norm_eps: A small value to avoid division by zero when normalizing.

.. note::
This aggregator is not installed by default. When not installed, trying to import it should
result in the following error:
``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
To install it, use ``pip install "torchjd[cagrad]"``.
"""

gramian_weighting: CAGradWeighting

def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps))
self._c = c
self._norm_eps = norm_eps

# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})"

def __str__(self) -> str:
c_str = str(self._c).rstrip("0")
return f"CAGrad{c_str}"
42 changes: 22 additions & 20 deletions src/torchjd/aggregation/_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,6 @@
from ._weighting_bases import Weighting


class Constant(WeightedAggregator):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that makes a linear combination of
the rows of the provided matrix, with constant, pre-determined weights.

:param weights: The weights associated to the rows of the input matrices.
"""

def __init__(self, weights: Tensor) -> None:
super().__init__(weighting=ConstantWeighting(weights=weights))
self._weights = weights

def __repr__(self) -> str:
return f"{self.__class__.__name__}(weights={repr(self._weights)})"

def __str__(self) -> str:
weights_str = vector_to_str(self._weights)
return f"{self.__class__.__name__}([{weights_str}])"


class ConstantWeighting(Weighting[Matrix]):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined
Expand Down Expand Up @@ -55,3 +35,25 @@ def _check_matrix_shape(self, matrix: Tensor) -> None:
f"Parameter `matrix` should have {len(self.weights)} rows (the number of specified "
f"weights). Found `matrix` with {matrix.shape[0]} rows.",
)


class Constant(WeightedAggregator):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that makes a linear combination of
the rows of the provided matrix, with constant, pre-determined weights.

:param weights: The weights associated to the rows of the input matrices.
"""

weighting: ConstantWeighting

def __init__(self, weights: Tensor) -> None:
super().__init__(weighting=ConstantWeighting(weights=weights))
self._weights = weights

def __repr__(self) -> str:
return f"{self.__class__.__name__}(weights={repr(self._weights)})"

def __str__(self) -> str:
weights_str = vector_to_str(self._weights)
return f"{self.__class__.__name__}([{weights_str}])"
74 changes: 38 additions & 36 deletions src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,42 @@
from ._weighting_bases import Weighting


class DualProjWeighting(Weighting[PSDMatrix]):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.DualProj`.

:param pref_vector: The preference vector to use. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
:param norm_eps: A small value to avoid division by zero when normalizing.
:param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to
numerical errors when computing the gramian, it might not exactly be positive definite.
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
ensures that it is positive definite.
:param solver: The solver used to optimize the underlying optimization problem.
"""

def __init__(
self,
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
) -> None:
super().__init__()
self._pref_vector = pref_vector
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
self.norm_eps = norm_eps
self.reg_eps = reg_eps
self.solver: SUPPORTED_SOLVER = solver

def forward(self, gramian: PSDMatrix, /) -> Tensor:
u = self.weighting(gramian)
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
w = project_weights(u, G, self.solver)
return w


class DualProj(GramianWeightedAggregator):
r"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input
Expand All @@ -27,6 +63,8 @@ class DualProj(GramianWeightedAggregator):
:param solver: The solver used to optimize the underlying optimization problem.
"""

gramian_weighting: DualProjWeighting

def __init__(
self,
pref_vector: Tensor | None = None,
Expand Down Expand Up @@ -54,39 +92,3 @@ def __repr__(self) -> str:

def __str__(self) -> str:
return f"DualProj{pref_vector_to_str_suffix(self._pref_vector)}"


class DualProjWeighting(Weighting[PSDMatrix]):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.DualProj`.

:param pref_vector: The preference vector to use. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
:param norm_eps: A small value to avoid division by zero when normalizing.
:param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to
numerical errors when computing the gramian, it might not exactly be positive definite.
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
ensures that it is positive definite.
:param solver: The solver used to optimize the underlying optimization problem.
"""

def __init__(
self,
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
) -> None:
super().__init__()
self._pref_vector = pref_vector
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
self.norm_eps = norm_eps
self.reg_eps = reg_eps
self.solver: SUPPORTED_SOLVER = solver

def forward(self, gramian: PSDMatrix, /) -> Tensor:
u = self.weighting(gramian)
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
w = project_weights(u, G, self.solver)
return w
Loading
Loading