diff --git a/src/torchjd/_linalg/__init__.py b/src/torchjd/_linalg/__init__.py index 29b8cd0b3..9db3b8a77 100644 --- a/src/torchjd/_linalg/__init__.py +++ b/src/torchjd/_linalg/__init__.py @@ -1,8 +1,11 @@ from ._generalized_gramian import flatten, movedim, reshape from ._gramian import compute_gramian, normalize, regularize from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor +from ._structure import Structure, extract_structure __all__ = [ + "extract_structure", + "Structure", "compute_gramian", "normalize", "regularize", diff --git a/src/torchjd/_linalg/_structure.py b/src/torchjd/_linalg/_structure.py new file mode 100644 index 000000000..303a6c112 --- /dev/null +++ b/src/torchjd/_linalg/_structure.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + +import torch + +from torchjd._linalg import Matrix + + +@dataclass +class Structure: + m: int + device: torch.device + dtype: torch.dtype + + +def extract_structure(matrix: Matrix) -> Structure: + return Structure(m=matrix.shape[0], device=matrix.device, dtype=matrix.dtype) diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index 0485e7261..bf63cf8c8 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -1,20 +1,13 @@ from torch import Tensor -from torchjd._linalg import Matrix +from torchjd.aggregation._weighting_bases import FromNothingWeighting from ._aggregator_bases import WeightedAggregator from ._utils.str import vector_to_str from ._weighting_bases import Weighting -class ConstantWeighting(Weighting[Matrix]): - """ - :class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined - weights. - - :param weights: The weights to return at each call. - """ - +class _ConstantWeighting(Weighting[None]): def __init__(self, weights: Tensor) -> None: if weights.dim() != 1: raise ValueError( @@ -25,16 +18,20 @@ def __init__(self, weights: Tensor) -> None: super().__init__() self.weights = weights - def forward(self, matrix: Tensor, /) -> Tensor: - self._check_matrix_shape(matrix) + def forward(self, _: None, /) -> Tensor: return self.weights - def _check_matrix_shape(self, matrix: Tensor) -> None: - if matrix.shape[0] != len(self.weights): - raise ValueError( - f"Parameter `matrix` should have {len(self.weights)} rows (the number of specified " - f"weights). Found `matrix` with {matrix.shape[0]} rows.", - ) + +class ConstantWeighting(FromNothingWeighting): + """ + :class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined + weights. + + :param weights: The weights to return at each call. + """ + + def __init__(self, weights: Tensor) -> None: + super().__init__(_ConstantWeighting(weights)) class Constant(WeightedAggregator): diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index 2ebe208de..a4bde82f0 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -1,25 +1,31 @@ import torch from torch import Tensor -from torchjd._linalg import Matrix +from torchjd._linalg import Structure +from torchjd.aggregation._weighting_bases import FromStructureWeighting from ._aggregator_bases import WeightedAggregator from ._weighting_bases import Weighting -class MeanWeighting(Weighting[Matrix]): +class _MeanWeighting(Weighting[Structure]): + def forward(self, structure: Structure, /) -> Tensor: + device = structure.device + dtype = structure.dtype + m = structure.m + weights = torch.full(size=[m], fill_value=1 / m, device=device, dtype=dtype) + return weights + + +class MeanWeighting(FromStructureWeighting): r""" :class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. """ - def forward(self, matrix: Tensor, /) -> Tensor: - device = matrix.device - dtype = matrix.dtype - m = matrix.shape[0] - weights = torch.full(size=[m], fill_value=1 / m, device=device, dtype=dtype) - return weights + def __init__(self) -> None: + super().__init__(_MeanWeighting()) class Mean(WeightedAggregator): diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index 8345a15cb..b61cc0fb5 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -2,22 +2,28 @@ from torch import Tensor from torch.nn import functional as F -from torchjd._linalg import Matrix +from torchjd._linalg import Structure +from torchjd.aggregation._weighting_bases import FromStructureWeighting from ._aggregator_bases import WeightedAggregator from ._weighting_bases import Weighting -class RandomWeighting(Weighting[Matrix]): +class _RandomWeighting(Weighting[Structure]): + def forward(self, structure: Structure, /) -> Tensor: + random_vector = torch.randn(structure.m, device=structure.device, dtype=structure.dtype) + weights = F.softmax(random_vector, dim=-1) + return weights + + +class RandomWeighting(FromStructureWeighting): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights at each call. """ - def forward(self, matrix: Tensor, /) -> Tensor: - random_vector = torch.randn(matrix.shape[0], device=matrix.device, dtype=matrix.dtype) - weights = F.softmax(random_vector, dim=-1) - return weights + def __init__(self) -> None: + super().__init__(_RandomWeighting()) class Random(WeightedAggregator): diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index 0754f4668..b46888326 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -1,23 +1,27 @@ import torch from torch import Tensor -from torchjd._linalg import Matrix +from torchjd._linalg import Structure +from torchjd.aggregation._weighting_bases import FromStructureWeighting from ._aggregator_bases import WeightedAggregator from ._weighting_bases import Weighting -class SumWeighting(Weighting[Matrix]): +class _SumWeighting(Weighting[Structure]): + def forward(self, structure: Structure, /) -> Tensor: + weights = torch.ones(structure.m, device=structure.device, dtype=structure.dtype) + return weights + + +class SumWeighting(FromStructureWeighting): r""" :class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights :math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`. """ - def forward(self, matrix: Tensor, /) -> Tensor: - device = matrix.device - dtype = matrix.dtype - weights = torch.ones(matrix.shape[0], device=device, dtype=dtype) - return weights + def __init__(self) -> None: + super().__init__(_SumWeighting()) class Sum(WeightedAggregator): diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index e321169c3..ee91f2347 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -6,11 +6,11 @@ from torch import Tensor, nn -from torchjd._linalg import PSDTensor, is_psd_tensor +from torchjd._linalg import Matrix, PSDTensor, Structure, extract_structure, is_psd_tensor -_T = TypeVar("_T", contravariant=True, bound=Tensor) -_FnInputT = TypeVar("_FnInputT", bound=Tensor) -_FnOutputT = TypeVar("_FnOutputT", bound=Tensor) +_T = TypeVar("_T", contravariant=True) +_FnInputT = TypeVar("_FnInputT") +_FnOutputT = TypeVar("_FnOutputT") class Weighting(nn.Module, ABC, Generic[_T]): @@ -27,11 +27,9 @@ def __init__(self) -> None: def forward(self, stat: _T, /) -> Tensor: """Computes the vector of weights from the input stat.""" - def __call__(self, stat: Tensor, /) -> Tensor: + def __call__(self, stat: object, /) -> Tensor: """Computes the vector of weights from the input stat and applies all registered hooks.""" - # The value of _T (e.g. PSDMatrix) is not public, so we need the user-facing type hint of - # stat to be Tensor. return super().__call__(stat) def _compose(self, fn: Callable[[_FnInputT], _T]) -> Weighting[_FnInputT]: @@ -55,6 +53,32 @@ def forward(self, stat: _T, /) -> Tensor: return self.weighting(self.fn(stat)) +class FromStructureWeighting(_Composition[Matrix]): + """ + Weighting that extracts the structure of the input matrix before applying a Weighting to it. + + :param structure_weighting: The object responsible for extracting the vector of weights from the + structure. + """ + + def __init__(self, structure_weighting: Weighting[Structure]) -> None: + super().__init__(structure_weighting, extract_structure) + self.structure_weighting = structure_weighting + + +class FromNothingWeighting(_Composition[Matrix]): + """ + Weighting that extracts nothing from the input matrix before applying a Weighting to it (i.e. to + None). + + :param none_weighting: The object responsible for extracting the vector of weights from nothing. + """ + + def __init__(self, none_weighting: Weighting[None]) -> None: + super().__init__(none_weighting, lambda _: None) + self.none_weighting = none_weighting + + class GeneralizedWeighting(nn.Module, ABC): r""" Abstract base class for all weightings that operate on generalized Gramians. It has the role of diff --git a/tests/unit/aggregation/test_constant.py b/tests/unit/aggregation/test_constant.py index aa1332fcb..4fa4488ef 100644 --- a/tests/unit/aggregation/test_constant.py +++ b/tests/unit/aggregation/test_constant.py @@ -63,29 +63,6 @@ def test_weights_shape_check(weights_shape: list[int], expectation: ExceptionCon _ = Constant(weights=weights) -@mark.parametrize( - ["weights_shape", "n_rows", "expectation"], - [ - ([0], 0, does_not_raise()), - ([1], 1, does_not_raise()), - ([5], 5, does_not_raise()), - ([0], 1, raises(ValueError)), - ([1], 0, raises(ValueError)), - ([4], 5, raises(ValueError)), - ([5], 4, raises(ValueError)), - ], -) -def test_matrix_shape_check( - weights_shape: list[int], n_rows: int, expectation: ExceptionContext -) -> None: - matrix = ones_([n_rows, 5]) - weights = ones_(weights_shape) - aggregator = Constant(weights) - - with expectation: - _ = aggregator(matrix) - - def test_representations() -> None: A = Constant(weights=torch.tensor([1.0, 2.0], device="cpu")) assert repr(A) == "Constant(weights=tensor([1., 2.]))"