diff --git a/README.md b/README.md
index 79f103078166eee32305da4ebf2c7d083a81e3d8..b638674b1b558ad4258255087de58892729d0753 100644
--- a/README.md
+++ b/README.md
@@ -227,6 +227,38 @@ client = declearn.main.FederatedClient(netwk, train, valid, checkpoint="outputs"
 client.run()
 ```
 
+### Support for GPU acceleration
+
+TL;DR: GPU acceleration is natively available in `declearn` for model 
+frameworks that support it, with one line of code and without changing 
+your original model.
+
+Details:
+
+Most machine learning frameworks, including Tensorflow and Torch, enable 
+accelerating computations by using computational devices other than CPU. 
+`declearn` interfaces supported frameworks to be able to set a device policy 
+in a single line of code, accross frameworks. 
+
+`declearn` internalizes the framework-specific code adaptations to place the
+data, model weights and computations on such a device. `declearn` provides 
+with a simple API to define a global device policy. This enables using a 
+single GPU to accelerate computations, or forcing the use of a CPU. 
+
+By default, the policy is set to use the first available GPU, and otherwise 
+use the CPU, with a warning that can safely be ignored.
+
+Setting the device policy to be used can be done in local scripts, either as a 
+client or as a server. Device policy is local and is not synchronized between 
+federated learninng participants.
+
+Here are some examples of the one-liner used:
+```python
+declearn.utils.set_device_policy(gpu=False)  # disable GPU use
+declearn.utils.set_device_policy(gpu=True)  # use any available GPU
+declearn.utils.set_device_policy(gpu=True, idx=2)  # specifically use GPU n°2
+```
+
 ### Note on dependency sharing
 
 One important issue however that is not handled by declearn itself is that
diff --git a/declearn/model/api/_model.py b/declearn/model/api/_model.py
index 6bd5fa081494f3862affba1fdbf95b2919228ebd..41d2c9abd09a0530c2b9b0b9629acd560467a848 100644
--- a/declearn/model/api/_model.py
+++ b/declearn/model/api/_model.py
@@ -25,7 +25,7 @@ from typing_extensions import Self  # future: import from typing (py >=3.11)
 
 from declearn.model.api._vector import Vector
 from declearn.typing import Batch
-from declearn.utils import create_types_registry
+from declearn.utils import DevicePolicy, create_types_registry
 
 
 __all__ = [
@@ -43,6 +43,13 @@ class Model(metaclass=ABCMeta):
     writing algorithms and operations agnostic to the framework
     in which the underlying model is implemented (e.g. PyTorch,
     TensorFlow, Scikit-Learn...).
+
+    Device-placement (i.e. running computations on CPU or GPU)
+    is also handled as part of Model classes' backend, mapping
+    the generic `declearn.utils.DevicePolicy` parameters to any
+    required framework-specific instruction to adequately pick
+    the device to use and ensure the wrapped model, input data
+    and interfaced computations are placed there.
     """
 
     def __init__(
@@ -52,6 +59,13 @@ class Model(metaclass=ABCMeta):
         """Instantiate a Model interface wrapping a 'model' object."""
         self._model = model
 
+    @property
+    @abstractmethod
+    def device_policy(
+        self,
+    ) -> DevicePolicy:
+        """Return the device-placement policy currently used by this model."""
+
     @property
     @abstractmethod
     def required_data_info(
@@ -220,3 +234,27 @@ class Model(metaclass=ABCMeta):
         s_loss: np.ndarray
             Sample-wise loss values, as a 1-d numpy array.
         """
+
+    @abstractmethod
+    def update_device_policy(
+        self,
+        policy: Optional[DevicePolicy] = None,
+    ) -> None:
+        """Update the device-placement policy of this model.
+
+        This method is designed to be called after a change in the global
+        device-placement policy (e.g. to disable using a GPU, or move to
+        a specific one), so as to place pre-existing Model instances and
+        avoid policy inconsistencies that might cause repeated memory or
+        runtime costs from moving data or weights around each time they
+        are used. You should otherwise not worry about a Model's device-
+        placement, as it is handled at instantiation based on the global
+        device policy (see `declearn.utils.set_device_policy`).
+
+        Parameters
+        ----------
+        policy: DevicePolicy or None, default=None
+            Optional DevicePolicy dataclass instance to be used.
+            If None, use the global device policy, accessed via
+            `declearn.utils.get_device_policy`.
+        """
diff --git a/declearn/model/api/_vector.py b/declearn/model/api/_vector.py
index c55986773c830b5d0ddaecdee596879ee1efdfb5..fc6fdc78b15ddc5eb406b10d3ffef50d50ec2857 100644
--- a/declearn/model/api/_vector.py
+++ b/declearn/model/api/_vector.py
@@ -315,10 +315,7 @@ class Vector(metaclass=ABCMeta):
             return type(self)(coefs)
         # Case when the two vectors have incompatible types.
         if isinstance(other, Vector):
-            raise TypeError(
-                f"Cannot {func.__name__} {type(self).__name__} object with "
-                f"a vector of incompatible type {type(other).__name__}."
-            )
+            return NotImplemented
         # Case when operating with another object (e.g. a scalar).
         try:
             return type(self)(
diff --git a/declearn/model/sklearn/_np_vec.py b/declearn/model/sklearn/_np_vec.py
index 5903f7dd91ce4d7707d22511645899fd047d2b2e..65ac351d93904c3375f6d2244f8b271f3a067eb1 100644
--- a/declearn/model/sklearn/_np_vec.py
+++ b/declearn/model/sklearn/_np_vec.py
@@ -41,6 +41,19 @@ class NumpyVector(Vector):
     instances with similar coefficients specifications).
 
     Use `vector.coefs` to access the stored coefficients.
+
+    Notes
+    -----
+    - A `NumpyVector` can be operated with either a scalar value,
+      or another `NumpyVector` that has similar specifications
+      (same coefficient names, shapes and compatible dtypes).
+    - Some other `Vector` classes might be made compatible with
+      `NumpyVector`; in that case, operating with a `NumpyVector`
+      will always result in a vector of the other type. This is
+      notably the case with `TensorflowVector` and `TorchVector`.
+    - There is currently no support for GPU-acceleration with the
+      `NumpyVector` class, that only handles arrays and operations
+      placed on a CPU device.
     """
 
     @property
diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py
index eabc2883ed169072bae0e31219082ec72703d1c2..28abc4913bf30b1f3e4899c41b7c3555f71394d2 100644
--- a/declearn/model/sklearn/_sgd.py
+++ b/declearn/model/sklearn/_sgd.py
@@ -18,6 +18,7 @@
 """Model subclass to wrap scikit-learn SGD classifier and regressor models."""
 
 import typing
+import warnings
 from typing import Any, Callable, Dict, Literal, Optional, Set, Tuple, Union
 
 import numpy as np
@@ -29,7 +30,7 @@ from declearn.data_info import aggregate_data_info
 from declearn.model.api import Model
 from declearn.model.sklearn._np_vec import NumpyVector
 from declearn.typing import Batch
-from declearn.utils import register_type
+from declearn.utils import DevicePolicy, register_type
 
 
 __all__ = [
@@ -63,6 +64,13 @@ class SklearnSGDModel(Model):
     This `Model` subclass is designed to wrap a `SGDClassifier`
     or `SGDRegressor` instance (from `sklearn.linear_model`) to
     be learned federatively.
+
+    Notes regarding device management (CPU, GPU, etc.):
+    * This Model may only run on CPU, and is unaffected by device-
+      management policies.
+    * Calling the `update_device_policy` method has no effect, and
+      raises a UserWarning if a GPU-targetting policy is passed to
+      it directly.
     """
 
     def __init__(
@@ -104,6 +112,12 @@ class SklearnSGDModel(Model):
             None
         )  # type: Optional[Callable[[np.ndarray, np.ndarray], np.ndarray]]
 
+    @property
+    def device_policy(
+        self,
+    ) -> DevicePolicy:
+        return DevicePolicy(gpu=False, idx=None)
+
     @property
     def required_data_info(
         self,
@@ -382,3 +396,10 @@ class SklearnSGDModel(Model):
         else:
             loss_fn = loss_1d
         return loss_fn
+
+    def update_device_policy(
+        self,
+        policy: Optional[DevicePolicy] = None,
+    ) -> None:
+        if policy is not None and policy.gpu:
+            warnings.warn("'SklearnSGDModel' only runs on a CPU backend.")
diff --git a/declearn/model/tensorflow/__init__.py b/declearn/model/tensorflow/__init__.py
index ff94c495b4647614cc5f9e9e44b0a9d636f3c3d3..a8db61f23c5e09cb1a25b2c2abdf70a9128d9972 100644
--- a/declearn/model/tensorflow/__init__.py
+++ b/declearn/model/tensorflow/__init__.py
@@ -24,7 +24,11 @@ through gradient descent.
 This module exposes:
 * TensorflowModel: Model subclass to wrap tensorflow.keras.Model objects
 * TensorflowVector: Vector subclass to wrap tensorflow.Tensor objects
+
+It also exposes the `utils` submodule, which mainly aims at
+providing tools used in the backend of the former objects.
 """
 
+from . import utils
 from ._vector import TensorflowVector
 from ._model import TensorflowModel
diff --git a/declearn/model/tensorflow/_model.py b/declearn/model/tensorflow/_model.py
index fc1a9eb88dc4d42f07314276b266d20749661318..a56ff2096e0f4d50eb32c65aeb84bf268d30f9fe 100644
--- a/declearn/model/tensorflow/_model.py
+++ b/declearn/model/tensorflow/_model.py
@@ -28,18 +28,43 @@ from typing_extensions import Self  # future: import from typing (py >=3.11)
 
 from declearn.data_info import aggregate_data_info
 from declearn.model.api import Model
-from declearn.model.tensorflow._utils import build_keras_loss
+from declearn.model.tensorflow.utils import (
+    build_keras_loss,
+    move_layer_to_device,
+    select_device,
+)
 from declearn.model.tensorflow._vector import TensorflowVector
 from declearn.typing import Batch
-from declearn.utils import register_type
+from declearn.utils import DevicePolicy, get_device_policy, register_type
+
+
+__all__ = [
+    "TensorflowModel",
+]
 
 
 @register_type(name="TensorflowModel", group="Model")
 class TensorflowModel(Model):
     """Model wrapper for TensorFlow Model instances.
 
-    This `Model` subclass is designed to wrap a `tf.keras.Model`
-    instance to be learned federatively.
+    This `Model` subclass is designed to wrap a `tf.keras.Model` instance
+    to be trained federatively.
+
+    Notes regarding device management (CPU, GPU, etc.):
+    * By default, tensorflow places data and operations on GPU whenever one
+      is available.
+    * Our `TensorflowModel` instead consults the device-placement policy (via
+      `declearn.utils.get_device_policy`), places the wrapped keras model's
+      weights there, and runs computations defined under public methods in
+      a `tensorflow.device` context, to enforce that policy.
+    * Note that there is no guarantee that calling a private method directly
+      will result in abiding by that policy. Hence, be careful when writing
+      custom code, and use your own context managers to get guarantees.
+    * Note that if the global device-placement policy is updated, this will
+      only be propagated to existing instances by manually calling their
+      `update_device_policy` method.
+    * You may consult the device policy currently enforced by a TensorflowModel
+      instance by accessing its `device_policy` property.
     """
 
     def __init__(
@@ -47,6 +72,7 @@ class TensorflowModel(Model):
         model: tf.keras.layers.Layer,
         loss: Union[str, tf.keras.losses.Loss],
         metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None,
+        _from_config: bool = False,
         **kwargs: Any,
     ) -> None:
         """Instantiate a Model interface wrapping a tensorflow.keras model.
@@ -66,7 +92,7 @@ class TensorflowModel(Model):
             compiled with the model and computed using the `evaluate`
             method of the returned TensorflowModel instance.
         **kwargs: Any
-            Any addition keyword argument to `tf.keras.Model.compile`
+            Any additional keyword argument to `tf.keras.Model.compile`
             may be passed.
         """
         # Type-check the input Model and wrap it up.
@@ -79,12 +105,30 @@ class TensorflowModel(Model):
         super().__init__(model)
         # Ensure the loss is a keras.Loss object and set its reduction to none.
         loss = build_keras_loss(loss, reduction=tf.keras.losses.Reduction.NONE)
-        # Compile the wrapped model and retain compilation arguments.
-        kwargs.update({"loss": loss, "metrics": metrics})
-        model.compile(**kwargs)
-        self._kwargs = kwargs
-        # Instantiate a SGD optimizer to apply updates as-provided.
-        self._sgd = tf.keras.optimizers.SGD(learning_rate=1.0)
+        # Select the device where to place computations and move the model.
+        policy = get_device_policy()
+        self._device = select_device(gpu=policy.gpu, idx=policy.idx)
+        if not _from_config:
+            self._model = move_layer_to_device(self._model, self._device)
+        # Finalize initialization using the selected device.
+        with tf.device(self._device):
+            # Compile the wrapped model and retain compilation arguments.
+            kwargs.update({"loss": loss, "metrics": metrics})
+            self._model.compile(**kwargs)
+            self._kwargs = kwargs
+            # Instantiate a SGD optimizer to apply updates as-provided.
+            self._sgd = tf.keras.optimizers.SGD(learning_rate=1.0)
+
+    @property
+    def device_policy(
+        self,
+    ) -> DevicePolicy:
+        device = self._device
+        try:
+            idx = int(device.name.rsplit(":", 1)[-1])
+        except ValueError:
+            idx = None
+        return DevicePolicy(gpu=(device.device_type == "GPU"), idx=idx)
 
     @property
     def required_data_info(
@@ -98,7 +142,8 @@ class TensorflowModel(Model):
     ) -> None:
         if not self._model.built:
             data_info = aggregate_data_info([data_info], {"input_shape"})
-            self._model.build(data_info["input_shape"])
+            with tf.device(self._device):
+                self._model.build(data_info["input_shape"])
         # Warn about frozen weights.
         # similar to TorchModel warning; pylint: disable=duplicate-code
         if len(self._model.trainable_weights) < len(self._model.weights):
@@ -129,9 +174,15 @@ class TensorflowModel(Model):
         for key in ("model", "loss", "kwargs"):
             if key not in config.keys():
                 raise KeyError(f"Missing key '{key}' in the config dict.")
-        model = tf.keras.layers.deserialize(config["model"])
-        loss = tf.keras.losses.deserialize(config["loss"])
-        return cls(model, loss, **config["kwargs"])
+        # Set up the device policy.
+        policy = get_device_policy()
+        device = select_device(gpu=policy.gpu, idx=policy.idx)
+        # Deserialize the model and loss keras objects on the device.
+        with tf.device(device):
+            model = tf.keras.layers.deserialize(config["model"])
+            loss = tf.keras.losses.deserialize(config["loss"])
+        # Instantiate the TensorflowModel, avoiding device-to-device copies.
+        return cls(model, loss, **config["kwargs"], _from_config=True)
 
     def get_weights(
         self,
@@ -151,8 +202,9 @@ class TensorflowModel(Model):
             )
         self._verify_weights_compatibility(weights, trainable_only=False)
         variables = {var.name: var for var in self._model.weights}
-        for name, value in weights.coefs.items():
-            variables[name].assign(value, read_value=False)
+        with tf.device(self._device):
+            for name, value in weights.coefs.items():
+                variables[name].assign(value, read_value=False)
 
     def _verify_weights_compatibility(
         self,
@@ -197,12 +249,13 @@ class TensorflowModel(Model):
         batch: Batch,
         max_norm: Optional[float] = None,
     ) -> TensorflowVector:
-        data = self._unpack_batch(batch)
-        if max_norm is None:
-            grads = self._compute_batch_gradients(*data)
-        else:
-            norm = tf.constant(max_norm)
-            grads = self._compute_clipped_gradients(*data, norm)
+        with tf.device(self._device):
+            data = self._unpack_batch(batch)
+            if max_norm is None:
+                grads = self._compute_batch_gradients(*data)
+            else:
+                norm = tf.constant(max_norm)
+                grads = self._compute_clipped_gradients(*data, norm)
         grads_and_vars = zip(grads, self._model.trainable_weights)
         return TensorflowVector(
             {var.name: grad for grad, var in grads_and_vars}
@@ -285,13 +338,14 @@ class TensorflowModel(Model):
         updates: TensorflowVector,
     ) -> None:
         self._verify_weights_compatibility(updates, trainable_only=True)
-        # Delegate updates' application to a tensorflow Optimizer.
-        values = (-1 * updates).coefs.values()
-        zipped = zip(values, self._model.trainable_weights)
-        upd_op = self._sgd.apply_gradients(zipped)
-        # Ensure ops have been performed before exiting.
-        with tf.control_dependencies([upd_op]):
-            return None
+        with tf.device(self._device):
+            # Delegate updates' application to a tensorflow Optimizer.
+            values = (-1 * updates).coefs.values()
+            zipped = zip(values, self._model.trainable_weights)
+            upd_op = self._sgd.apply_gradients(zipped)
+            # Ensure ops have been performed before exiting.
+            with tf.control_dependencies([upd_op]):
+                return None
 
     def evaluate(
         self,
@@ -311,21 +365,23 @@ class TensorflowModel(Model):
         metrics: dict[str, float]
             Dictionary associating evaluation metrics' values to their name.
         """
-        return self._model.evaluate(dataset, return_dict=True)
+        with tf.device(self._device):
+            return self._model.evaluate(dataset, return_dict=True)
 
     def compute_batch_predictions(
         self,
         batch: Batch,
     ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
-        inputs, y_true, s_wght = self._unpack_batch(batch)
-        if y_true is None:
-            raise TypeError(
-                "`TensorflowModel.compute_batch_predictions` received a "
-                "batch with `y_true=None`, which is unsupported. Please "
-                "correct the inputs, or override this method to support "
-                "creating labels from the base inputs."
-            )
-        y_pred = self._model(inputs, training=False).numpy()
+        with tf.device(self._device):
+            inputs, y_true, s_wght = self._unpack_batch(batch)
+            if y_true is None:
+                raise TypeError(
+                    "`TensorflowModel.compute_batch_predictions` received a "
+                    "batch with `y_true=None`, which is unsupported. Please "
+                    "correct the inputs, or override this method to support "
+                    "creating labels from the base inputs."
+                )
+            y_pred = self._model(inputs, training=False).numpy()
         y_true = y_true.numpy()
         s_wght = s_wght.numpy() if s_wght is not None else s_wght
         return y_true, y_pred, s_wght
@@ -335,7 +391,21 @@ class TensorflowModel(Model):
         y_true: np.ndarray,
         y_pred: np.ndarray,
     ) -> np.ndarray:
-        tns_true = tf.convert_to_tensor(y_true)
-        tns_pred = tf.convert_to_tensor(y_pred)
-        s_loss = self._model.compute_loss(y=tns_true, y_pred=tns_pred)
+        with tf.device(self._device):
+            tns_true = tf.convert_to_tensor(y_true)
+            tns_pred = tf.convert_to_tensor(y_pred)
+            s_loss = self._model.compute_loss(y=tns_true, y_pred=tns_pred)
         return s_loss.numpy().squeeze()
+
+    def update_device_policy(
+        self,
+        policy: Optional[DevicePolicy] = None,
+    ) -> None:
+        # Select the device to use based on the provided or global policy.
+        if policy is None:
+            policy = get_device_policy()
+        device = select_device(gpu=policy.gpu, idx=policy.idx)
+        # When needed, re-create the model to force moving it to the device.
+        if self._device is not device:
+            self._device = device
+            self._model = move_layer_to_device(self._model, self._device)
diff --git a/declearn/model/tensorflow/_vector.py b/declearn/model/tensorflow/_vector.py
index f10fa5154d3d82b65d747de31d2d8824bc9bbb2d..f347f212b97c37c3176401004f39d4d09aadf89a 100644
--- a/declearn/model/tensorflow/_vector.py
+++ b/declearn/model/tensorflow/_vector.py
@@ -25,30 +25,55 @@ import tensorflow as tf  # type: ignore
 from tensorflow.python.framework.ops import EagerTensor  # type: ignore
 # pylint: enable=no-name-in-module
 from typing_extensions import Self  # future: import from typing (Py>=3.11)
+# fmt: on
 
 from declearn.model.api import Vector, register_vector_type
 from declearn.model.sklearn import NumpyVector
-
-# fmt: on
+from declearn.model.tensorflow.utils import (
+    preserve_tensor_device,
+    select_device,
+)
+from declearn.utils import get_device_policy
 
 
 @register_vector_type(tf.Tensor, EagerTensor, tf.IndexedSlices)
 class TensorflowVector(Vector):
     """Vector subclass to store tensorflow tensors.
 
-    This Vector is designed to store a collection of named
-    TensorFlow tensors, enabling computations that are either
-    applied to each and every coefficient, or imply two sets
-    of aligned coefficients (i.e. two TensorflowVector with
-    similar specifications).
+    This Vector is designed to store a collection of named TensorFlow
+    tensors, enabling computations that are either applied to each and
+    every coefficient, or imply two sets of aligned coefficients (i.e.
+    two TensorflowVector with similar specifications).
+
+    Note that support for IndexedSlices is implemented, as these are a
+    common type for auto-differentiated gradients.
 
-    Note that support for IndexedSlices is implemented,
-    as these are a common type for auto-differentiated
-    gradients.
-    Note that this class does not (yet?) support special
-    tensor types such as SparseTensor or RaggedTensor.
+    Note that this class does not (yet?) support special tensor types
+    such as SparseTensor or RaggedTensor.
 
     Use `vector.coefs` to access the stored coefficients.
+
+    Notes
+    -----
+    - A `TensorflowVector` can be operated with either a:
+      - scalar value
+      - `NumpyVector` that has similar specifications
+      - `TensorflowVector` that has similar specifications
+      => resulting in a `TensorflowVector` in each of these cases.
+    - The wrapped tensors may be placed on any device (CPU, GPU...)
+      and may not be all on the same device.
+    - The device-placement of the initial `TensorflowVector`'s data
+      is preserved by operations, including with `NumpyVector`.
+    - When combining two `TensorflowVector`, the device-placement
+      of the left-most one is used; in that case, one ends up with
+      `gpu + cpu = gpu` while `cpu + gpu = cpu`. In both cases, a
+      warning will be emitted to prevent silent un-optimized copies.
+    - When deserializing a `TensorflowVector` (either by directly using
+      `TensorflowVector.unpack` or loading one from a JSON dump), loaded
+      tensors are placed based on the global device-placement policy
+      (accessed via `declearn.utils.get_device_policy`). Thus it may
+      have a different device-placement schema than at dump time but
+      should be coherent with that of `TensorflowModel` computations.
     """
 
     @property
@@ -81,6 +106,23 @@ class TensorflowVector(Vector):
     ) -> None:
         super().__init__(coefs)
 
+    def apply_func(
+        self,
+        func: Callable[..., Any],
+        *args: Any,
+        **kwargs: Any,
+    ) -> Self:
+        func = preserve_tensor_device(func)
+        return super().apply_func(func, *args, **kwargs)
+
+    def _apply_operation(
+        self,
+        other: Any,
+        func: Callable[[Any, Any], Any],
+    ) -> Self:
+        func = preserve_tensor_device(func)
+        return super()._apply_operation(other, func)
+
     def dtypes(
         self,
     ) -> Dict[str, str]:
@@ -97,7 +139,10 @@ class TensorflowVector(Vector):
         cls,
         data: Dict[str, Any],
     ) -> Self:
-        coef = {key: cls._unpack_tensor(dat) for key, dat in data.items()}
+        policy = get_device_policy()
+        device = select_device(gpu=policy.gpu, idx=policy.idx)
+        with tf.device(device):
+            coef = {key: cls._unpack_tensor(dat) for key, dat in data.items()}
         return cls(coef)
 
     @classmethod
@@ -149,10 +194,13 @@ class TensorflowVector(Vector):
         if not isinstance(t_a, type(t_b)):
             return False
         if isinstance(t_a, tf.IndexedSlices):
-            return TensorflowVector._tensor_equal(
-                t_a.indices, t_b.indices
-            ) and TensorflowVector._tensor_equal(t_a.values, t_b.values)
-        return tf.reduce_all(t_a == t_b).numpy()
+            # fmt: off
+            return (
+                TensorflowVector._tensor_equal(t_a.indices, t_b.indices)
+                and TensorflowVector._tensor_equal(t_a.values, t_b.values)
+            )
+        with tf.device(t_a.device):
+            return tf.reduce_all(t_a == t_b).numpy()
 
     def sign(self) -> Self:
         return self.apply_func(tf.sign)
@@ -178,8 +226,4 @@ class TensorflowVector(Vector):
         axis: Optional[int] = None,
         keepdims: bool = False,
     ) -> Self:
-        coefs = {
-            key: tf.reduce_sum(val, axis=axis, keepdims=keepdims)
-            for key, val in self.coefs.items()
-        }
-        return self.__class__(coefs)
+        return self.apply_func(tf.reduce_sum, axis=axis, keepdims=keepdims)
diff --git a/declearn/model/tensorflow/utils/__init__.py b/declearn/model/tensorflow/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f7d66b1d29e47c7b07dc8b10f1b6e9d11187896
--- /dev/null
+++ b/declearn/model/tensorflow/utils/__init__.py
@@ -0,0 +1,38 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils for tensorflow backend support code.
+
+GPU/CPU backing device management utils:
+* move_layer_to_device:
+    Create a copy of an input keras layer placed on a given device.
+* preserve_tensor_device:
+    Wrap a tensor-processing function to have it run on its inputs' device.
+* select_device:
+    Select a backing device to use based on inputs and availability.
+
+Loss function management utils:
+* build_keras_loss:
+    Type-check, deserialize and/or wrap a keras loss into a Loss object.
+"""
+
+from ._gpu import (
+    move_layer_to_device,
+    preserve_tensor_device,
+    select_device,
+)
+from ._loss import build_keras_loss
diff --git a/declearn/model/tensorflow/utils/_gpu.py b/declearn/model/tensorflow/utils/_gpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..8164250ab560d4db1e08146f960d3ba3f4c6a6d8
--- /dev/null
+++ b/declearn/model/tensorflow/utils/_gpu.py
@@ -0,0 +1,137 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils for GPU support and device management in tensorflow."""
+
+import functools
+import warnings
+from typing import Any, Callable, Optional, Union
+
+import tensorflow as tf  # type: ignore
+
+
+__all__ = [
+    "move_layer_to_device",
+    "preserve_tensor_device",
+    "select_device",
+]
+
+
+def select_device(
+    gpu: bool,
+    idx: Optional[int] = None,
+) -> tf.config.LogicalDevice:
+    """Select a backing device to use based on inputs and availability.
+
+    Parameters
+    ----------
+    gpu: bool
+        Whether to select a GPU device rather than the CPU one.
+    idx: int or None, default=None
+        Optional pre-selected device index. Only used when `gpu=True`.
+        If `idx is None` or exceeds the number of available GPU devices,
+        use the first available one.
+
+    Warns
+    -----
+    UserWarning:
+        If `gpu=True` but no GPU is available.
+        If `idx` exceeds the number of available GPU devices.
+
+    Returns
+    -------
+    device: tf.config.LogicalDevice
+        Selected device, usable as `tf.device` argument.
+    """
+    idx = 0 if idx is None else idx
+    # List available CPU or GPU devices.
+    device_type = "GPU" if gpu else "CPU"
+    devices = tf.config.list_logical_devices(device_type)
+    # Case when no GPU is available: warn and use a CPU instead.
+    if gpu and not devices:
+        warnings.warn(
+            "Cannot use a GPU device: either CUDA is unavailable "
+            "or no GPU is visible to tensorflow."
+        )
+        device_type, idx = "CPU", 0
+        devices = tf.config.list_logical_devices("CPU")
+    # Case when the desired device index is invalid: select another one.
+    if idx >= len(devices):
+        warnings.warn(
+            f"Cannot use {device_type} device n°{idx}: index is out-of-range."
+            f"\nUsing {device_type} device n°0 instead."
+        )
+        idx = 0
+    # Return the selected device.
+    return devices[idx]
+
+
+def move_layer_to_device(
+    layer: tf.keras.layers.Layer,
+    device: Union[tf.config.LogicalDevice, str],
+) -> tf.keras.layers.Layer:
+    """Create a copy of an input keras layer placed on a given device.
+
+    This functions creates a copy of the input layer and of all its weights.
+    It may therefore be costful and should be used sparingly, to move away
+    variables on a device where all further computations are expected to be
+    run.
+
+    Parameters
+    ----------
+    layer: tf.keras.layers.Layer
+        Keras layer that needs moving to another device.
+    device: tf.config.LogicalDevice or str
+        Device where to place the layer's weights.
+
+    Returns
+    -------
+    layer: tf.keras.layers.Layer
+        Copy of the input layer, with its weights backed on `device`.
+    """
+    config = tf.keras.layers.serialize(layer)
+    weights = layer.get_weights()
+    with tf.device(device):
+        layer = tf.keras.layers.deserialize(config)
+        layer.set_weights(weights)
+    return layer
+
+
+def preserve_tensor_device(
+    func: Callable[..., tf.Tensor],
+) -> Callable[..., tf.Tensor]:
+    """Wrap a tensor-processing function to have it run on its inputs' device.
+
+    Parameters
+    ----------
+    func: function(tf.Tensor, ...) -> tf.Tensor:
+        Function to wrap, that takes a tensorflow Tensor as first argument.
+
+    Returns
+    -------
+    func: function(tf.Tensor, ...) -> tf.Tensor:
+        Similar function to the input one, that operates under a `tf.device`
+        context so as to run computations on the first input tensor's device.
+    """
+
+    @functools.wraps(func)
+    def wrapped(tensor: tf.Tensor, *args: Any, **kwargs: Any) -> tf.Tensor:
+        """Wrapped function, running under a `tf.device` context."""
+        with tf.device(tensor.device):
+            return func(tensor, *args, **kwargs)
+
+    return wrapped
diff --git a/declearn/model/tensorflow/_utils.py b/declearn/model/tensorflow/utils/_loss.py
similarity index 98%
rename from declearn/model/tensorflow/_utils.py
rename to declearn/model/tensorflow/utils/_loss.py
index b858e4e4631862a07b8639a8a0229fc6717d547d..c6b880c8e8aaa6dc90cb197e18bf7677f32285fc 100644
--- a/declearn/model/tensorflow/_utils.py
+++ b/declearn/model/tensorflow/utils/_loss.py
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-"""Backend utils for the declearn.model.tensorflow module."""
+"""Function to parse and/or wrap a keras loss for use with declearn."""
 
 import inspect
 
diff --git a/declearn/model/torch/__init__.py b/declearn/model/torch/__init__.py
index 352d3a398f6958488404c2ecd5aecc4a0f29a43b..efdf95f1d5a2ab94b777667fcd3787fb27fdea29 100644
--- a/declearn/model/torch/__init__.py
+++ b/declearn/model/torch/__init__.py
@@ -24,7 +24,11 @@ gradient descent.
 This module exposes:
 * TorchModel: Model subclass to wrap torch.nn.Module objects
 * TorchVector: Vector subclass to wrap torch.Tensor objects
+
+It also exposes the `utils` submodule, which mainly aims at
+providing tools used in the backend of the former objects.
 """
 
+from . import utils
 from ._vector import TorchVector
 from ._model import TorchModel
diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py
index 488d7782be84ffc9ef38ec9cee4d38bf31ff6d37..4bdea3aaeab477078dd28a24bcc32b0fcc68c3ee 100644
--- a/declearn/model/torch/_model.py
+++ b/declearn/model/torch/_model.py
@@ -27,9 +27,15 @@ import torch
 from typing_extensions import Self  # future: import from typing (py >=3.11)
 
 from declearn.model.api import Model
+from declearn.model.torch.utils import AutoDeviceModule, select_device
 from declearn.model.torch._vector import TorchVector
 from declearn.typing import Batch
-from declearn.utils import register_type
+from declearn.utils import DevicePolicy, get_device_policy, register_type
+
+
+__all__ = [
+    "TorchModel",
+]
 
 
 # alias for unpacked Batch structures, converted to torch.Tensor objects
@@ -42,8 +48,23 @@ TensorBatch = Tuple[
 class TorchModel(Model):
     """Model wrapper for PyTorch Model instances.
 
-    This `Model` subclass is designed to wrap a `torch.nn.Module`
-    instance to be learned federatively.
+    This `Model` subclass is designed to wrap a `torch.nn.Module` instance
+    to be trained federatively.
+
+    Notes regarding device management (CPU, GPU, etc.):
+    * By default torch operates on CPU, and it does not automatically move
+      tensors between devices. This means users have to be careful where
+      tensors are placed to avoid operations between tensors on different
+      devices, leading to runtime errors.
+    * Our `TorchModel` instead consults the global device-placement policy
+      (via `declearn.utils.get_device_policy`), places the wrapped torch
+      modules' weights there, and automates the placement of input data on
+      the same device as the wrapped model.
+    * Note that if the global device-placement policy is updated, this will
+      only be propagated to existing instances by manually calling their
+      `update_device_policy` method.
+    * You may consult the device policy currently enforced by a TorchModel
+      instance by accessing its `device_policy` property.
     """
 
     def __init__(
@@ -62,18 +83,29 @@ class TorchModel(Model):
             is to be minimized through training. Note that it will be
             altered when wrapped.
         """
-        # Type-check the input Model and wrap it up.
+        # Type-check the input model.
         if not isinstance(model, torch.nn.Module):
             raise TypeError("'model' should be a torch.nn.Module instance.")
+        # Select the device where to place computations, and wrap the model.
+        policy = get_device_policy()
+        device = select_device(gpu=policy.gpu, idx=policy.idx)
+        model = AutoDeviceModule(model, device=device)
         super().__init__(model)
         # Assign loss module and set it not to reduce sample-wise values.
         if not isinstance(loss, torch.nn.Module):
             raise TypeError("'loss' should be a torch.nn.Module instance.")
-        self._loss_fn = loss
-        self._loss_fn.reduction = "none"  # type: ignore
+        loss.reduction = "none"  # type: ignore
+        self._loss_fn = AutoDeviceModule(loss, device=device)
         # Compute and assign a functional version of the model.
         self._func_model = functorch.make_functional(self._model)[0]
 
+    @property
+    def device_policy(
+        self,
+    ) -> DevicePolicy:
+        device = self._model.device
+        return DevicePolicy(gpu=(device.type == "cuda"), idx=device.index)
+
     @property
     def required_data_info(
         self,
@@ -102,15 +134,12 @@ class TorchModel(Model):
             "PyTorch JSON serialization relies on pickle, which may be unsafe."
         )
         with io.BytesIO() as buffer:
-            torch.save(self._model, buffer)
+            torch.save(self._model.module, buffer)
             model = buffer.getbuffer().hex()
         with io.BytesIO() as buffer:
-            torch.save(self._loss_fn, buffer)
+            torch.save(self._loss_fn.module, buffer)
             loss = buffer.getbuffer().hex()
-        return {
-            "model": model,
-            "loss": loss,
-        }
+        return {"model": model, "loss": loss}
 
     @classmethod
     def from_config(
@@ -139,6 +168,7 @@ class TorchModel(Model):
     ) -> None:
         if not isinstance(weights, TorchVector):
             raise TypeError("TorchModel requires TorchVector weights.")
+        # NOTE: this preserves the device placement of current states
         self._model.load_state_dict(weights.coefs)
 
     def compute_batch_gradients(
@@ -197,7 +227,7 @@ class TorchModel(Model):
         """Compute the average (opt. weighted) loss over given predictions."""
         loss = self._loss_fn(y_pred, y_true)
         if s_wght is not None:
-            loss.mul_(s_wght)
+            loss.mul_(s_wght.to(loss.device))
         return loss.mean()
 
     def _compute_samplewise_gradients(
@@ -236,7 +266,7 @@ class TorchModel(Model):
                 # false-positive; pylint: disable=no-member
                 grad.mul_(torch.clamp(max_norm / norm, max=1))
                 if s_wght is not None:
-                    grad.mul_(s_wght)
+                    grad.mul_(s_wght.to(grad.device))
             return grads
         # Vectorize the function to compute sample-wise clipped gradients.
         with torch.no_grad():
@@ -322,7 +352,7 @@ class TorchModel(Model):
             try:
                 for key, upd in updates.coefs.items():
                     tns = self._model.get_parameter(key)
-                    tns.add_(upd)
+                    tns.add_(upd.to(tns.device))
             except KeyError as exc:
                 raise KeyError(
                     "Invalid model parameter name(s) found in updates."
@@ -342,9 +372,9 @@ class TorchModel(Model):
             )
         self._model.eval()
         with torch.no_grad():
-            y_pred = self._model(*inputs).numpy()
-        y_true = y_true.numpy()
-        s_wght = s_wght.numpy() if s_wght is not None else s_wght
+            y_pred = self._model(*inputs).cpu().numpy()
+        y_true = y_true.cpu().numpy()
+        s_wght = None if s_wght is None else s_wght.cpu().numpy()
         return y_true, y_pred, s_wght  # type: ignore
 
     def loss_function(
@@ -355,4 +385,16 @@ class TorchModel(Model):
         tns_pred = torch.from_numpy(y_pred)  # pylint: disable=no-member
         tns_true = torch.from_numpy(y_true)  # pylint: disable=no-member
         s_loss = self._loss_fn(tns_pred, tns_true)
-        return s_loss.numpy().squeeze()
+        return s_loss.cpu().numpy().squeeze()
+
+    def update_device_policy(
+        self,
+        policy: Optional[DevicePolicy] = None,
+    ) -> None:
+        # Select the device to use based on the provided or global policy.
+        if policy is None:
+            policy = get_device_policy()
+        device = select_device(gpu=policy.gpu, idx=policy.idx)
+        # Place the wrapped model and loss function modules on that device.
+        self._model.set_device(device)
+        self._loss_fn.set_device(device)
diff --git a/declearn/model/torch/_vector.py b/declearn/model/torch/_vector.py
index 5a6ee87b5a5e1da03e9e2998d9e3338639e6e394..6b2a19e2d05ea79e8e83c4206153c6d614e3e735 100644
--- a/declearn/model/torch/_vector.py
+++ b/declearn/model/torch/_vector.py
@@ -19,25 +19,47 @@
 
 from typing import Any, Callable, Dict, Optional, Set, Tuple, Type
 
-import numpy as np
 import torch
 from typing_extensions import Self  # future: import from typing (Py>=3.11)
 
 from declearn.model.api import Vector, register_vector_type
 from declearn.model.sklearn import NumpyVector
+from declearn.model.torch.utils import select_device
+from declearn.utils import get_device_policy
 
 
 @register_vector_type(torch.Tensor)
 class TorchVector(Vector):
     """Vector subclass to store PyTorch tensors.
 
-    This Vector is designed to store a collection of named
-    PyTorch tensors, enabling computations that are either
-    applied to each and every coefficient, or imply two sets
-    of aligned coefficients (i.e. two TorchVector with
-    similar specifications).
+    This Vector is designed to store a collection of named PyTorch
+    tensors, enabling computations that are either applied to each
+    and every coefficient, or imply two sets of aligned coefficients
+    (i.e. two TorchVector with similar specifications).
 
     Use `vector.coefs` to access the stored coefficients.
+
+    Notes
+    -----
+    - A `TorchVector` can be operated with either a:
+      - scalar value
+      - `NumpyVector` that has similar specifications
+      - `TorchVector` that has similar specifications
+      => resulting in a `TorchVector` in each of these cases.
+    - The wrapped tensors may be placed on any device (CPU, GPU...)
+      and may not be all on the same device.
+    - The device-placement of the initial `TorchVector`'s data
+      is preserved by operations, including with `NumpyVector`.
+    - When combining two `TorchVector`, the device-placement
+      of the left-most one is used; in that case, one ends up with
+      `gpu + cpu = gpu` while `cpu + gpu = cpu`. In both cases, a
+      warning will be emitted to prevent silent un-optimized copies.
+    - When deserializing a `TorchVector` (either by directly using
+      `TorchVector.unpack` or loading one from a JSON dump), loaded
+      tensors are placed based on the global device-placement policy
+      (accessed via `declearn.utils.get_device_policy`). Thus it may
+      have a different device-placement schema than at dump time but
+      should be coherent with that of `TorchModel` computations.
     """
 
     @property
@@ -73,12 +95,20 @@ class TorchVector(Vector):
         other: Any,
         func: Callable[[Any, Any], Any],
     ) -> Self:
+        # Convert 'other' NumpyVector to a (CPU-backed) TorchVector.
         if isinstance(other, NumpyVector):
             # false-positive; pylint: disable=no-member
             coefs = {
                 key: torch.from_numpy(val) for key, val in other.coefs.items()
             }
             other = TorchVector(coefs)
+        # Ensure 'other' TorchVector shares this vector's device placement.
+        if isinstance(other, TorchVector):
+            coefs = {
+                key: val.to(self.coefs[key].device)
+                for key, val in other.coefs.items()
+            }
+            other = TorchVector(coefs)
         return super()._apply_operation(other, func)
 
     def dtypes(
@@ -95,15 +125,20 @@ class TorchVector(Vector):
     def pack(
         self,
     ) -> Dict[str, Any]:
-        return {key: tns.numpy() for key, tns in self.coefs.items()}
+        return {key: tns.cpu().numpy() for key, tns in self.coefs.items()}
 
     @classmethod
     def unpack(
         cls,
         data: Dict[str, Any],
     ) -> Self:
-        # false-positive; pylint: disable=no-member
-        coefs = {key: torch.from_numpy(dat) for key, dat in data.items()}
+        policy = get_device_policy()
+        device = select_device(gpu=policy.gpu, idx=policy.idx)
+        coefs = {
+            # false-positive on `torch.from_numpy`; pylint: disable=no-member
+            key: torch.from_numpy(dat).to(device)
+            for key, dat in data.items()
+        }
         return cls(coefs)
 
     def __eq__(
@@ -115,8 +150,9 @@ class TorchVector(Vector):
             valid = self.coefs.keys() == other.coefs.keys()
         if valid:
             valid = all(
-                np.array_equal(self.coefs[k].numpy(), other.coefs[k].numpy())
-                for k in self.coefs
+                # false-positive on 'torch.equal'; pylint: disable=no-member
+                torch.equal(tns, other.coefs[key].to(tns.device))
+                for key, tns in self.coefs.items()
             )
         return valid
 
diff --git a/declearn/model/torch/utils/__init__.py b/declearn/model/torch/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..225118fa3cc035d9bac36542a59a389906587a6e
--- /dev/null
+++ b/declearn/model/torch/utils/__init__.py
@@ -0,0 +1,27 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils for torch backend support code.
+
+GPU/CPU backing device management utils:
+* AutoDeviceModule:
+    Wrapper for a `torch.nn.Module`, automating device-management.
+* select_device:
+    Select a backing device to use based on inputs and availability.
+"""
+
+from ._gpu import AutoDeviceModule, select_device
diff --git a/declearn/model/torch/utils/_gpu.py b/declearn/model/torch/utils/_gpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..59ed28dc340721d5e92d0e65a851592f4c18fffc
--- /dev/null
+++ b/declearn/model/torch/utils/_gpu.py
@@ -0,0 +1,140 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils for GPU support and device management in torch."""
+
+import warnings
+from typing import Any, Optional
+
+import torch
+
+
+__all__ = [
+    "AutoDeviceModule",
+    "select_device",
+]
+
+
+def select_device(
+    gpu: bool,
+    idx: Optional[int] = None,
+) -> torch.device:  # pylint: disable=no-member
+    """Select a backing device to use based on inputs and availability.
+
+    Parameters
+    ----------
+    gpu: bool
+        Whether to select a GPU device rather than the CPU one.
+    idx: int or None, default=None
+        Optional pre-selected GPU device index. Only used when `gpu=True`.
+        If `idx is None` or exceeds the number of available GPU devices,
+        use `torch.cuda.current_device()`.
+
+    Warns
+    -----
+    UserWarning:
+        If `gpu=True` but no GPU is available.
+        If `idx` exceeds the number of available GPU devices.
+
+    Returns
+    -------
+    device: torch.device
+        Selected torch device, with type "cpu" or "cuda".
+    """
+    # Case when instructed to use the CPU device.
+    if not gpu:
+        return torch.device("cpu")  # pylint: disable=no-member
+    # Case when no GPU is available: warn and use the CPU instead.
+    if gpu and not torch.cuda.is_available():
+        warnings.warn(
+            "Cannot use a GPU device: either CUDA is unavailable "
+            "or no GPU is visible to torch."
+        )
+        return torch.device("cpu")  # pylint: disable=no-member
+    # Case when the desired GPU is invalid: select another one.
+    if (idx or 0) >= torch.cuda.device_count():
+        warnings.warn(
+            f"Cannot use GPU device n°{idx}: index is out-of-range.\n"
+            f"Using GPU device n°{torch.cuda.current_device()} instead."
+        )
+        idx = None
+    # Return the selected or auto-selected GPU device index.
+    if idx is None:
+        idx = torch.cuda.current_device()
+    return torch.device("cuda", index=idx)  # pylint: disable=no-member
+
+
+class AutoDeviceModule(torch.nn.Module):
+    """Wrapper for a `torch.nn.Module`, automating device-management.
+
+    This `torch.nn.Module` subclass enables wrapping another one, and
+    provides:
+    * a `device` attribute (and instantiation parameter) indicating
+      where the wrapped module is placed
+    * automatic placement of input tensors on that device as part of
+      `forward` calls to the module
+    * a `set_device` method to change the device and move the wrapped
+      module to it
+
+    This aims at internalizing device-management boilerplate code.
+    The wrapped module is assigned to the `module` attribute and thus
+    can be accessed directly.
+    """
+
+    def __init__(
+        self,
+        module: torch.nn.Module,
+        device: torch.device,  # pylint: disable=no-member
+    ) -> None:
+        """Wrap a torch Module into an AutoDeviceModule.
+
+        Parameters
+        ----------
+        module: torch.nn.Module
+            Torch module that needs wrapping.
+        device: torch.device
+            Torch device where to place the wrapped module and computations.
+        """
+        super().__init__()
+        self.device = device
+        self.module = module.to(self.device)
+
+    def forward(self, *inputs: Any) -> torch.Tensor:
+        """Run the forward computation, automating device-placement of inputs.
+
+        Please refer to `self.module.forward` for details on the wrapped
+        module's forward specifications.
+        """
+        inputs = tuple(
+            x.to(self.device) if isinstance(x, torch.Tensor) else x
+            for x in inputs
+        )
+        return self.module(*inputs)
+
+    def set_device(
+        self,
+        device: torch.device,  # pylint: disable=no-member
+    ) -> None:
+        """Move the wrapped module to a pre-selected torch device.
+
+        Parameters
+        ----------
+        device: torch.device
+           Torch device where to place the wrapped module and computations.
+        """
+        self.device = device
+        self.module.to(device)
diff --git a/declearn/test_utils/_vectors.py b/declearn/test_utils/_vectors.py
index d789cec4933795a0364b3083164ff449b9f79c57..27f821674acec621db1fbf14d569d0c9505b55ea 100644
--- a/declearn/test_utils/_vectors.py
+++ b/declearn/test_utils/_vectors.py
@@ -87,7 +87,8 @@ class GradientsTestCase:
             return array
         if self.framework == "tensorflow":
             tensorflow = importlib.import_module("tensorflow")
-            return tensorflow.convert_to_tensor(array)
+            with tensorflow.device("CPU"):
+                return tensorflow.convert_to_tensor(array)
         if self.framework == "torch":
             torch = importlib.import_module("torch")
             return torch.from_numpy(array)
diff --git a/declearn/utils/__init__.py b/declearn/utils/__init__.py
index 00bf9fc14932fef61b6eaa4ab9b8aefeb404374c..bee5e57cbbe3f30f9b6acd2753ad9096c3c25ab8 100644
--- a/declearn/utils/__init__.py
+++ b/declearn/utils/__init__.py
@@ -67,6 +67,17 @@ And examples of pre-registered (de)serialization functions:
 * (deserialize_numpy, serialize_numpy):
     Pair of functions to (un)pack a numpy ndarray as JSON-serializable data.
 
+Device-policy utils
+-------------------
+Utils to access or update parameters defining a global device-selection policy.
+
+* DevicePolicy:
+    Dataclass to store parameters defining a device-selection policy.
+* get_device_policy:
+    Access a copy of the current global device policy.
+* set_device_policy:
+    Update the current global device policy.
+
 Miscellaneous
 -------------
 
@@ -84,6 +95,11 @@ from ._dataclass import (
     dataclass_from_func,
     dataclass_from_init,
 )
+from ._device_policy import (
+    DevicePolicy,
+    get_device_policy,
+    set_device_policy,
+)
 from ._json import (
     add_json_support,
     json_dump,
diff --git a/declearn/utils/_device_policy.py b/declearn/utils/_device_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..f26ea9653d4018d88c23e0020fb9aee38b968a72
--- /dev/null
+++ b/declearn/utils/_device_policy.py
@@ -0,0 +1,122 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils to define a computation device policy.
+
+This private submodule defines:
+* A dataclass defining a standard to hold a device-selection policy.
+* A private global variable holding the current package-wise device policy.
+* A public pair of functions acting as a getter and a setter for that variable.
+"""
+
+import dataclasses
+from typing import Optional
+
+
+__all__ = [
+    "DevicePolicy",
+    "get_device_policy",
+    "set_device_policy",
+]
+
+
+@dataclasses.dataclass
+class DevicePolicy:
+    """Dataclass to store parameters defining a device-selection policy.
+
+    This class merely defines a shared language of keyword-arguments to
+    define whether to back computations on CPU or on a GPU device.
+
+    It is meant to be instantiated as a global variable that holds that
+    information, and can be accessed by framework-specific backend code
+    so as to take the required steps towards implementing that policy.
+
+    To access or update the current global DevicePolicy, please use the
+    getter and setter functions: `declearn.utils.get_device_policy` and
+    `declearn.utils.set_device_policy`.
+
+    Attributes
+    ----------
+    gpu: bool
+        Whether to use a GPU device rather than the CPU one to back data
+        and computations. If no GPU is available, use CPU with a warning.
+    idx: int or None
+        Optional index of the GPU device to use.
+        If None, select one arbitrarily.
+        If this index exceeds the number of available GPUs, select one
+        arbitrarily, with a warning.
+    """
+
+    gpu: bool
+    idx: Optional[int]
+
+    def __post_init__(self) -> None:
+        if not isinstance(self.gpu, bool):
+            raise TypeError(
+                f"DevicePolicy 'gpu' should be a bool, not '{type(self.gpu)}'."
+            )
+        if not (self.idx is None or isinstance(self.idx, int)):
+            raise TypeError(
+                "DevicePolicy 'idx' should be None or an int, not "
+                f"'{type(self.idx)}'."
+            )
+
+
+DEVICE_POLICY = DevicePolicy(gpu=True, idx=None)
+
+
+def get_device_policy() -> DevicePolicy:
+    """Return a copy of the current global device policy.
+
+    This method is meant to be used:
+    - By end-users that wish to check the current device policy.
+    - By the backend code of framework-specific objects so as to
+      take the required steps towards implementing that policy.
+
+    To update the current policy, use `declearn.utils.set_device_policy`.
+
+    Returns
+    -------
+    policy: DevicePolicy
+        DevicePolicy dataclass instance, wrapping parameters that specify
+        the device policy to be enforced by Model and Vector to properly
+        place data and computations.
+    """
+    return DevicePolicy(**dataclasses.asdict(DEVICE_POLICY))
+
+
+def set_device_policy(
+    gpu: bool,
+    idx: Optional[int] = None,
+) -> None:
+    """Update the current global device policy.
+
+    To access the current policy, use `declearn.utils.set_device_policy`.
+
+    Parameters
+    ----------
+    gpu: bool
+        Whether to use a GPU device rather than the CPU one to back data
+        and computations. If no GPU is available, use CPU with a warning.
+    idx: int or None, default=None
+        Optional index of the GPU device to use.
+        If this index exceeds the number of available GPUs, select one
+        arbitrarily, with a warning.
+    """
+    # Using a global statement to have a proper setter to a private variable.
+    global DEVICE_POLICY  # pylint: disable=global-statement
+    DEVICE_POLICY = DevicePolicy(gpu, idx)
diff --git a/test/model/model_testing.py b/test/model/model_testing.py
index 3d632814f678f506a81a25dc7f384141fcbf55df..9936788969cecc3eef00913bcc2f7e0659a8eee8 100644
--- a/test/model/model_testing.py
+++ b/test/model/model_testing.py
@@ -18,7 +18,7 @@
 """Shared testing code for TensorFlow and Torch models' unit tests."""
 
 import json
-from typing import Any, List, Protocol, Tuple, Type, Union
+from typing import Any, Generic, List, Protocol, Tuple, Type, TypeVar, Union
 
 import numpy as np
 
@@ -27,10 +27,13 @@ from declearn.typing import Batch
 from declearn.utils import json_pack, json_unpack
 
 
-class ModelTestCase(Protocol):
+VectorT = TypeVar("VectorT", bound=Vector)
+
+
+class ModelTestCase(Protocol, Generic[VectorT]):
     """TestCase fixture-provider protocol."""
 
-    vector_cls: Type[Vector]
+    vector_cls: VectorT
     tensor_cls: Union[Type[Any], Tuple[Type[Any], ...]]
 
     @staticmethod
@@ -51,6 +54,12 @@ class ModelTestCase(Protocol):
     ) -> Model:
         """Suited toy binary-classification model."""
 
+    def assert_correct_device(
+        self,
+        vector: VectorT,
+    ) -> None:
+        """Raise if a vector is backed on the wrong type of device."""
+
 
 class ModelTestSuite:
     """Unit tests for a declearn Model."""
@@ -64,6 +73,7 @@ class ModelTestSuite:
         config = json.dumps(model.get_config())
         other = model.from_config(json.loads(config))
         assert model.get_config() == other.get_config()
+        assert model.device_policy == other.device_policy
 
     def test_get_set_weights(
         self,
@@ -75,7 +85,12 @@ class ModelTestSuite:
         assert isinstance(w_srt, test_case.vector_cls)
         w_end = w_srt + 1.0
         model.set_weights(w_end)
-        assert model.get_weights() == w_end
+        w_upd = model.get_weights()
+        assert w_upd == w_end
+        # Check that weight tensors are properly placed.
+        test_case.assert_correct_device(w_srt)
+        test_case.assert_correct_device(w_end)
+        test_case.assert_correct_device(w_upd)
 
     def test_compute_batch_gradients(
         self,
@@ -113,7 +128,13 @@ class ModelTestSuite:
         np_grads = model.compute_batch_gradients(np_batch)  # type: ignore
         assert isinstance(np_grads, test_case.vector_cls)
         my_grads = model.compute_batch_gradients(my_batch)
-        assert my_grads == np_grads
+        # Allow for a numerical imprecision of 10^-9.
+        diff = my_grads - np_grads
+        max_err = max(
+            np.abs(test_case.to_numpy(weight)).max()
+            for weight in diff.coefs.values()
+        )
+        assert max_err < 1e-8
 
     def test_compute_batch_gradients_clipped(
         self,
@@ -137,6 +158,9 @@ class ModelTestSuite:
             for k in grads_a.coefs
         )
         assert grads_a != grads_b
+        # Check that gradients are properly placed.
+        test_case.assert_correct_device(grads_a)
+        test_case.assert_correct_device(grads_b)
 
     def test_apply_updates(
         self,
@@ -157,9 +181,14 @@ class ModelTestSuite:
         # NOTE: if the model had frozen weights, this test would xfail.
         w_end = model.get_weights()
         assert w_end != w_srt
-        updt = [test_case.to_numpy(val) for val in grads.coefs.values()]
-        diff = list((w_end - w_srt).coefs.values())
-        assert all(np.abs(a - b).max() < 1e-6 for a, b in zip(diff, updt))
+        diff = (w_end - w_srt) - grads
+        assert all(
+            np.abs(test_case.to_numpy(weight)).max() < 1e-6
+            for weight in diff.coefs.values()
+        )
+        # Check that gradients and updated weights are properly placed.
+        test_case.assert_correct_device(grads)
+        test_case.assert_correct_device(w_end)
 
     def test_serialize_gradients(
         self,
diff --git a/test/model/test_sksgd.py b/test/model/test_sksgd.py
index b214f28ac1c0fb595365dec187ca8ac94d52ed71..e53febae8f98d3fd43fbabc61695e9ae366c054b 100644
--- a/test/model/test_sksgd.py
+++ b/test/model/test_sksgd.py
@@ -106,6 +106,12 @@ class SklearnSGDTestCase(ModelTestCase):
         model.initialize(data_info)
         return model
 
+    def assert_correct_device(
+        self,
+        vector: NumpyVector,
+    ) -> None:
+        pass
+
 
 @pytest.fixture(name="test_case")
 def fixture_test_case(
diff --git a/test/model/test_tflow.py b/test/model/test_tflow.py
index aafd5d5d567ff37baba544ab8ff34904a529674d..ba1193f6d19e2e15da5dc98a27291bf5fc0722a3 100644
--- a/test/model/test_tflow.py
+++ b/test/model/test_tflow.py
@@ -33,6 +33,7 @@ except ModuleNotFoundError:
 
 from declearn.model.tensorflow import TensorflowModel, TensorflowVector
 from declearn.typing import Batch
+from declearn.utils import set_device_policy
 
 # dirty trick to import from `model_testing.py`;
 # pylint: disable=wrong-import-order, wrong-import-position
@@ -67,11 +68,16 @@ class TensorflowTestCase(ModelTestCase):
     def __init__(
         self,
         kind: Literal["MLP", "RNN", "CNN"],
+        device: Literal["CPU", "GPU"],
     ) -> None:
         """Specify the desired model architecture."""
         if kind not in ("MLP", "RNN", "CNN"):
             raise ValueError(f"Invalid keras test architecture: '{kind}'.")
+        if device not in ("CPU", "GPU"):
+            raise ValueError(f"Invalid device choice for test: '{device}'.")
         self.kind = kind
+        self.device = device
+        set_device_policy(gpu=(device == "GPU"), idx=0)
 
     @staticmethod
     def to_numpy(
@@ -79,7 +85,7 @@ class TensorflowTestCase(ModelTestCase):
     ) -> np.ndarray:
         """Convert an input tensor to a numpy array."""
         assert isinstance(tensor, tf.Tensor)
-        return tensor.numpy()  # type: ignore
+        return tensor.numpy()
 
     @property
     def dataset(
@@ -132,16 +138,47 @@ class TensorflowTestCase(ModelTestCase):
         tfmod.build(shape)  # as model is built, no data_info is required
         return TensorflowModel(tfmod, loss="binary_crossentropy", metrics=None)
 
+    def assert_correct_device(
+        self,
+        vector: TensorflowVector,
+    ) -> None:
+        """Raise if a vector is backed on the wrong type of device."""
+        name = f"{self.device}:0"
+        assert all(
+            tensor.device.endswith(name) for tensor in vector.coefs.values()
+        )
+
 
 @pytest.fixture(name="test_case")
 def fixture_test_case(
-    kind: Literal["MLP", "RNN", "CNN"]
+    kind: Literal["MLP", "RNN", "CNN"],
+    device: Literal["CPU", "GPU"],
 ) -> TensorflowTestCase:
     """Fixture to access a TensorflowTestCase."""
-    return TensorflowTestCase(kind)
+    return TensorflowTestCase(kind, device)
+
 
+DEVICES = ["CPU"]
+if tf.config.list_logical_devices("GPU"):
+    DEVICES.append("GPU")
 
+
+@pytest.mark.parametrize("device", DEVICES)
 @pytest.mark.parametrize("kind", ["MLP", "RNN", "CNN"])
 @pytest.mark.filterwarnings("ignore:.*randrange.*:DeprecationWarning")
 class TestTensorflowModel(ModelTestSuite):
     """Unit tests for declearn.model.tensorflow.TensorflowModel."""
+
+    def test_proper_model_placement(
+        self,
+        test_case: TensorflowTestCase,
+    ) -> None:
+        """Check that at instantiation, model weights are properly placed."""
+        model = test_case.model
+        policy = model.device_policy
+        assert policy.gpu == (test_case.device == "GPU")
+        assert policy.idx == 0
+        tfmod = getattr(model, "_model")
+        device = f"{test_case.device}:0"
+        for var in tfmod.weights:
+            assert var.device.endswith(device)
diff --git a/test/model/test_torch.py b/test/model/test_torch.py
index 18b6418a632182262e040ecada563e3cc958ebee..5a6729d59bd53e0e69273ddf5ad936776ddfc78f 100644
--- a/test/model/test_torch.py
+++ b/test/model/test_torch.py
@@ -17,6 +17,7 @@
 
 """Unit tests for TorchModel."""
 
+import json
 import sys
 from typing import Any, List, Literal, Tuple
 
@@ -30,6 +31,7 @@ except ModuleNotFoundError:
 
 from declearn.model.torch import TorchModel, TorchVector
 from declearn.typing import Batch
+from declearn.utils import set_device_policy
 
 # dirty trick to import from `model_testing.py`;
 # pylint: disable=wrong-import-order, wrong-import-position
@@ -84,11 +86,14 @@ class TorchTestCase(ModelTestCase):
     def __init__(
         self,
         kind: Literal["MLP", "RNN", "CNN"],
+        device: Literal["CPU", "GPU"],
     ) -> None:
         """Specify the desired model architecture."""
         if kind not in ("MLP", "RNN", "CNN"):
             raise ValueError(f"Invalid torch test architecture: '{kind}'.")
         self.kind = kind
+        self.device = device
+        set_device_policy(gpu=(device == "GPU"), idx=0)
 
     @staticmethod
     def to_numpy(
@@ -96,7 +101,7 @@ class TorchTestCase(ModelTestCase):
     ) -> np.ndarray:
         """Convert an input tensor to a numpy array."""
         assert isinstance(tensor, torch.Tensor)
-        return tensor.numpy()  # type: ignore
+        return tensor.cpu().numpy()
 
     @property
     def dataset(
@@ -133,7 +138,7 @@ class TorchTestCase(ModelTestCase):
         elif self.kind == "RNN":
             stack = [
                 torch.nn.Embedding(100, 32),
-                torch.nn.LSTM(32, 16, batch_first=True),  # type: ignore
+                torch.nn.LSTM(32, 16, batch_first=True),
                 ExtractLSTMFinalOutput(),
                 torch.nn.Tanh(),
                 torch.nn.Linear(16, 1),
@@ -154,13 +159,32 @@ class TorchTestCase(ModelTestCase):
         nnmod = torch.nn.Sequential(*stack)
         return TorchModel(nnmod, loss=torch.nn.BCELoss())
 
+    def assert_correct_device(
+        self,
+        vector: TorchVector,
+    ) -> None:
+        """Raise if a vector is backed on the wrong type of device."""
+        dev_type = "cuda" if self.device == "GPU" else "cpu"
+        assert all(
+            tensor.device.type == dev_type for tensor in vector.coefs.values()
+        )
+
 
 @pytest.fixture(name="test_case")
-def fixture_test_case(kind: Literal["MLP", "RNN", "CNN"]) -> TorchTestCase:
+def fixture_test_case(
+    kind: Literal["MLP", "RNN", "CNN"],
+    device: Literal["CPU", "GPU"],
+) -> TorchTestCase:
     """Fixture to access a TorchTestCase."""
-    return TorchTestCase(kind)
+    return TorchTestCase(kind, device)
+
 
+DEVICES = ["CPU"]
+if torch.cuda.device_count():
+    DEVICES.append("GPU")
 
+
+@pytest.mark.parametrize("device", DEVICES)
 @pytest.mark.parametrize("kind", ["MLP", "RNN", "CNN"])
 class TestTorchModel(ModelTestSuite):
     """Unit tests for declearn.model.torch.TorchModel."""
@@ -175,12 +199,33 @@ class TestTorchModel(ModelTestSuite):
             #       due to the (de)serialization of a custom nn.Module
             #       the expected model behaviour is, however, correct
             try:
-                super().test_serialization(test_case)
+                self._test_serialization(test_case)
             except AssertionError:
                 pytest.skip(
                     "skipping failed test due to custom nn.Module pickling"
                 )
-        super().test_serialization(test_case)
+        self._test_serialization(test_case)
+
+    def _test_serialization(
+        self,
+        test_case: ModelTestCase,
+    ) -> None:
+        """Check that the model can be JSON-(de)serialized properly.
+
+        This method replaces the parent `test_serialization` one.
+        """
+        # Same setup as in parent test: a model and a config-based other.
+        model = test_case.model
+        config = json.dumps(model.get_config())
+        other = model.from_config(json.loads(config))
+        # Verify that both models have the same device policy.
+        assert model.device_policy == other.device_policy
+        # Verify that both models have a similar structure of modules.
+        mod_a = list(getattr(model, "_model").modules())
+        mod_b = list(getattr(other, "_model").modules())
+        assert len(mod_a) == len(mod_b)
+        assert all(isinstance(a, type(b)) for a, b in zip(mod_a, mod_b))
+        assert all(repr(a) == repr(b) for a, b in zip(mod_a, mod_b))
 
     def test_compute_batch_gradients_clipped(
         self,
@@ -195,3 +240,17 @@ class TestTorchModel(ModelTestSuite):
                 )
         else:
             super().test_compute_batch_gradients_clipped(test_case)
+
+    def test_proper_model_placement(
+        self,
+        test_case: TorchTestCase,
+    ) -> None:
+        """Check that at instantiation, model weights are properly placed."""
+        model = test_case.model
+        policy = model.device_policy
+        assert policy.gpu == (test_case.device == "GPU")
+        assert (policy.idx == 0) if policy.gpu else (policy.idx is None)
+        ptmod = getattr(model, "_model").module
+        device_type = "cpu" if test_case.device == "CPU" else "cuda"
+        for param in ptmod.parameters():
+            assert param.device.type == device_type
diff --git a/test/model/test_vector.py b/test/model/test_vector.py
index 6374ad409370145a4f6b2ca4b8a953f3bf9f777e..ca5592a76d25067569903d78b652f202b70fd783 100644
--- a/test/model/test_vector.py
+++ b/test/model/test_vector.py
@@ -32,7 +32,10 @@ from declearn.test_utils import (
     GradientsTestCase,
     list_available_frameworks,
 )
-from declearn.utils import json_pack, json_unpack
+from declearn.utils import json_pack, json_unpack, set_device_policy
+
+
+set_device_policy(gpu=False)  # run Vector unit tests on CPU only
 
 
 @pytest.fixture(name="framework", params=list_available_frameworks())
diff --git a/test/optimizer/test_modules.py b/test/optimizer/test_modules.py
index 0bbde31ad6746bd94b2cc631c4d2b0a991645682..ffd1e3132ffb64691b04ebdd03cdad4bd5d4e355 100644
--- a/test/optimizer/test_modules.py
+++ b/test/optimizer/test_modules.py
@@ -44,7 +44,7 @@ from declearn.test_utils import (
     GradientsTestCase,
     assert_json_serializable_dict,
 )
-from declearn.utils import access_types_mapping
+from declearn.utils import access_types_mapping, set_device_policy
 
 # relative import; pylint: disable=wrong-import-order, wrong-import-position
 # fmt: off
@@ -56,6 +56,8 @@ sys.path.pop()
 
 OPTIMODULE_SUBCLASSES = access_types_mapping(group="OptiModule")
 
+set_device_policy(gpu=False)  # run all OptiModule tests on CPU
+
 
 @pytest.mark.parametrize(
     "cls", OPTIMODULE_SUBCLASSES.values(), ids=OPTIMODULE_SUBCLASSES.keys()