diff --git a/declearn/model/torch/__init__.py b/declearn/model/torch/__init__.py
index efdf95f1d5a2ab94b777667fcd3787fb27fdea29..9bf89c63e678d6d58a3bbded6d06765e9251ae22 100644
--- a/declearn/model/torch/__init__.py
+++ b/declearn/model/torch/__init__.py
@@ -23,6 +23,7 @@ gradient descent.
 
 This module exposes:
 * TorchModel: Model subclass to wrap torch.nn.Module objects
+* TorchOptiModule: OptiModule subclass to wrap torch.nn.Optimizer objects
 * TorchVector: Vector subclass to wrap torch.Tensor objects
 
 It also exposes the `utils` submodule, which mainly aims at
@@ -31,4 +32,5 @@ providing tools used in the backend of the former objects.
 
 from . import utils
 from ._vector import TorchVector
+from ._optim import TorchOptiModule
 from ._model import TorchModel
diff --git a/declearn/model/torch/_optim.py b/declearn/model/torch/_optim.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdb505bbd5cac9b2ed67e2257ac07bb2a67292c7
--- /dev/null
+++ b/declearn/model/torch/_optim.py
@@ -0,0 +1,324 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Hacky OptiModule subclass enabling the use of a torch.nn.Optimizer."""
+
+import importlib
+from typing import Any, Dict, List, Optional, Union, Tuple, Type
+
+import numpy as np
+import torch
+from typing_extensions import Self  # future: import from typing (py >=3.11)
+
+from declearn.model.api import Vector
+from declearn.model.torch.utils import select_device
+from declearn.model.torch._vector import TorchVector
+from declearn.optimizer.modules import OptiModule
+from declearn.utils import get_device_policy
+
+
+__all__ = [
+    "TorchOptiModule",
+]
+
+
+class TorchOptiModule(OptiModule):
+    """Hacky OptiModule subclass to wrap a torch Optimizer.
+
+    This torch-only OptiModule enables wrapping a `torch.nn.Optimizer`
+    to make it part of a declearn Optimizer pipeline, where it may be
+    combined with other framework-agnostic tools (notably FL-specific
+    ones such as the FedProx loss regularizer).
+
+    The wrapped torch Optimizer states will be placed on a device (CPU
+    or GPU) selected automatically based on the first input gradients'
+    placement OR on the global device policy when `set_state` is used.
+    The `reset` method may be used to drop internal optimizer states
+    and device-placement choices at once.
+
+    Please note that this relies on a hack that may have unforeseen side
+    effects on the optimization algorithm if used carelessly and will at
+    any rate cause some memory overhead. Thus it should be used sparingly,
+    taking into account the following constraints and limitations:
+
+    * The wrapped optimizer class should have a "lr" (learning-rate)
+      parameter, that will be forced to 1.0, so that updates' scaling
+      remains the responsibility of the wrapping declearn Optimizer.
+    * The wrapped optimizer class should not make use of the watched
+      parameters' values, only of their gradients, because it will in
+      fact monitor artificial, zero-valued parameters at each step.
+    * If the module is to be used by the clients, the wrapped optimizer
+      class must have been imported from a third-party package that is
+      also available to the clients (e.g. torch).
+
+    Also note that passing a string input as `optim_cls` (as is always done
+    when deserializing the module from its auto-generated config) may raise
+    security concerns due to its resulting in importing external code. As a
+    consequence, users will be asked to validate any non-torch import before
+    it is executed. This may be disabled when instantiating the module from
+    its init constructor but not when using `from_config`, `from_specs` or
+    `deserialize`.
+
+    This class is mostly provided for experimental use of algorithms that
+    are not natively available in declearn, for users that do not want to
+    put in (or reserve for later) the effort of writing a custom, dedicated,
+    framework-agnostic OptiModule subclass implementing that algorithm.
+    If you encounter issues, please report to the declearn developers, and
+    we will be happy to assist with debugging the present module and/or
+    implementing the desired algorithm as a proper OptiModule.
+    """
+
+    name = "torch-optim"
+
+    def __init__(
+        self,
+        optim_cls: Union[Type[torch.optim.Optimizer], str],
+        validate: bool = True,
+        **kwargs: Any,
+    ) -> None:
+        """Instantiate a hacky torch optimizer plug-in module.
+
+        Parameters
+        ----------
+        optim_cls: type[torch.optim.Optimizer] or str
+            Class constructor of the torch Optimizer that needs wrapping.
+            A string containing its import path may be provided instead.
+        validate: bool, default=True
+            Whether the user should be prompted to validate the module-
+            import action triggered in case `optim_cls` is a string and
+            targets another package than 'torch'.
+        **kwargs: Any
+            Keyword arguments to `optim_cls`.
+            Note that "lr" will be forced to 1.0.
+        """
+        self.optim_cls = self._validate_optim_cls(optim_cls, validate)
+        self.kwargs = kwargs
+        self.kwargs["lr"] = 1.0
+        self._params = {}  # type: Dict[str, torch.nn.Parameter]
+        self._optim = None  # type: Optional[torch.optim.Optimizer]
+
+    def _validate_optim_cls(
+        self,
+        optim_cls: Union[Type[torch.optim.Optimizer], str],
+        validate: bool = True,
+    ) -> Type[torch.optim.Optimizer]:
+        """Type-check and optionally import a torch Optimizer class.
+
+        Parameters
+        ----------
+        optim_cls: Type[torch.optim.Optimizer] or str
+            Either a torch Optimizer class constructor, or the import path
+            to one, from which it will be retrieved.
+        validate: bool, default=True
+            Whether the user should be prompted to validate the module-
+            import action triggered in case `optim_cls` is a string and
+            targets another package than 'torch'.
+
+        Raises
+        ------
+        RuntimeError:
+            If `optim_cls` is a string and the target class cannot be loaded.
+            If `optim_cls` is a string and the user denies the import command.
+        TypeError:
+            If `optim_cls` (or the object loaded in case it is a string)
+            is not a `torch.nn.Optimizer` subclass.
+
+        Returns
+        -------
+        optim_cls: Type[torch.optim.Optimizer]
+            Torch Optimizer class constructor.
+        """
+        if isinstance(optim_cls, str):
+            try:
+                module, name = optim_cls.rsplit(".", 1)
+                if validate and (module.split(".", 1)[0] != "torch"):
+                    accept = input(
+                        f"TorchOptiModule requires importing the '{module}' "
+                        "module.\nDo you agree to this? [y/N] "
+                    )
+                    if not accept.lower().startswith("y"):
+                        raise RuntimeError(f"User refused to import {module}.")
+                optim_mod = importlib.import_module(module)
+                optim_cls = getattr(optim_mod, name)
+            except (AttributeError, ModuleNotFoundError, RuntimeError) as exc:
+                raise RuntimeError(
+                    "Could not load TorchOptiModule's wrapped "
+                    f"torch optimizer class: {exc}"
+                ) from exc
+        if not (
+            isinstance(optim_cls, type)
+            and issubclass(optim_cls, torch.optim.Optimizer)
+        ):
+            raise TypeError(
+                "'optim_cls' should be a torch Optimizer subclass."
+            )
+        return optim_cls
+
+    def run(
+        self,
+        gradients: Vector,
+    ) -> Vector:
+        """Run input gradients through the wrapped torch Optimizer.
+
+        Parameters
+        ----------
+        gradients: TorchVector
+            Input gradients that are to be processed and updated.
+
+        Raises
+        ------
+        TypeError:
+            If `gradients` are not a TorchVector (this module is
+            a framework-specific hack).
+        KeyError:
+            If `gradients` have an inconsistent spec with the first
+            ones ever processed by this module. Use `reset` if you
+            wish to start back from the beginning.
+
+        Returns
+        -------
+        gradients: TorchVector
+            Modified input gradients. The output Vector should be
+            fully compatible with the input one - only the values
+            of the wrapped coefficients may have changed.
+        """
+        if not isinstance(gradients, TorchVector):
+            raise TypeError(
+                "TorchOptiModule only supports TorchVector input gradients."
+            )
+        if self._optim is None:
+            self._optim = self._init_optimizer(gradients)
+        if gradients.coefs.keys() != self._params.keys():
+            raise KeyError(
+                "Mismatch between input gradients and stored parameters."
+            )
+        for key, grad in gradients.coefs.items():
+            param = self._params[key]
+            with torch.no_grad():
+                param.zero_()
+            param.grad = -grad.to(param.device)  # devices *must* be the same
+        self._optim.step()
+        coefs = {
+            key: param.detach().clone().to(gradients.coefs[key].device)
+            for key, param in self._params.items()
+        }
+        return TorchVector(coefs)
+
+    def _init_optimizer(self, gradients: TorchVector) -> torch.optim.Optimizer:
+        """Instantiate and return a torch Optimizer to make use of.
+
+        Place the artifical parameters and optimizer states on the
+        same device as the input gradients.
+        """
+        # false-positive on torch.zeros_like; pylint: disable=no-member
+        self._params = {
+            key: torch.nn.Parameter(torch.zeros_like(grad))
+            for key, grad in gradients.coefs.items()
+        }
+        return self.optim_cls(list(self._params.values()), **self.kwargs)
+
+    def reset(self) -> None:
+        """Reset this module to its uninitialized state.
+
+        Discard the wrapped torch parameters (that define a required
+        specification of input gradients) and torch Optimizer. As a
+        consequence, the next call to `run` will result in creating
+        a new Optimizer from scratch and setting a new specification.
+        """
+        self._params = {}
+        self._optim = None
+
+    def get_config(
+        self,
+    ) -> Dict[str, Any]:
+        optim_cls = f"{self.optim_cls.__module__}.{self.optim_cls.__name__}"
+        return {"optim_cls": optim_cls, "kwargs": self.kwargs}
+
+    @classmethod
+    def from_config(
+        cls,
+        config: Dict[str, Any],
+    ) -> Self:
+        if "optim_cls" not in config:
+            raise TypeError(
+                "TorchOptiModule config is missing required key 'optim_cls'."
+            )
+        kwargs = config.get("kwargs", {})
+        kwargs.pop("validate", None)  # force manual validation of imports
+        return cls(config["optim_cls"], validate=True, **kwargs)
+
+    def get_state(
+        self,
+    ) -> Dict[str, Any]:
+        params = TorchVector({k: p.data for k, p in self._params.items()})
+        dtypes = params.dtypes()
+        shapes = params.shapes()
+        specs = {key: (shapes[key], dtypes[key]) for key in self._params}
+        sdict = (
+            {"state": {}} if self._optim is None else self._optim.state_dict()
+        )
+        state = []  # type: List[Tuple[int, Dict[str, Any]]]
+        for key, group in sdict["state"].items():
+            gval = {
+                k: v.cpu().numpy().copy() if isinstance(v, torch.Tensor) else v
+                for k, v in group.items()
+            }
+            state.append((key, gval))
+        return {"specs": specs, "state": state}
+
+    def set_state(
+        self,
+        state: Dict[str, Any],
+    ) -> None:
+        for key in ("specs", "state"):
+            if key not in state:
+                raise KeyError(
+                    "Missing required key in input TorchOptiModule state "
+                    f"dict: '{key}'."
+                )
+        self.reset()
+        # Early-exit if reloading from an uninitialized state.
+        if not state["state"]:
+            return None
+        # Consult the global device policy to place the variables and states.
+        policy = get_device_policy()
+        device = select_device(gpu=policy.gpu, idx=policy.idx)
+        # Restore weight variables' specifications from the input state dict.
+        self._params = {}
+        for key, (shape, dtype) in state["specs"].items():
+            zeros = torch.zeros(  # false-positive; pylint: disable=no-member
+                tuple(shape), dtype=getattr(torch, dtype), device=device
+            )
+            self._params[key] = torch.nn.Parameter(zeros)
+        self._optim = self.optim_cls(
+            list(self._params.values()), **self.kwargs
+        )
+        # Restore optimizer variables' values from the input state dict.
+        sdict = self._optim.state_dict()
+        sdict["state"] = {
+            key: {
+                k: (
+                    torch.from_numpy(v).to(device)  # pylint: disable=no-member
+                    if isinstance(v, np.ndarray)
+                    else v
+                )
+                for k, v in group.items()
+            }
+            for key, group in state["state"]
+        }
+        self._optim.load_state_dict(sdict)
+        return None
diff --git a/test/optimizer/test_modules.py b/test/optimizer/test_modules.py
index 1b2fdf53dfff3dc53a8e23b9f868394cb06059a5..dcd9c904a6c67cb8ac27c26005780cd8444f1efe 100644
--- a/test/optimizer/test_modules.py
+++ b/test/optimizer/test_modules.py
@@ -54,8 +54,9 @@ from optim_testing import PluginTestBase
 sys.path.pop()
 # fmt: on
 
-
+# Access the list of modules to test; remove some that have dedicated tests.
 OPTIMODULE_SUBCLASSES = access_types_mapping(group="OptiModule")
+OPTIMODULE_SUBCLASSES.pop("torch-optim", None)
 
 set_device_policy(gpu=False)  # run all OptiModule tests on CPU
 
diff --git a/test/optimizer/test_torch_optim.py b/test/optimizer/test_torch_optim.py
new file mode 100644
index 0000000000000000000000000000000000000000..e032a4c25e29f915139cc998f8c9a678cd8d7867
--- /dev/null
+++ b/test/optimizer/test_torch_optim.py
@@ -0,0 +1,245 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for the TorchOptiModule class."""
+
+import importlib
+import sys
+from unittest import mock
+from typing import Iterator, Type
+
+import numpy as np
+import pytest
+
+try:
+    import torch
+except ModuleNotFoundError:
+    pytest.skip("Torch is unavailable", allow_module_level=True)
+
+from declearn.model.torch import TorchOptiModule, TorchVector
+from declearn.optimizer.modules import OptiModule
+from declearn.test_utils import GradientsTestCase
+from declearn.utils import set_device_policy
+
+
+# dirty trick to import from `model_testing.py`;
+# fmt: off
+# pylint: disable=wrong-import-order, wrong-import-position
+sys.path.append(".")
+from test_modules import OptiModuleTestSuite
+sys.path.pop()
+# fmt: on
+
+
+set_device_policy(gpu=False)  # force most tests to run on CPU
+
+
+DEVICES = ["CPU"]
+if torch.cuda.device_count():
+    DEVICES.append("GPU")
+
+
+@pytest.fixture(name="optim_cls")
+def optim_cls_fixture(optim: str) -> str:
+    """Fixture to provide with a torch Optimizer's import path."""
+    return {
+        "adam": "torch.optim.Adam",
+        "rmsprop": "torch.optim.RMSprop",
+        "adagrad": "torch.optim.Adagrad",
+    }[optim]
+
+
+@pytest.fixture(name="cls")
+def cls_fixture(optim_cls: str) -> Iterator[Type[TorchOptiModule]]:
+    """Fixture to provide with preset TorchOptiModule constructors."""
+    defaults = TorchOptiModule.__init__.__defaults__
+    TorchOptiModule.__init__.__defaults__ = (optim_cls, False)
+    with mock.patch("builtins.input") as patch_input:
+        patch_input.return_value = "y"
+        yield TorchOptiModule
+    TorchOptiModule.__init__.__defaults__ = defaults
+
+
+def get_optim_dec(optim_cls: str) -> OptiModule:
+    """Instanciate an OptiModule parameterized to match a torch one."""
+    optim = optim_cls.rsplit(".", 1)[-1].lower()
+    if optim == "adam":
+        name = "adam"
+        kwargs = {"beta_1": 0.9, "beta_2": 0.999, "eps": 1e-08}
+    elif optim == "rmsprop":
+        name = "rmsprop"
+        kwargs = {"beta": 0.99, "eps": 1e-08}
+    elif optim == "adagrad":
+        name = "adagrad"
+        kwargs = {"eps": 1e-10}
+    else:
+        raise KeyError(f"Invalid 'optim' fixture parameter value: '{optim}'.")
+    return OptiModule.from_specs(name, kwargs)
+
+
+@pytest.fixture(name="framework")
+def framework_fixture():
+    """Fixture to ensure 'TorchOptiModule' only receives torch gradients."""
+    return "torch"
+
+
+@pytest.mark.parametrize("optim", ["adam", "rmsprop", "adagrad"])
+class TestTorchOptiModule(OptiModuleTestSuite):
+    """Unit tests for declearn.model.torch.TorchOptiModule."""
+
+    def test_run_equivalence(self, cls: Type[OptiModule]) -> None:
+        # This test is undefined for a framework-specific plugin.
+        pass
+
+    def test_validate_torch(self, optim_cls: str) -> None:
+        """Test that user-validation of torch modules' import is skipped."""
+        with mock.patch("builtins.input") as patch_input:
+            patch_input.return_value = "y"
+            TorchOptiModule(optim_cls, validate=True)
+            patch_input.assert_not_called()
+
+    def test_equivalent_declearn(self, optim_cls: str) -> None:
+        """Test that declearn modules are equivalent to torch ones.
+
+        Instantiate a TorchOptiModule wrapping a torch.optim.Optimizer,
+        as well as a declearn OptiModule that matches its configuration
+        (adjusting hyper-parameters to torch's defaults).
+
+        Ensure that on 10 successive passes on the same random-valued
+        input gradients, outputs have the same values, up to numerical
+        precision (relative tolerance of 10^-5, absolute of 10^-8).
+        """
+        optim_pyt = TorchOptiModule(optim_cls, validate=False)
+        optim_dec = get_optim_dec(optim_cls)
+        gradients = GradientsTestCase("torch").mock_gradient
+        for _ in range(10):
+            grads_pyt = optim_pyt.run(gradients).coefs
+            grads_dec = optim_dec.run(gradients).coefs
+            assert all(
+                np.allclose(grads_pyt[key].numpy(), grads_dec[key].numpy())
+                for key in gradients.coefs
+            )
+
+    def test_reset(self, optim_cls: str) -> None:
+        """Test that the `TorchOptiModule.reset` method works."""
+        # Set up a module and two sets of differently-labeled gradients.
+        module = TorchOptiModule(optim_cls, validate=False)
+        grads_a = GradientsTestCase("torch").mock_gradient
+        grads_b = TorchVector(
+            {f"{key}_bis": val for key, val in grads_a.coefs.items()}
+        )
+        # Check that running inconsistent gradients fails.
+        outpt_a = module.run(grads_a)
+        assert isinstance(outpt_a, TorchVector)
+        with pytest.raises(KeyError):
+            module.run(grads_b)
+        # Test that resetting enables setting up a new input spec.
+        module.reset()
+        outpt_b = module.run(grads_b)
+        assert isinstance(outpt_b, TorchVector)
+        # Check that results are indeed the same, save for their names.
+        # This means inner states have been properly reset.
+        outpt_a = TorchVector(
+            {f"{key}_bis": val for key, val in outpt_a.coefs.items()}
+        )
+        assert outpt_b == outpt_a
+
+    @pytest.mark.parametrize("device", DEVICES)
+    def test_device_placement(self, optim_cls: str, device: str) -> None:
+        """Test that the optimizer and computations are properly placed."""
+        # Set the device policy, setup a module and run computations.
+        set_device_policy(gpu=(device == "GPU"), idx=None)
+        module = TorchOptiModule(optim_cls)
+        grads = GradientsTestCase("torch").mock_gradient
+        if device == "GPU":
+            for key, tensor in grads.coefs.items():
+                grads.coefs[key] = tensor.cuda()
+        updts = module.run(grads)
+        # Assert that the outputs and internal states are properly placed.
+        dtype = "cuda" if device == "GPU" else "cpu"
+        assert all(t.device.type == dtype for t in updts.coefs.values())
+        state = module.get_state()["state"]
+        assert all(
+            tensor.device.type == dtype
+            for _, group in state
+            for tensor in group
+            if isinstance(tensor, torch.Tensor)
+        )
+        # Reset device policy to run other tests on CPU as expected.
+        set_device_policy(gpu=False)
+
+
+class FakeOptimizer(torch.optim.Optimizer):
+    """Fake torch Optimizer subclass."""
+
+    step = mock.create_autospec(torch.optim.Optimizer.step)
+
+
+class EmptyClass:
+    """Empty class to test non-torch-optimizer imports' failure."""
+
+    # mock class; pylint: disable=too-few-public-methods
+
+    __init__ = mock.MagicMock()
+
+
+class TestTorchOptiModuleValidateImports:
+    """Test that user-validation of third-party modules' import works."""
+
+    def test_validate_accept(self) -> None:
+        """Assert that (fake) user inputs enable validating imports."""
+        # Set up the import string for FakeOptimizer and pre-import its module.
+        string = f"{__name__}.FakeOptimizer"
+        module = importlib.import_module(__name__)
+        # Run the TorchOptiModule instantiation with patched objects.
+        with mock.patch("builtins.input") as patch_input:
+            with mock.patch("importlib.import_module") as patch_import:
+                patch_input.return_value = "y"
+                patch_import.return_value = module
+                optim = TorchOptiModule(string, validate=True)
+        # Assert the expected calls were made and FakeOptimizer was assigned.
+        patch_input.assert_called_once()
+        patch_import.assert_called_once_with(__name__)
+        assert optim.optim_cls is FakeOptimizer
+
+    def test_validate_deny(self) -> None:
+        """Assert (fake) user inputs enable blocking imports."""
+        # Run the TorchOptiModule instantiation with fake user denial command.
+        with mock.patch("builtins.input") as patch_input:
+            with mock.patch("importlib.import_module") as patch_import:
+                patch_input.return_value = "n"
+                with pytest.raises(RuntimeError):
+                    TorchOptiModule(f"{__name__}.FakeOptimizer", validate=True)
+        # Assert the expected calls were made (or not).
+        patch_input.assert_called_once()
+        patch_import.assert_not_called()
+
+    def test_validate_wrong_class(self) -> None:
+        """Assert importing an invalid class raises a TypeError."""
+        # Set up the import string for EmptyClass and pre-import its module.
+        string = f"{__name__}.EmptyClass"
+        module = importlib.import_module(__name__)
+        # Run the TorchOptiModule instantiation with patched objects.
+        with mock.patch("builtins.input") as patch_input:
+            with mock.patch("importlib.import_module") as patch_import:
+                patch_input.return_value = "y"
+                patch_import.return_value = module
+                with pytest.raises(TypeError):
+                    TorchOptiModule(string, validate=True)
+        # Assert the expected calls were made.
+        patch_input.assert_called_once()
+        patch_import.assert_called_once_with(__name__)