diff --git a/declearn/model/torch/__init__.py b/declearn/model/torch/__init__.py index efdf95f1d5a2ab94b777667fcd3787fb27fdea29..9bf89c63e678d6d58a3bbded6d06765e9251ae22 100644 --- a/declearn/model/torch/__init__.py +++ b/declearn/model/torch/__init__.py @@ -23,6 +23,7 @@ gradient descent. This module exposes: * TorchModel: Model subclass to wrap torch.nn.Module objects +* TorchOptiModule: OptiModule subclass to wrap torch.nn.Optimizer objects * TorchVector: Vector subclass to wrap torch.Tensor objects It also exposes the `utils` submodule, which mainly aims at @@ -31,4 +32,5 @@ providing tools used in the backend of the former objects. from . import utils from ._vector import TorchVector +from ._optim import TorchOptiModule from ._model import TorchModel diff --git a/declearn/model/torch/_optim.py b/declearn/model/torch/_optim.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb505bbd5cac9b2ed67e2257ac07bb2a67292c7 --- /dev/null +++ b/declearn/model/torch/_optim.py @@ -0,0 +1,324 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hacky OptiModule subclass enabling the use of a torch.nn.Optimizer.""" + +import importlib +from typing import Any, Dict, List, Optional, Union, Tuple, Type + +import numpy as np +import torch +from typing_extensions import Self # future: import from typing (py >=3.11) + +from declearn.model.api import Vector +from declearn.model.torch.utils import select_device +from declearn.model.torch._vector import TorchVector +from declearn.optimizer.modules import OptiModule +from declearn.utils import get_device_policy + + +__all__ = [ + "TorchOptiModule", +] + + +class TorchOptiModule(OptiModule): + """Hacky OptiModule subclass to wrap a torch Optimizer. + + This torch-only OptiModule enables wrapping a `torch.nn.Optimizer` + to make it part of a declearn Optimizer pipeline, where it may be + combined with other framework-agnostic tools (notably FL-specific + ones such as the FedProx loss regularizer). + + The wrapped torch Optimizer states will be placed on a device (CPU + or GPU) selected automatically based on the first input gradients' + placement OR on the global device policy when `set_state` is used. + The `reset` method may be used to drop internal optimizer states + and device-placement choices at once. + + Please note that this relies on a hack that may have unforeseen side + effects on the optimization algorithm if used carelessly and will at + any rate cause some memory overhead. Thus it should be used sparingly, + taking into account the following constraints and limitations: + + * The wrapped optimizer class should have a "lr" (learning-rate) + parameter, that will be forced to 1.0, so that updates' scaling + remains the responsibility of the wrapping declearn Optimizer. + * The wrapped optimizer class should not make use of the watched + parameters' values, only of their gradients, because it will in + fact monitor artificial, zero-valued parameters at each step. + * If the module is to be used by the clients, the wrapped optimizer + class must have been imported from a third-party package that is + also available to the clients (e.g. torch). + + Also note that passing a string input as `optim_cls` (as is always done + when deserializing the module from its auto-generated config) may raise + security concerns due to its resulting in importing external code. As a + consequence, users will be asked to validate any non-torch import before + it is executed. This may be disabled when instantiating the module from + its init constructor but not when using `from_config`, `from_specs` or + `deserialize`. + + This class is mostly provided for experimental use of algorithms that + are not natively available in declearn, for users that do not want to + put in (or reserve for later) the effort of writing a custom, dedicated, + framework-agnostic OptiModule subclass implementing that algorithm. + If you encounter issues, please report to the declearn developers, and + we will be happy to assist with debugging the present module and/or + implementing the desired algorithm as a proper OptiModule. + """ + + name = "torch-optim" + + def __init__( + self, + optim_cls: Union[Type[torch.optim.Optimizer], str], + validate: bool = True, + **kwargs: Any, + ) -> None: + """Instantiate a hacky torch optimizer plug-in module. + + Parameters + ---------- + optim_cls: type[torch.optim.Optimizer] or str + Class constructor of the torch Optimizer that needs wrapping. + A string containing its import path may be provided instead. + validate: bool, default=True + Whether the user should be prompted to validate the module- + import action triggered in case `optim_cls` is a string and + targets another package than 'torch'. + **kwargs: Any + Keyword arguments to `optim_cls`. + Note that "lr" will be forced to 1.0. + """ + self.optim_cls = self._validate_optim_cls(optim_cls, validate) + self.kwargs = kwargs + self.kwargs["lr"] = 1.0 + self._params = {} # type: Dict[str, torch.nn.Parameter] + self._optim = None # type: Optional[torch.optim.Optimizer] + + def _validate_optim_cls( + self, + optim_cls: Union[Type[torch.optim.Optimizer], str], + validate: bool = True, + ) -> Type[torch.optim.Optimizer]: + """Type-check and optionally import a torch Optimizer class. + + Parameters + ---------- + optim_cls: Type[torch.optim.Optimizer] or str + Either a torch Optimizer class constructor, or the import path + to one, from which it will be retrieved. + validate: bool, default=True + Whether the user should be prompted to validate the module- + import action triggered in case `optim_cls` is a string and + targets another package than 'torch'. + + Raises + ------ + RuntimeError: + If `optim_cls` is a string and the target class cannot be loaded. + If `optim_cls` is a string and the user denies the import command. + TypeError: + If `optim_cls` (or the object loaded in case it is a string) + is not a `torch.nn.Optimizer` subclass. + + Returns + ------- + optim_cls: Type[torch.optim.Optimizer] + Torch Optimizer class constructor. + """ + if isinstance(optim_cls, str): + try: + module, name = optim_cls.rsplit(".", 1) + if validate and (module.split(".", 1)[0] != "torch"): + accept = input( + f"TorchOptiModule requires importing the '{module}' " + "module.\nDo you agree to this? [y/N] " + ) + if not accept.lower().startswith("y"): + raise RuntimeError(f"User refused to import {module}.") + optim_mod = importlib.import_module(module) + optim_cls = getattr(optim_mod, name) + except (AttributeError, ModuleNotFoundError, RuntimeError) as exc: + raise RuntimeError( + "Could not load TorchOptiModule's wrapped " + f"torch optimizer class: {exc}" + ) from exc + if not ( + isinstance(optim_cls, type) + and issubclass(optim_cls, torch.optim.Optimizer) + ): + raise TypeError( + "'optim_cls' should be a torch Optimizer subclass." + ) + return optim_cls + + def run( + self, + gradients: Vector, + ) -> Vector: + """Run input gradients through the wrapped torch Optimizer. + + Parameters + ---------- + gradients: TorchVector + Input gradients that are to be processed and updated. + + Raises + ------ + TypeError: + If `gradients` are not a TorchVector (this module is + a framework-specific hack). + KeyError: + If `gradients` have an inconsistent spec with the first + ones ever processed by this module. Use `reset` if you + wish to start back from the beginning. + + Returns + ------- + gradients: TorchVector + Modified input gradients. The output Vector should be + fully compatible with the input one - only the values + of the wrapped coefficients may have changed. + """ + if not isinstance(gradients, TorchVector): + raise TypeError( + "TorchOptiModule only supports TorchVector input gradients." + ) + if self._optim is None: + self._optim = self._init_optimizer(gradients) + if gradients.coefs.keys() != self._params.keys(): + raise KeyError( + "Mismatch between input gradients and stored parameters." + ) + for key, grad in gradients.coefs.items(): + param = self._params[key] + with torch.no_grad(): + param.zero_() + param.grad = -grad.to(param.device) # devices *must* be the same + self._optim.step() + coefs = { + key: param.detach().clone().to(gradients.coefs[key].device) + for key, param in self._params.items() + } + return TorchVector(coefs) + + def _init_optimizer(self, gradients: TorchVector) -> torch.optim.Optimizer: + """Instantiate and return a torch Optimizer to make use of. + + Place the artifical parameters and optimizer states on the + same device as the input gradients. + """ + # false-positive on torch.zeros_like; pylint: disable=no-member + self._params = { + key: torch.nn.Parameter(torch.zeros_like(grad)) + for key, grad in gradients.coefs.items() + } + return self.optim_cls(list(self._params.values()), **self.kwargs) + + def reset(self) -> None: + """Reset this module to its uninitialized state. + + Discard the wrapped torch parameters (that define a required + specification of input gradients) and torch Optimizer. As a + consequence, the next call to `run` will result in creating + a new Optimizer from scratch and setting a new specification. + """ + self._params = {} + self._optim = None + + def get_config( + self, + ) -> Dict[str, Any]: + optim_cls = f"{self.optim_cls.__module__}.{self.optim_cls.__name__}" + return {"optim_cls": optim_cls, "kwargs": self.kwargs} + + @classmethod + def from_config( + cls, + config: Dict[str, Any], + ) -> Self: + if "optim_cls" not in config: + raise TypeError( + "TorchOptiModule config is missing required key 'optim_cls'." + ) + kwargs = config.get("kwargs", {}) + kwargs.pop("validate", None) # force manual validation of imports + return cls(config["optim_cls"], validate=True, **kwargs) + + def get_state( + self, + ) -> Dict[str, Any]: + params = TorchVector({k: p.data for k, p in self._params.items()}) + dtypes = params.dtypes() + shapes = params.shapes() + specs = {key: (shapes[key], dtypes[key]) for key in self._params} + sdict = ( + {"state": {}} if self._optim is None else self._optim.state_dict() + ) + state = [] # type: List[Tuple[int, Dict[str, Any]]] + for key, group in sdict["state"].items(): + gval = { + k: v.cpu().numpy().copy() if isinstance(v, torch.Tensor) else v + for k, v in group.items() + } + state.append((key, gval)) + return {"specs": specs, "state": state} + + def set_state( + self, + state: Dict[str, Any], + ) -> None: + for key in ("specs", "state"): + if key not in state: + raise KeyError( + "Missing required key in input TorchOptiModule state " + f"dict: '{key}'." + ) + self.reset() + # Early-exit if reloading from an uninitialized state. + if not state["state"]: + return None + # Consult the global device policy to place the variables and states. + policy = get_device_policy() + device = select_device(gpu=policy.gpu, idx=policy.idx) + # Restore weight variables' specifications from the input state dict. + self._params = {} + for key, (shape, dtype) in state["specs"].items(): + zeros = torch.zeros( # false-positive; pylint: disable=no-member + tuple(shape), dtype=getattr(torch, dtype), device=device + ) + self._params[key] = torch.nn.Parameter(zeros) + self._optim = self.optim_cls( + list(self._params.values()), **self.kwargs + ) + # Restore optimizer variables' values from the input state dict. + sdict = self._optim.state_dict() + sdict["state"] = { + key: { + k: ( + torch.from_numpy(v).to(device) # pylint: disable=no-member + if isinstance(v, np.ndarray) + else v + ) + for k, v in group.items() + } + for key, group in state["state"] + } + self._optim.load_state_dict(sdict) + return None diff --git a/test/optimizer/test_modules.py b/test/optimizer/test_modules.py index 1b2fdf53dfff3dc53a8e23b9f868394cb06059a5..dcd9c904a6c67cb8ac27c26005780cd8444f1efe 100644 --- a/test/optimizer/test_modules.py +++ b/test/optimizer/test_modules.py @@ -54,8 +54,9 @@ from optim_testing import PluginTestBase sys.path.pop() # fmt: on - +# Access the list of modules to test; remove some that have dedicated tests. OPTIMODULE_SUBCLASSES = access_types_mapping(group="OptiModule") +OPTIMODULE_SUBCLASSES.pop("torch-optim", None) set_device_policy(gpu=False) # run all OptiModule tests on CPU diff --git a/test/optimizer/test_torch_optim.py b/test/optimizer/test_torch_optim.py new file mode 100644 index 0000000000000000000000000000000000000000..e032a4c25e29f915139cc998f8c9a678cd8d7867 --- /dev/null +++ b/test/optimizer/test_torch_optim.py @@ -0,0 +1,245 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the TorchOptiModule class.""" + +import importlib +import sys +from unittest import mock +from typing import Iterator, Type + +import numpy as np +import pytest + +try: + import torch +except ModuleNotFoundError: + pytest.skip("Torch is unavailable", allow_module_level=True) + +from declearn.model.torch import TorchOptiModule, TorchVector +from declearn.optimizer.modules import OptiModule +from declearn.test_utils import GradientsTestCase +from declearn.utils import set_device_policy + + +# dirty trick to import from `model_testing.py`; +# fmt: off +# pylint: disable=wrong-import-order, wrong-import-position +sys.path.append(".") +from test_modules import OptiModuleTestSuite +sys.path.pop() +# fmt: on + + +set_device_policy(gpu=False) # force most tests to run on CPU + + +DEVICES = ["CPU"] +if torch.cuda.device_count(): + DEVICES.append("GPU") + + +@pytest.fixture(name="optim_cls") +def optim_cls_fixture(optim: str) -> str: + """Fixture to provide with a torch Optimizer's import path.""" + return { + "adam": "torch.optim.Adam", + "rmsprop": "torch.optim.RMSprop", + "adagrad": "torch.optim.Adagrad", + }[optim] + + +@pytest.fixture(name="cls") +def cls_fixture(optim_cls: str) -> Iterator[Type[TorchOptiModule]]: + """Fixture to provide with preset TorchOptiModule constructors.""" + defaults = TorchOptiModule.__init__.__defaults__ + TorchOptiModule.__init__.__defaults__ = (optim_cls, False) + with mock.patch("builtins.input") as patch_input: + patch_input.return_value = "y" + yield TorchOptiModule + TorchOptiModule.__init__.__defaults__ = defaults + + +def get_optim_dec(optim_cls: str) -> OptiModule: + """Instanciate an OptiModule parameterized to match a torch one.""" + optim = optim_cls.rsplit(".", 1)[-1].lower() + if optim == "adam": + name = "adam" + kwargs = {"beta_1": 0.9, "beta_2": 0.999, "eps": 1e-08} + elif optim == "rmsprop": + name = "rmsprop" + kwargs = {"beta": 0.99, "eps": 1e-08} + elif optim == "adagrad": + name = "adagrad" + kwargs = {"eps": 1e-10} + else: + raise KeyError(f"Invalid 'optim' fixture parameter value: '{optim}'.") + return OptiModule.from_specs(name, kwargs) + + +@pytest.fixture(name="framework") +def framework_fixture(): + """Fixture to ensure 'TorchOptiModule' only receives torch gradients.""" + return "torch" + + +@pytest.mark.parametrize("optim", ["adam", "rmsprop", "adagrad"]) +class TestTorchOptiModule(OptiModuleTestSuite): + """Unit tests for declearn.model.torch.TorchOptiModule.""" + + def test_run_equivalence(self, cls: Type[OptiModule]) -> None: + # This test is undefined for a framework-specific plugin. + pass + + def test_validate_torch(self, optim_cls: str) -> None: + """Test that user-validation of torch modules' import is skipped.""" + with mock.patch("builtins.input") as patch_input: + patch_input.return_value = "y" + TorchOptiModule(optim_cls, validate=True) + patch_input.assert_not_called() + + def test_equivalent_declearn(self, optim_cls: str) -> None: + """Test that declearn modules are equivalent to torch ones. + + Instantiate a TorchOptiModule wrapping a torch.optim.Optimizer, + as well as a declearn OptiModule that matches its configuration + (adjusting hyper-parameters to torch's defaults). + + Ensure that on 10 successive passes on the same random-valued + input gradients, outputs have the same values, up to numerical + precision (relative tolerance of 10^-5, absolute of 10^-8). + """ + optim_pyt = TorchOptiModule(optim_cls, validate=False) + optim_dec = get_optim_dec(optim_cls) + gradients = GradientsTestCase("torch").mock_gradient + for _ in range(10): + grads_pyt = optim_pyt.run(gradients).coefs + grads_dec = optim_dec.run(gradients).coefs + assert all( + np.allclose(grads_pyt[key].numpy(), grads_dec[key].numpy()) + for key in gradients.coefs + ) + + def test_reset(self, optim_cls: str) -> None: + """Test that the `TorchOptiModule.reset` method works.""" + # Set up a module and two sets of differently-labeled gradients. + module = TorchOptiModule(optim_cls, validate=False) + grads_a = GradientsTestCase("torch").mock_gradient + grads_b = TorchVector( + {f"{key}_bis": val for key, val in grads_a.coefs.items()} + ) + # Check that running inconsistent gradients fails. + outpt_a = module.run(grads_a) + assert isinstance(outpt_a, TorchVector) + with pytest.raises(KeyError): + module.run(grads_b) + # Test that resetting enables setting up a new input spec. + module.reset() + outpt_b = module.run(grads_b) + assert isinstance(outpt_b, TorchVector) + # Check that results are indeed the same, save for their names. + # This means inner states have been properly reset. + outpt_a = TorchVector( + {f"{key}_bis": val for key, val in outpt_a.coefs.items()} + ) + assert outpt_b == outpt_a + + @pytest.mark.parametrize("device", DEVICES) + def test_device_placement(self, optim_cls: str, device: str) -> None: + """Test that the optimizer and computations are properly placed.""" + # Set the device policy, setup a module and run computations. + set_device_policy(gpu=(device == "GPU"), idx=None) + module = TorchOptiModule(optim_cls) + grads = GradientsTestCase("torch").mock_gradient + if device == "GPU": + for key, tensor in grads.coefs.items(): + grads.coefs[key] = tensor.cuda() + updts = module.run(grads) + # Assert that the outputs and internal states are properly placed. + dtype = "cuda" if device == "GPU" else "cpu" + assert all(t.device.type == dtype for t in updts.coefs.values()) + state = module.get_state()["state"] + assert all( + tensor.device.type == dtype + for _, group in state + for tensor in group + if isinstance(tensor, torch.Tensor) + ) + # Reset device policy to run other tests on CPU as expected. + set_device_policy(gpu=False) + + +class FakeOptimizer(torch.optim.Optimizer): + """Fake torch Optimizer subclass.""" + + step = mock.create_autospec(torch.optim.Optimizer.step) + + +class EmptyClass: + """Empty class to test non-torch-optimizer imports' failure.""" + + # mock class; pylint: disable=too-few-public-methods + + __init__ = mock.MagicMock() + + +class TestTorchOptiModuleValidateImports: + """Test that user-validation of third-party modules' import works.""" + + def test_validate_accept(self) -> None: + """Assert that (fake) user inputs enable validating imports.""" + # Set up the import string for FakeOptimizer and pre-import its module. + string = f"{__name__}.FakeOptimizer" + module = importlib.import_module(__name__) + # Run the TorchOptiModule instantiation with patched objects. + with mock.patch("builtins.input") as patch_input: + with mock.patch("importlib.import_module") as patch_import: + patch_input.return_value = "y" + patch_import.return_value = module + optim = TorchOptiModule(string, validate=True) + # Assert the expected calls were made and FakeOptimizer was assigned. + patch_input.assert_called_once() + patch_import.assert_called_once_with(__name__) + assert optim.optim_cls is FakeOptimizer + + def test_validate_deny(self) -> None: + """Assert (fake) user inputs enable blocking imports.""" + # Run the TorchOptiModule instantiation with fake user denial command. + with mock.patch("builtins.input") as patch_input: + with mock.patch("importlib.import_module") as patch_import: + patch_input.return_value = "n" + with pytest.raises(RuntimeError): + TorchOptiModule(f"{__name__}.FakeOptimizer", validate=True) + # Assert the expected calls were made (or not). + patch_input.assert_called_once() + patch_import.assert_not_called() + + def test_validate_wrong_class(self) -> None: + """Assert importing an invalid class raises a TypeError.""" + # Set up the import string for EmptyClass and pre-import its module. + string = f"{__name__}.EmptyClass" + module = importlib.import_module(__name__) + # Run the TorchOptiModule instantiation with patched objects. + with mock.patch("builtins.input") as patch_input: + with mock.patch("importlib.import_module") as patch_import: + patch_input.return_value = "y" + patch_import.return_value = module + with pytest.raises(TypeError): + TorchOptiModule(string, validate=True) + # Assert the expected calls were made. + patch_input.assert_called_once() + patch_import.assert_called_once_with(__name__)