diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 71f1f27cb84deb0d43b0a4f598e2b92730729c23..5af099a99145ddc9cffe81ba447982e146927b43 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -7,7 +7,7 @@ before_script:
   - source venv/bin/activate
 
 # Run the test suite using tox.
-# This job is called when commits are pushed to the main branch,
+# This job is called when commits are pushed to the main or the dev branch,
 # save for merge-resulting commits (that are expected to come from MRs).
 test:
   script:
@@ -15,14 +15,16 @@ test:
     - tox -e py38
   rules:
     - if: ($CI_COMMIT_BRANCH == "main") && ($CI_PIPELINE_SOURCE == "push")
-          && ($CI_COMMIT_TITLE !~ /^Merge branch '.*' into main/)
+          && ($CI_COMMIT_TITLE !~ /^Merge branch '.*' into 'main'/)
+    - if: ($CI_COMMIT_BRANCH == "develop") && ($CI_PIPELINE_SOURCE == "push")
+          && ($CI_COMMIT_TITLE !~ /^Merge branch '.*' into 'develop'/)
   tags:
     - ci.inria.fr
     - small
 
 # Run the test suite using tox, with --fulltest option.
 # This job is called on creation and pushes to non-Draft MRs.
-# It may also be launched manually upon pushes to Draft MRs or main branch.
+# It may also be launched manually upon pushes to Draft MRs or main/dev branch.
 test-full:
   script:
     - pip install -U tox
@@ -33,7 +35,8 @@ test-full:
     - if: ($CI_PIPELINE_SOURCE == "merge_request_event")
           && ($CI_MERGE_REQUEST_TITLE =~ /^Draft:.*/)
       when: manual
-    - if: ($CI_COMMIT_BRANCH == "main") && ($CI_PIPELINE_SOURCE == "push")
+    - if: (($CI_COMMIT_BRANCH == "main") || ($CI_COMMIT_BRANCH == "develop"))
+          && ($CI_PIPELINE_SOURCE == "push")
       when: manual
   tags:
     - ci.inria.fr
diff --git a/README.md b/README.md
index 1866115e93bfcbec2f7daf2f25b9593193c68c05..3ab040e8d80bf7b0afe2d1ede39701bdba70ea97 100644
--- a/README.md
+++ b/README.md
@@ -227,6 +227,38 @@ client = declearn.main.FederatedClient(netwk, train, valid, checkpoint="outputs"
 client.run()
 ```
 
+### Support for GPU acceleration
+
+TL;DR: GPU acceleration is natively available in `declearn` for model 
+frameworks that support it, with one line of code and without changing 
+your original model.
+
+Details:
+
+Most machine learning frameworks, including Tensorflow and Torch, enable 
+accelerating computations by using computational devices other than CPU. 
+`declearn` interfaces supported frameworks to be able to set a device policy 
+in a single line of code, accross frameworks. 
+
+`declearn` internalizes the framework-specific code adaptations to place the
+data, model weights and computations on such a device. `declearn` provides 
+with a simple API to define a global device policy. This enables using a 
+single GPU to accelerate computations, or forcing the use of a CPU. 
+
+By default, the policy is set to use the first available GPU, and otherwise 
+use the CPU, with a warning that can safely be ignored.
+
+Setting the device policy to be used can be done in local scripts, either as a 
+client or as a server. Device policy is local and is not synchronized between 
+federated learninng participants.
+
+Here are some examples of the one-liner used:
+```python
+declearn.utils.set_device_policy(gpu=False)  # disable GPU use
+declearn.utils.set_device_policy(gpu=True)  # use any available GPU
+declearn.utils.set_device_policy(gpu=True, idx=2)  # specifically use GPU n°2
+```
+
 ### Note on dependency sharing
 
 One important issue however that is not handled by declearn itself is that
@@ -701,7 +733,24 @@ forward framework evolutions and API revisions.
 
 To contribute directly to the code (beyond posting issues on gitlab), please
 create a dedicated branch, and submit a **Merge Request** once you want your
-work reviewed and further processed to end up integrated into the main branch.
+work reviewed and further processed to end up integrated into the package.
+
+The **git branching strategy** is the following:
+
+- 'main' matches the latest release's X.Y version, but may hold unreleased
+  patch changes; i.e. it can be seen as version X.Y.(Z+1)-beta
+- 'develop' holds finalized changes that should be made part of the next
+  minor version release; i.e. it can be seen as version X.(Y+1).0-beta
+- when necessary, intermediate release branches may be set up to cherry-pick
+  changes from 'develop' to be included in a given minor version release
+- 'main', 'develop' and any intermediate release branch are expected to be
+  stable at all times
+- feature branches should be created at will to develop features, enhancements,
+  or even hotfixes that will later be merged into 'develop' and eventually into
+  'main'.
+- it is legit to write up poc branches, as well as to split the development of
+  a feature into multiple branches that will incrementally be merged into an
+  intermediate feature branch that will eventually be merged into 'develop'
 
 The **coding rules** are fairly simple:
 
@@ -723,6 +772,10 @@ The **coding rules** are fairly simple:
 - reformat your code using [black](https://github.com/psf/black); do use
   (sparingly) "fmt: off/on" comments when you think it relevant
   (see dedicated sub-section [below](#running-black-to-format-the-code))
+- abide by [semver](https://semver.org/) when implementing new features or
+  changing the existing APIs; try making changes non-breaking, document and
+  warn about deprecations or behavior changes, or make a point for API-breaking
+  changes, which we are happy to consider but might take time to be released
 
 ### Unit tests and code analysis
 
diff --git a/declearn/metrics/__init__.py b/declearn/metrics/__init__.py
index d4e9cf5ddbf1c071795964bdc20bee1ff18197c3..0051de291a9ad01d3e4f5abb354991cb2cdb1ee7 100644
--- a/declearn/metrics/__init__.py
+++ b/declearn/metrics/__init__.py
@@ -43,13 +43,19 @@ Classification metrics:
     Identifier name: "multi-classif".
 * BinaryRocAuc:
     Receiver Operator Curve and its Area Under the Curve for binary classif.
-    Identified name: "binary-roc"
+    Identifier name: "binary-roc"
 
 Regression metrics:
 * MeanAbsoluteError:
     Mean absolute error, averaged across all samples (and channels).
+    Identifier name: "mae"
 * MeanSquaredError:
     Mean squared error, averaged across all samples (and channels).
+    Identifier name: "mse"
+* RSquared:
+    R^2 (R-Squared, coefficient of determination) regression metric.
+    Identifier name: "r2"
+
 """
 
 from ._api import Metric
@@ -59,4 +65,5 @@ from ._classif import (
 )
 from ._mean import MeanMetric, MeanAbsoluteError, MeanSquaredError
 from ._roc_auc import BinaryRocAUC
+from ._rsquared import RSquared
 from ._wrapper import MetricInputType, MetricSet
diff --git a/declearn/metrics/_api.py b/declearn/metrics/_api.py
index d19cc2f210c89f9ca7d8d7322b441445e297c768..16d74c78559edaa2cbe28c3fdcbbd05303f96fe9 100644
--- a/declearn/metrics/_api.py
+++ b/declearn/metrics/_api.py
@@ -17,6 +17,7 @@
 
 """Iterative and federative evaluation metrics base class."""
 
+import warnings
 from abc import ABCMeta, abstractmethod
 from copy import deepcopy
 from typing import Any, ClassVar, Dict, Optional, Union
@@ -181,23 +182,6 @@ class Metric(metaclass=ABCMeta):
             Optional sample weights to take into account in scores.
         """
 
-    @staticmethod
-    def normalize_weights(s_wght: np.ndarray) -> np.ndarray:
-        """Utility method to ensure weights sum to one.
-
-        Note that this method may or may not be used depending on
-        the actual `Metric` considered, and is merely provided as
-        a utility to metric developers.
-        """
-        if s_wght.sum():
-            s_wght /= s_wght.sum()
-        else:
-            raise ValueError(
-                "Weights provided sum to zero, please provide only "
-                "positive weights with at least one non-zero weight."
-            )
-        return s_wght
-
     def reset(
         self,
     ) -> None:
@@ -315,3 +299,65 @@ class Metric(metaclass=ABCMeta):
                 f"Failed to retrieve Metric subclass from name '{name}'."
             ) from exc
         return cls.from_config(config or {})
+
+    @staticmethod
+    def _prepare_sample_weights(
+        s_wght: Optional[np.ndarray],
+        n_samples: int,
+    ) -> np.ndarray:
+        """Flatten or generate sample weights and validate their shape.
+
+        This method is a shared util that may or may not be used as part
+        of concrete Metric classes' backend depending on their formula.
+
+        Parameters
+        ----------
+        s_wght: np.ndarray or None
+            1-d (or squeezable) array of sample-wise positive scalar
+            weights. If None, one will be generated, with one values.
+        n_samples: int
+            Expected length of the sample weights.
+
+        Returns
+        -------
+        s_wght: np.ndarray
+            Input (opt. squeezed) `s_wght`, or `np.ones(n_samples)`
+            if input was None.
+
+        Raises
+        ------
+        ValueError:
+            If the input array has improper shape or negative values.
+        """
+        if s_wght is None:
+            return np.ones(shape=(n_samples,))
+        s_wght = s_wght.squeeze()
+        if s_wght.shape != (n_samples,) or np.any(s_wght < 0):
+            raise ValueError(
+                "Improper shape for 's_wght': should be a 1-d array "
+                "of sample-wise positive scalar weights."
+            )
+        return s_wght
+
+    @staticmethod
+    def normalize_weights(s_wght: np.ndarray) -> np.ndarray:
+        """Utility method to ensure weights sum to one.
+
+        Note that this method may or may not be used depending on
+        the actual `Metric` considered, and is merely provided as
+        a utility to metric developers.
+        """
+        warn = DeprecationWarning(
+            "'Metric.normalize_weights' is unfit for the iterative "
+            "nature of the metric-computation process. It will be "
+            "removed from the Metric API in declearn v3.0."
+        )
+        warnings.warn(warn)
+        if s_wght.sum():
+            s_wght /= s_wght.sum()
+        else:
+            raise ValueError(
+                "Weights provided sum to zero, please provide only "
+                "positive weights with at least one non-zero weight."
+            )
+        return s_wght
diff --git a/declearn/metrics/_mean.py b/declearn/metrics/_mean.py
index 7ac765a12c2309f18dfce7dcc9c1142d9310a619..f92f01d8fd553b6b2a483c7dc4e2ff299070e2d8 100644
--- a/declearn/metrics/_mean.py
+++ b/declearn/metrics/_mean.py
@@ -101,12 +101,7 @@ class MeanMetric(Metric, register=False, metaclass=ABCMeta):
             self._states["current"] += scores.sum()
             self._states["divisor"] += len(y_pred)
         else:
-            s_wght = s_wght.squeeze()
-            if s_wght.shape != (len(y_pred),):
-                raise ValueError(
-                    "Improper shape for 's_wght': should be a 1-d array "
-                    "of sample-wise scalar weights."
-                )
+            s_wght = self._prepare_sample_weights(s_wght, len(y_pred))
             self._states["current"] += (s_wght * scores).sum()
             self._states["divisor"] += np.sum(s_wght)
 
diff --git a/declearn/metrics/_roc_auc.py b/declearn/metrics/_roc_auc.py
index a8a97dcc0a277081ba9fe92f098ba1dca92a3f80..4ed37a853cc57f5e12b99ed0ab8b5785d6712c4b 100644
--- a/declearn/metrics/_roc_auc.py
+++ b/declearn/metrics/_roc_auc.py
@@ -53,7 +53,9 @@ class BinaryRocAUC(Metric):
 
     Note that this class supports aggregating states from another
     BinaryRocAUC instance with different hyper-parameters into it,
-    unless its
+    unless its `bound` parameter is set - in which case thresholds
+    are not authorized to be dynamically updated, either at samples
+    processing or states-aggregating steps.
     """
 
     name: ClassVar[str] = "binary-roc"
diff --git a/declearn/metrics/_rsquared.py b/declearn/metrics/_rsquared.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b99d834b562332716e8413c66647b5885728b58
--- /dev/null
+++ b/declearn/metrics/_rsquared.py
@@ -0,0 +1,130 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Iterative and federative R-Squared evaluation metric."""
+
+from typing import ClassVar, Dict, Optional, Union
+
+import numpy as np
+
+from declearn.metrics._api import Metric
+
+__all__ = [
+    "RSquared",
+]
+
+
+class RSquared(Metric):
+    """R^2 (R-Squared, coefficient of determination) regression metric.
+
+    This metric applies to a regression model, and computes the (opt.
+    weighted) R^2 score, also known as coefficient of determination.
+
+    Computed metric is the following:
+    * r2: float
+        R^2 score, or coefficient of determination, averaged across samples.
+        It is defined as the proportion of total sample variance explained
+        by the regression model:
+        * SSr = Sum((true - pred)^2)  # Residual sum of squares
+        * SSt = Sum((true - mean(true))^2)  # Total sum of squares
+        * R^2 = 1 - (SSr / SSt)
+
+    Notes:
+    - This metric expects 1d-arrays, or arrays than can be reduced to 1-d
+    - If the true variance is zero, we by convention return a perfect score
+      if the expected variance is also zero, else return a score of 0.0.
+    - The R^2 score is not well-defined with less than two samples.
+
+    Implementation details:
+    - Since this metric is to be computed iteratively and with a single pass
+      over a batched dataset, we use the König-Huygens formula to decompose
+      the total sum of squares into a sum of terms that can be updated with
+      summation for each batch received (as opposed to using an estimate of
+      the mean of true values that would vary with each batch). This gives:
+        wSST = Sum(weight * (true - mean(true))^2)  # initial definition
+        wSST = Sum(weight * true^2) - (Sum(weight * true))^2 / Sum(weight)
+
+    LaTeX formulas (with weights):
+    - Canonical formula:
+        $$R^2(y, \\hat{y})= 1 - \\frac{
+            \\sum_{i=1}^n w_i \\left(y_i-\\hat{y}_i\\right)^2
+        }{
+            \\sum_{i=1}^n w_i \\left(y_i-\\bar{y}\\right)^2
+        }$$
+    - Decomposed weighted total sum of squares:
+        $$\\sum_{i=1}^n w_i \\left(y_i-\\bar{y}\\right)^2 =
+            \\sum_i w_i y_i^2
+            - \\frac{\\left(\\sum_i w_i y_i\\right)^2}{\\sum_i w_i}
+        $$
+    """
+
+    name: ClassVar[str] = "r2"
+
+    def _build_states(
+        self,
+    ) -> Dict[str, Union[float, np.ndarray]]:
+        return {
+            "sum_of_squared_errors": 0.0,
+            "sum_of_squared_labels": 0.0,
+            "sum_of_labels": 0.0,
+            "sum_of_weights": 0.0,
+        }
+
+    def get_result(
+        self,
+    ) -> Dict[str, Union[float, np.ndarray]]:
+        # Case when no samples were seen: return 0. by convention.
+        if self._states["sum_of_weights"] == 0:
+            return {self.name: 0.0}
+        # Compute the (weighted) total sum of squares.
+        ss_tot = (  # wSSt = sum(w * y^2) - (sum(w * y))^2 / sum(w)
+            self._states["sum_of_squared_labels"]
+            - self._states["sum_of_labels"] ** 2
+            / self._states["sum_of_weights"]
+        )
+        ss_res = self._states["sum_of_squared_errors"]
+        # Handle the edge case where SSt is null.
+        if ss_tot == 0:
+            return {self.name: 1.0 if ss_res == 0 else 0.0}
+        # Otherwise, compute and return the R-squared metric.
+        result = 1 - ss_res / ss_tot
+        return {self.name: float(result)}
+
+    def update(
+        self,
+        y_true: np.ndarray,
+        y_pred: np.ndarray,
+        s_wght: Optional[np.ndarray] = None,
+    ) -> None:
+        # Verify sample weights' shape, or set up 1-valued ones.
+        s_wght = self._prepare_sample_weights(s_wght, n_samples=len(y_pred))
+        # Update the residual sum of squares. wSSr = sum(w * (y - p)^2)
+        ss_res = (s_wght * self._sum_to_1d(y_true - y_pred) ** 2).sum()
+        self._states["sum_of_squared_errors"] += ss_res
+        # Update states that compose the total sum of squares.
+        # wSSt = sum(w * y^2) - (sum(w * y))^2 / sum(w)
+        y_true = self._sum_to_1d(y_true)
+        self._states["sum_of_squared_labels"] += (s_wght * y_true**2).sum()
+        self._states["sum_of_labels"] += (s_wght * y_true).sum()
+        self._states["sum_of_weights"] += s_wght.sum()
+
+    @staticmethod
+    def _sum_to_1d(val: np.ndarray) -> np.ndarray:
+        "Utility method to reduce an array of any shape to a 1-d array"
+        while val.ndim > 1:
+            val = val.sum(axis=-1)
+        return val
diff --git a/declearn/model/api/_model.py b/declearn/model/api/_model.py
index e934ce6d558327f8d2ba3a174992c303658af91b..3a9a87cbde3e944a57b47b0df3a14a28b7515985 100644
--- a/declearn/model/api/_model.py
+++ b/declearn/model/api/_model.py
@@ -25,7 +25,7 @@ from typing_extensions import Self  # future: import from typing (py >=3.11)
 
 from declearn.model.api._vector import Vector
 from declearn.typing import Batch
-from declearn.utils import create_types_registry
+from declearn.utils import DevicePolicy, create_types_registry
 
 
 __all__ = [
@@ -43,6 +43,13 @@ class Model(metaclass=ABCMeta):
     writing algorithms and operations agnostic to the framework
     in which the underlying model is implemented (e.g. PyTorch,
     TensorFlow, Scikit-Learn...).
+
+    Device-placement (i.e. running computations on CPU or GPU)
+    is also handled as part of Model classes' backend, mapping
+    the generic `declearn.utils.DevicePolicy` parameters to any
+    required framework-specific instruction to adequately pick
+    the device to use and ensure the wrapped model, input data
+    and interfaced computations are placed there.
     """
 
     def __init__(
@@ -52,6 +59,13 @@ class Model(metaclass=ABCMeta):
         """Instantiate a Model interface wrapping a 'model' object."""
         self._model = model
 
+    @property
+    @abstractmethod
+    def device_policy(
+        self,
+    ) -> DevicePolicy:
+        """Return the device-placement policy currently used by this model."""
+
     @property
     @abstractmethod
     def required_data_info(
@@ -263,3 +277,27 @@ class Model(metaclass=ABCMeta):
         s_loss: np.ndarray
             Sample-wise loss values, as a 1-d numpy array.
         """
+
+    @abstractmethod
+    def update_device_policy(
+        self,
+        policy: Optional[DevicePolicy] = None,
+    ) -> None:
+        """Update the device-placement policy of this model.
+
+        This method is designed to be called after a change in the global
+        device-placement policy (e.g. to disable using a GPU, or move to
+        a specific one), so as to place pre-existing Model instances and
+        avoid policy inconsistencies that might cause repeated memory or
+        runtime costs from moving data or weights around each time they
+        are used. You should otherwise not worry about a Model's device-
+        placement, as it is handled at instantiation based on the global
+        device policy (see `declearn.utils.set_device_policy`).
+
+        Parameters
+        ----------
+        policy: DevicePolicy or None, default=None
+            Optional DevicePolicy dataclass instance to be used.
+            If None, use the global device policy, accessed via
+            `declearn.utils.get_device_policy`.
+        """
diff --git a/declearn/model/api/_vector.py b/declearn/model/api/_vector.py
index c55986773c830b5d0ddaecdee596879ee1efdfb5..fc6fdc78b15ddc5eb406b10d3ffef50d50ec2857 100644
--- a/declearn/model/api/_vector.py
+++ b/declearn/model/api/_vector.py
@@ -315,10 +315,7 @@ class Vector(metaclass=ABCMeta):
             return type(self)(coefs)
         # Case when the two vectors have incompatible types.
         if isinstance(other, Vector):
-            raise TypeError(
-                f"Cannot {func.__name__} {type(self).__name__} object with "
-                f"a vector of incompatible type {type(other).__name__}."
-            )
+            return NotImplemented
         # Case when operating with another object (e.g. a scalar).
         try:
             return type(self)(
diff --git a/declearn/model/sklearn/_np_vec.py b/declearn/model/sklearn/_np_vec.py
index 5903f7dd91ce4d7707d22511645899fd047d2b2e..65ac351d93904c3375f6d2244f8b271f3a067eb1 100644
--- a/declearn/model/sklearn/_np_vec.py
+++ b/declearn/model/sklearn/_np_vec.py
@@ -41,6 +41,19 @@ class NumpyVector(Vector):
     instances with similar coefficients specifications).
 
     Use `vector.coefs` to access the stored coefficients.
+
+    Notes
+    -----
+    - A `NumpyVector` can be operated with either a scalar value,
+      or another `NumpyVector` that has similar specifications
+      (same coefficient names, shapes and compatible dtypes).
+    - Some other `Vector` classes might be made compatible with
+      `NumpyVector`; in that case, operating with a `NumpyVector`
+      will always result in a vector of the other type. This is
+      notably the case with `TensorflowVector` and `TorchVector`.
+    - There is currently no support for GPU-acceleration with the
+      `NumpyVector` class, that only handles arrays and operations
+      placed on a CPU device.
     """
 
     @property
diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py
index 5485332f878ca52fde326a58fa678271972b55fe..47529be32ca8ff6e5718d504caa3106b9ffcb3fa 100644
--- a/declearn/model/sklearn/_sgd.py
+++ b/declearn/model/sklearn/_sgd.py
@@ -18,6 +18,7 @@
 """Model subclass to wrap scikit-learn SGD classifier and regressor models."""
 
 import typing
+import warnings
 from typing import Any, Callable, Dict, Literal, Optional, Set, Tuple, Union
 
 import numpy as np
@@ -29,7 +30,7 @@ from declearn.data_info import aggregate_data_info
 from declearn.model.api import Model
 from declearn.model.sklearn._np_vec import NumpyVector
 from declearn.typing import Batch
-from declearn.utils import register_type
+from declearn.utils import DevicePolicy, register_type
 
 
 __all__ = [
@@ -63,6 +64,13 @@ class SklearnSGDModel(Model):
     This `Model` subclass is designed to wrap a `SGDClassifier`
     or `SGDRegressor` instance (from `sklearn.linear_model`) to
     be learned federatively.
+
+    Notes regarding device management (CPU, GPU, etc.):
+    * This Model may only run on CPU, and is unaffected by device-
+      management policies.
+    * Calling the `update_device_policy` method has no effect, and
+      raises a UserWarning if a GPU-targetting policy is passed to
+      it directly.
     """
 
     def __init__(
@@ -104,6 +112,12 @@ class SklearnSGDModel(Model):
             None
         )  # type: Optional[Callable[[np.ndarray, np.ndarray], np.ndarray]]
 
+    @property
+    def device_policy(
+        self,
+    ) -> DevicePolicy:
+        return DevicePolicy(gpu=False, idx=None)
+
     @property
     def required_data_info(
         self,
@@ -384,3 +398,10 @@ class SklearnSGDModel(Model):
         else:
             loss_fn = loss_1d
         return loss_fn
+
+    def update_device_policy(
+        self,
+        policy: Optional[DevicePolicy] = None,
+    ) -> None:
+        if policy is not None and policy.gpu:
+            warnings.warn("'SklearnSGDModel' only runs on a CPU backend.")
diff --git a/declearn/model/tensorflow/__init__.py b/declearn/model/tensorflow/__init__.py
index ff94c495b4647614cc5f9e9e44b0a9d636f3c3d3..a8db61f23c5e09cb1a25b2c2abdf70a9128d9972 100644
--- a/declearn/model/tensorflow/__init__.py
+++ b/declearn/model/tensorflow/__init__.py
@@ -24,7 +24,11 @@ through gradient descent.
 This module exposes:
 * TensorflowModel: Model subclass to wrap tensorflow.keras.Model objects
 * TensorflowVector: Vector subclass to wrap tensorflow.Tensor objects
+
+It also exposes the `utils` submodule, which mainly aims at
+providing tools used in the backend of the former objects.
 """
 
+from . import utils
 from ._vector import TensorflowVector
 from ._model import TensorflowModel
diff --git a/declearn/model/tensorflow/_model.py b/declearn/model/tensorflow/_model.py
index 3ca9843d2cd2c1a52f1451f0c3d6afacb36c3e56..732b8723a0ed1b436f46ee89a6d21e0814a68140 100644
--- a/declearn/model/tensorflow/_model.py
+++ b/declearn/model/tensorflow/_model.py
@@ -27,19 +27,44 @@ from typing_extensions import Self  # future: import from typing (py >=3.11)
 
 from declearn.data_info import aggregate_data_info
 from declearn.model.api import Model
-from declearn.model.tensorflow._utils import build_keras_loss
+from declearn.model.tensorflow.utils import (
+    build_keras_loss,
+    move_layer_to_device,
+    select_device,
+)
 from declearn.model.tensorflow._vector import TensorflowVector
 from declearn.model._utils import raise_on_stringsets_mismatch
 from declearn.typing import Batch
-from declearn.utils import register_type
+from declearn.utils import DevicePolicy, get_device_policy, register_type
+
+
+__all__ = [
+    "TensorflowModel",
+]
 
 
 @register_type(name="TensorflowModel", group="Model")
 class TensorflowModel(Model):
     """Model wrapper for TensorFlow Model instances.
 
-    This `Model` subclass is designed to wrap a `tf.keras.Model`
-    instance to be learned federatively.
+    This `Model` subclass is designed to wrap a `tf.keras.Model` instance
+    to be trained federatively.
+
+    Notes regarding device management (CPU, GPU, etc.):
+    * By default, tensorflow places data and operations on GPU whenever one
+      is available.
+    * Our `TensorflowModel` instead consults the device-placement policy (via
+      `declearn.utils.get_device_policy`), places the wrapped keras model's
+      weights there, and runs computations defined under public methods in
+      a `tensorflow.device` context, to enforce that policy.
+    * Note that there is no guarantee that calling a private method directly
+      will result in abiding by that policy. Hence, be careful when writing
+      custom code, and use your own context managers to get guarantees.
+    * Note that if the global device-placement policy is updated, this will
+      only be propagated to existing instances by manually calling their
+      `update_device_policy` method.
+    * You may consult the device policy currently enforced by a TensorflowModel
+      instance by accessing its `device_policy` property.
     """
 
     def __init__(
@@ -47,6 +72,7 @@ class TensorflowModel(Model):
         model: tf.keras.layers.Layer,
         loss: Union[str, tf.keras.losses.Loss],
         metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None,
+        _from_config: bool = False,
         **kwargs: Any,
     ) -> None:
         """Instantiate a Model interface wrapping a tensorflow.keras model.
@@ -66,7 +92,7 @@ class TensorflowModel(Model):
             compiled with the model and computed using the `evaluate`
             method of the returned TensorflowModel instance.
         **kwargs: Any
-            Any addition keyword argument to `tf.keras.Model.compile`
+            Any additional keyword argument to `tf.keras.Model.compile`
             may be passed.
         """
         # Type-check the input Model and wrap it up.
@@ -79,12 +105,30 @@ class TensorflowModel(Model):
         super().__init__(model)
         # Ensure the loss is a keras.Loss object and set its reduction to none.
         loss = build_keras_loss(loss, reduction=tf.keras.losses.Reduction.NONE)
-        # Compile the wrapped model and retain compilation arguments.
-        kwargs.update({"loss": loss, "metrics": metrics})
-        model.compile(**kwargs)
-        self._kwargs = kwargs
-        # Instantiate a SGD optimizer to apply updates as-provided.
-        self._sgd = tf.keras.optimizers.SGD(learning_rate=1.0)
+        # Select the device where to place computations and move the model.
+        policy = get_device_policy()
+        self._device = select_device(gpu=policy.gpu, idx=policy.idx)
+        if not _from_config:
+            self._model = move_layer_to_device(self._model, self._device)
+        # Finalize initialization using the selected device.
+        with tf.device(self._device):
+            # Compile the wrapped model and retain compilation arguments.
+            kwargs.update({"loss": loss, "metrics": metrics})
+            self._model.compile(**kwargs)
+            self._kwargs = kwargs
+            # Instantiate a SGD optimizer to apply updates as-provided.
+            self._sgd = tf.keras.optimizers.SGD(learning_rate=1.0)
+
+    @property
+    def device_policy(
+        self,
+    ) -> DevicePolicy:
+        device = self._device
+        try:
+            idx = int(device.name.rsplit(":", 1)[-1])
+        except ValueError:
+            idx = None
+        return DevicePolicy(gpu=(device.device_type == "GPU"), idx=idx)
 
     @property
     def required_data_info(
@@ -98,7 +142,8 @@ class TensorflowModel(Model):
     ) -> None:
         if not self._model.built:
             data_info = aggregate_data_info([data_info], {"input_shape"})
-            self._model.build(data_info["input_shape"])
+            with tf.device(self._device):
+                self._model.build(data_info["input_shape"])
 
     def get_config(
         self,
@@ -117,9 +162,15 @@ class TensorflowModel(Model):
         for key in ("model", "loss", "kwargs"):
             if key not in config.keys():
                 raise KeyError(f"Missing key '{key}' in the config dict.")
-        model = tf.keras.layers.deserialize(config["model"])
-        loss = tf.keras.losses.deserialize(config["loss"])
-        return cls(model, loss, **config["kwargs"])
+        # Set up the device policy.
+        policy = get_device_policy()
+        device = select_device(gpu=policy.gpu, idx=policy.idx)
+        # Deserialize the model and loss keras objects on the device.
+        with tf.device(device):
+            model = tf.keras.layers.deserialize(config["model"])
+            loss = tf.keras.losses.deserialize(config["loss"])
+        # Instantiate the TensorflowModel, avoiding device-to-device copies.
+        return cls(model, loss, **config["kwargs"], _from_config=True)
 
     def get_weights(
         self,
@@ -141,8 +192,9 @@ class TensorflowModel(Model):
             )
         self._verify_weights_compatibility(weights, trainable=trainable)
         variables = {var.name: var for var in self._model.weights}
-        for name, value in weights.coefs.items():
-            variables[name].assign(value, read_value=False)
+        with tf.device(self._device):
+            for name, value in weights.coefs.items():
+                variables[name].assign(value, read_value=False)
 
     def _verify_weights_compatibility(
         self,
@@ -180,12 +232,13 @@ class TensorflowModel(Model):
         batch: Batch,
         max_norm: Optional[float] = None,
     ) -> TensorflowVector:
-        data = self._unpack_batch(batch)
-        if max_norm is None:
-            grads = self._compute_batch_gradients(*data)
-        else:
-            norm = tf.constant(max_norm)
-            grads = self._compute_clipped_gradients(*data, norm)
+        with tf.device(self._device):
+            data = self._unpack_batch(batch)
+            if max_norm is None:
+                grads = self._compute_batch_gradients(*data)
+            else:
+                norm = tf.constant(max_norm)
+                grads = self._compute_clipped_gradients(*data, norm)
         grads_and_vars = zip(grads, self._model.trainable_weights)
         return TensorflowVector(
             {var.name: grad for grad, var in grads_and_vars}
@@ -268,13 +321,14 @@ class TensorflowModel(Model):
         updates: TensorflowVector,
     ) -> None:
         self._verify_weights_compatibility(updates, trainable=True)
-        # Delegate updates' application to a tensorflow Optimizer.
-        values = (-1 * updates).coefs.values()
-        zipped = zip(values, self._model.trainable_weights)
-        upd_op = self._sgd.apply_gradients(zipped)
-        # Ensure ops have been performed before exiting.
-        with tf.control_dependencies([upd_op]):
-            return None
+        with tf.device(self._device):
+            # Delegate updates' application to a tensorflow Optimizer.
+            values = (-1 * updates).coefs.values()
+            zipped = zip(values, self._model.trainable_weights)
+            upd_op = self._sgd.apply_gradients(zipped)
+            # Ensure ops have been performed before exiting.
+            with tf.control_dependencies([upd_op]):
+                return None
 
     def evaluate(
         self,
@@ -294,21 +348,23 @@ class TensorflowModel(Model):
         metrics: dict[str, float]
             Dictionary associating evaluation metrics' values to their name.
         """
-        return self._model.evaluate(dataset, return_dict=True)
+        with tf.device(self._device):
+            return self._model.evaluate(dataset, return_dict=True)
 
     def compute_batch_predictions(
         self,
         batch: Batch,
     ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
-        inputs, y_true, s_wght = self._unpack_batch(batch)
-        if y_true is None:
-            raise TypeError(
-                "`TensorflowModel.compute_batch_predictions` received a "
-                "batch with `y_true=None`, which is unsupported. Please "
-                "correct the inputs, or override this method to support "
-                "creating labels from the base inputs."
-            )
-        y_pred = self._model(inputs, training=False).numpy()
+        with tf.device(self._device):
+            inputs, y_true, s_wght = self._unpack_batch(batch)
+            if y_true is None:
+                raise TypeError(
+                    "`TensorflowModel.compute_batch_predictions` received a "
+                    "batch with `y_true=None`, which is unsupported. Please "
+                    "correct the inputs, or override this method to support "
+                    "creating labels from the base inputs."
+                )
+            y_pred = self._model(inputs, training=False).numpy()
         y_true = y_true.numpy()
         s_wght = s_wght.numpy() if s_wght is not None else s_wght
         return y_true, y_pred, s_wght
@@ -318,7 +374,21 @@ class TensorflowModel(Model):
         y_true: np.ndarray,
         y_pred: np.ndarray,
     ) -> np.ndarray:
-        tns_true = tf.convert_to_tensor(y_true)
-        tns_pred = tf.convert_to_tensor(y_pred)
-        s_loss = self._model.compute_loss(y=tns_true, y_pred=tns_pred)
+        with tf.device(self._device):
+            tns_true = tf.convert_to_tensor(y_true)
+            tns_pred = tf.convert_to_tensor(y_pred)
+            s_loss = self._model.compute_loss(y=tns_true, y_pred=tns_pred)
         return s_loss.numpy().squeeze()
+
+    def update_device_policy(
+        self,
+        policy: Optional[DevicePolicy] = None,
+    ) -> None:
+        # Select the device to use based on the provided or global policy.
+        if policy is None:
+            policy = get_device_policy()
+        device = select_device(gpu=policy.gpu, idx=policy.idx)
+        # When needed, re-create the model to force moving it to the device.
+        if self._device is not device:
+            self._device = device
+            self._model = move_layer_to_device(self._model, self._device)
diff --git a/declearn/model/tensorflow/_vector.py b/declearn/model/tensorflow/_vector.py
index f10fa5154d3d82b65d747de31d2d8824bc9bbb2d..f347f212b97c37c3176401004f39d4d09aadf89a 100644
--- a/declearn/model/tensorflow/_vector.py
+++ b/declearn/model/tensorflow/_vector.py
@@ -25,30 +25,55 @@ import tensorflow as tf  # type: ignore
 from tensorflow.python.framework.ops import EagerTensor  # type: ignore
 # pylint: enable=no-name-in-module
 from typing_extensions import Self  # future: import from typing (Py>=3.11)
+# fmt: on
 
 from declearn.model.api import Vector, register_vector_type
 from declearn.model.sklearn import NumpyVector
-
-# fmt: on
+from declearn.model.tensorflow.utils import (
+    preserve_tensor_device,
+    select_device,
+)
+from declearn.utils import get_device_policy
 
 
 @register_vector_type(tf.Tensor, EagerTensor, tf.IndexedSlices)
 class TensorflowVector(Vector):
     """Vector subclass to store tensorflow tensors.
 
-    This Vector is designed to store a collection of named
-    TensorFlow tensors, enabling computations that are either
-    applied to each and every coefficient, or imply two sets
-    of aligned coefficients (i.e. two TensorflowVector with
-    similar specifications).
+    This Vector is designed to store a collection of named TensorFlow
+    tensors, enabling computations that are either applied to each and
+    every coefficient, or imply two sets of aligned coefficients (i.e.
+    two TensorflowVector with similar specifications).
+
+    Note that support for IndexedSlices is implemented, as these are a
+    common type for auto-differentiated gradients.
 
-    Note that support for IndexedSlices is implemented,
-    as these are a common type for auto-differentiated
-    gradients.
-    Note that this class does not (yet?) support special
-    tensor types such as SparseTensor or RaggedTensor.
+    Note that this class does not (yet?) support special tensor types
+    such as SparseTensor or RaggedTensor.
 
     Use `vector.coefs` to access the stored coefficients.
+
+    Notes
+    -----
+    - A `TensorflowVector` can be operated with either a:
+      - scalar value
+      - `NumpyVector` that has similar specifications
+      - `TensorflowVector` that has similar specifications
+      => resulting in a `TensorflowVector` in each of these cases.
+    - The wrapped tensors may be placed on any device (CPU, GPU...)
+      and may not be all on the same device.
+    - The device-placement of the initial `TensorflowVector`'s data
+      is preserved by operations, including with `NumpyVector`.
+    - When combining two `TensorflowVector`, the device-placement
+      of the left-most one is used; in that case, one ends up with
+      `gpu + cpu = gpu` while `cpu + gpu = cpu`. In both cases, a
+      warning will be emitted to prevent silent un-optimized copies.
+    - When deserializing a `TensorflowVector` (either by directly using
+      `TensorflowVector.unpack` or loading one from a JSON dump), loaded
+      tensors are placed based on the global device-placement policy
+      (accessed via `declearn.utils.get_device_policy`). Thus it may
+      have a different device-placement schema than at dump time but
+      should be coherent with that of `TensorflowModel` computations.
     """
 
     @property
@@ -81,6 +106,23 @@ class TensorflowVector(Vector):
     ) -> None:
         super().__init__(coefs)
 
+    def apply_func(
+        self,
+        func: Callable[..., Any],
+        *args: Any,
+        **kwargs: Any,
+    ) -> Self:
+        func = preserve_tensor_device(func)
+        return super().apply_func(func, *args, **kwargs)
+
+    def _apply_operation(
+        self,
+        other: Any,
+        func: Callable[[Any, Any], Any],
+    ) -> Self:
+        func = preserve_tensor_device(func)
+        return super()._apply_operation(other, func)
+
     def dtypes(
         self,
     ) -> Dict[str, str]:
@@ -97,7 +139,10 @@ class TensorflowVector(Vector):
         cls,
         data: Dict[str, Any],
     ) -> Self:
-        coef = {key: cls._unpack_tensor(dat) for key, dat in data.items()}
+        policy = get_device_policy()
+        device = select_device(gpu=policy.gpu, idx=policy.idx)
+        with tf.device(device):
+            coef = {key: cls._unpack_tensor(dat) for key, dat in data.items()}
         return cls(coef)
 
     @classmethod
@@ -149,10 +194,13 @@ class TensorflowVector(Vector):
         if not isinstance(t_a, type(t_b)):
             return False
         if isinstance(t_a, tf.IndexedSlices):
-            return TensorflowVector._tensor_equal(
-                t_a.indices, t_b.indices
-            ) and TensorflowVector._tensor_equal(t_a.values, t_b.values)
-        return tf.reduce_all(t_a == t_b).numpy()
+            # fmt: off
+            return (
+                TensorflowVector._tensor_equal(t_a.indices, t_b.indices)
+                and TensorflowVector._tensor_equal(t_a.values, t_b.values)
+            )
+        with tf.device(t_a.device):
+            return tf.reduce_all(t_a == t_b).numpy()
 
     def sign(self) -> Self:
         return self.apply_func(tf.sign)
@@ -178,8 +226,4 @@ class TensorflowVector(Vector):
         axis: Optional[int] = None,
         keepdims: bool = False,
     ) -> Self:
-        coefs = {
-            key: tf.reduce_sum(val, axis=axis, keepdims=keepdims)
-            for key, val in self.coefs.items()
-        }
-        return self.__class__(coefs)
+        return self.apply_func(tf.reduce_sum, axis=axis, keepdims=keepdims)
diff --git a/declearn/model/tensorflow/utils/__init__.py b/declearn/model/tensorflow/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f7d66b1d29e47c7b07dc8b10f1b6e9d11187896
--- /dev/null
+++ b/declearn/model/tensorflow/utils/__init__.py
@@ -0,0 +1,38 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils for tensorflow backend support code.
+
+GPU/CPU backing device management utils:
+* move_layer_to_device:
+    Create a copy of an input keras layer placed on a given device.
+* preserve_tensor_device:
+    Wrap a tensor-processing function to have it run on its inputs' device.
+* select_device:
+    Select a backing device to use based on inputs and availability.
+
+Loss function management utils:
+* build_keras_loss:
+    Type-check, deserialize and/or wrap a keras loss into a Loss object.
+"""
+
+from ._gpu import (
+    move_layer_to_device,
+    preserve_tensor_device,
+    select_device,
+)
+from ._loss import build_keras_loss
diff --git a/declearn/model/tensorflow/utils/_gpu.py b/declearn/model/tensorflow/utils/_gpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..8164250ab560d4db1e08146f960d3ba3f4c6a6d8
--- /dev/null
+++ b/declearn/model/tensorflow/utils/_gpu.py
@@ -0,0 +1,137 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils for GPU support and device management in tensorflow."""
+
+import functools
+import warnings
+from typing import Any, Callable, Optional, Union
+
+import tensorflow as tf  # type: ignore
+
+
+__all__ = [
+    "move_layer_to_device",
+    "preserve_tensor_device",
+    "select_device",
+]
+
+
+def select_device(
+    gpu: bool,
+    idx: Optional[int] = None,
+) -> tf.config.LogicalDevice:
+    """Select a backing device to use based on inputs and availability.
+
+    Parameters
+    ----------
+    gpu: bool
+        Whether to select a GPU device rather than the CPU one.
+    idx: int or None, default=None
+        Optional pre-selected device index. Only used when `gpu=True`.
+        If `idx is None` or exceeds the number of available GPU devices,
+        use the first available one.
+
+    Warns
+    -----
+    UserWarning:
+        If `gpu=True` but no GPU is available.
+        If `idx` exceeds the number of available GPU devices.
+
+    Returns
+    -------
+    device: tf.config.LogicalDevice
+        Selected device, usable as `tf.device` argument.
+    """
+    idx = 0 if idx is None else idx
+    # List available CPU or GPU devices.
+    device_type = "GPU" if gpu else "CPU"
+    devices = tf.config.list_logical_devices(device_type)
+    # Case when no GPU is available: warn and use a CPU instead.
+    if gpu and not devices:
+        warnings.warn(
+            "Cannot use a GPU device: either CUDA is unavailable "
+            "or no GPU is visible to tensorflow."
+        )
+        device_type, idx = "CPU", 0
+        devices = tf.config.list_logical_devices("CPU")
+    # Case when the desired device index is invalid: select another one.
+    if idx >= len(devices):
+        warnings.warn(
+            f"Cannot use {device_type} device n°{idx}: index is out-of-range."
+            f"\nUsing {device_type} device n°0 instead."
+        )
+        idx = 0
+    # Return the selected device.
+    return devices[idx]
+
+
+def move_layer_to_device(
+    layer: tf.keras.layers.Layer,
+    device: Union[tf.config.LogicalDevice, str],
+) -> tf.keras.layers.Layer:
+    """Create a copy of an input keras layer placed on a given device.
+
+    This functions creates a copy of the input layer and of all its weights.
+    It may therefore be costful and should be used sparingly, to move away
+    variables on a device where all further computations are expected to be
+    run.
+
+    Parameters
+    ----------
+    layer: tf.keras.layers.Layer
+        Keras layer that needs moving to another device.
+    device: tf.config.LogicalDevice or str
+        Device where to place the layer's weights.
+
+    Returns
+    -------
+    layer: tf.keras.layers.Layer
+        Copy of the input layer, with its weights backed on `device`.
+    """
+    config = tf.keras.layers.serialize(layer)
+    weights = layer.get_weights()
+    with tf.device(device):
+        layer = tf.keras.layers.deserialize(config)
+        layer.set_weights(weights)
+    return layer
+
+
+def preserve_tensor_device(
+    func: Callable[..., tf.Tensor],
+) -> Callable[..., tf.Tensor]:
+    """Wrap a tensor-processing function to have it run on its inputs' device.
+
+    Parameters
+    ----------
+    func: function(tf.Tensor, ...) -> tf.Tensor:
+        Function to wrap, that takes a tensorflow Tensor as first argument.
+
+    Returns
+    -------
+    func: function(tf.Tensor, ...) -> tf.Tensor:
+        Similar function to the input one, that operates under a `tf.device`
+        context so as to run computations on the first input tensor's device.
+    """
+
+    @functools.wraps(func)
+    def wrapped(tensor: tf.Tensor, *args: Any, **kwargs: Any) -> tf.Tensor:
+        """Wrapped function, running under a `tf.device` context."""
+        with tf.device(tensor.device):
+            return func(tensor, *args, **kwargs)
+
+    return wrapped
diff --git a/declearn/model/tensorflow/_utils.py b/declearn/model/tensorflow/utils/_loss.py
similarity index 98%
rename from declearn/model/tensorflow/_utils.py
rename to declearn/model/tensorflow/utils/_loss.py
index b858e4e4631862a07b8639a8a0229fc6717d547d..c6b880c8e8aaa6dc90cb197e18bf7677f32285fc 100644
--- a/declearn/model/tensorflow/_utils.py
+++ b/declearn/model/tensorflow/utils/_loss.py
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-"""Backend utils for the declearn.model.tensorflow module."""
+"""Function to parse and/or wrap a keras loss for use with declearn."""
 
 import inspect
 
diff --git a/declearn/model/torch/__init__.py b/declearn/model/torch/__init__.py
index 352d3a398f6958488404c2ecd5aecc4a0f29a43b..efdf95f1d5a2ab94b777667fcd3787fb27fdea29 100644
--- a/declearn/model/torch/__init__.py
+++ b/declearn/model/torch/__init__.py
@@ -24,7 +24,11 @@ gradient descent.
 This module exposes:
 * TorchModel: Model subclass to wrap torch.nn.Module objects
 * TorchVector: Vector subclass to wrap torch.Tensor objects
+
+It also exposes the `utils` submodule, which mainly aims at
+providing tools used in the backend of the former objects.
 """
 
+from . import utils
 from ._vector import TorchVector
 from ._model import TorchModel
diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py
index 6dd7b2f7b2d5c1f25ac00f15762ab1510a90e888..6f99b19b3dc9cf3202512dc7a7665b98d7b284b5 100644
--- a/declearn/model/torch/_model.py
+++ b/declearn/model/torch/_model.py
@@ -27,10 +27,16 @@ import torch
 from typing_extensions import Self  # future: import from typing (py >=3.11)
 
 from declearn.model.api import Model
+from declearn.model.torch.utils import AutoDeviceModule, select_device
 from declearn.model.torch._vector import TorchVector
 from declearn.model._utils import raise_on_stringsets_mismatch
 from declearn.typing import Batch
-from declearn.utils import register_type
+from declearn.utils import DevicePolicy, get_device_policy, register_type
+
+
+__all__ = [
+    "TorchModel",
+]
 
 
 # alias for unpacked Batch structures, converted to torch.Tensor objects
@@ -43,8 +49,23 @@ TensorBatch = Tuple[
 class TorchModel(Model):
     """Model wrapper for PyTorch Model instances.
 
-    This `Model` subclass is designed to wrap a `torch.nn.Module`
-    instance to be learned federatively.
+    This `Model` subclass is designed to wrap a `torch.nn.Module` instance
+    to be trained federatively.
+
+    Notes regarding device management (CPU, GPU, etc.):
+    * By default torch operates on CPU, and it does not automatically move
+      tensors between devices. This means users have to be careful where
+      tensors are placed to avoid operations between tensors on different
+      devices, leading to runtime errors.
+    * Our `TorchModel` instead consults the global device-placement policy
+      (via `declearn.utils.get_device_policy`), places the wrapped torch
+      modules' weights there, and automates the placement of input data on
+      the same device as the wrapped model.
+    * Note that if the global device-placement policy is updated, this will
+      only be propagated to existing instances by manually calling their
+      `update_device_policy` method.
+    * You may consult the device policy currently enforced by a TorchModel
+      instance by accessing its `device_policy` property.
     """
 
     def __init__(
@@ -63,18 +84,29 @@ class TorchModel(Model):
             is to be minimized through training. Note that it will be
             altered when wrapped.
         """
-        # Type-check the input Model and wrap it up.
+        # Type-check the input model.
         if not isinstance(model, torch.nn.Module):
             raise TypeError("'model' should be a torch.nn.Module instance.")
+        # Select the device where to place computations, and wrap the model.
+        policy = get_device_policy()
+        device = select_device(gpu=policy.gpu, idx=policy.idx)
+        model = AutoDeviceModule(model, device=device)
         super().__init__(model)
         # Assign loss module and set it not to reduce sample-wise values.
         if not isinstance(loss, torch.nn.Module):
             raise TypeError("'loss' should be a torch.nn.Module instance.")
-        self._loss_fn = loss
-        self._loss_fn.reduction = "none"  # type: ignore
+        loss.reduction = "none"  # type: ignore
+        self._loss_fn = AutoDeviceModule(loss, device=device)
         # Compute and assign a functional version of the model.
         self._func_model = functorch.make_functional(self._model)[0]
 
+    @property
+    def device_policy(
+        self,
+    ) -> DevicePolicy:
+        device = self._model.device
+        return DevicePolicy(gpu=(device.type == "cuda"), idx=device.index)
+
     @property
     def required_data_info(
         self,
@@ -94,15 +126,12 @@ class TorchModel(Model):
             "PyTorch JSON serialization relies on pickle, which may be unsafe."
         )
         with io.BytesIO() as buffer:
-            torch.save(self._model, buffer)
+            torch.save(self._model.module, buffer)
             model = buffer.getbuffer().hex()
         with io.BytesIO() as buffer:
-            torch.save(self._loss_fn, buffer)
+            torch.save(self._loss_fn.module, buffer)
             loss = buffer.getbuffer().hex()
-        return {
-            "model": model,
-            "loss": loss,
-        }
+        return {"model": model, "loss": loss}
 
     @classmethod
     def from_config(
@@ -141,6 +170,7 @@ class TorchModel(Model):
             state_dict.update(weights.coefs)
         else:
             state_dict = weights.coefs
+        # NOTE: this preserves the device placement of current states
         self._model.load_state_dict(state_dict)
 
     def _verify_weights_compatibility(
@@ -231,7 +261,7 @@ class TorchModel(Model):
         """Compute the average (opt. weighted) loss over given predictions."""
         loss = self._loss_fn(y_pred, y_true)
         if s_wght is not None:
-            loss.mul_(s_wght)
+            loss.mul_(s_wght.to(loss.device))
         return loss.mean()
 
     def _compute_samplewise_gradients(
@@ -270,7 +300,7 @@ class TorchModel(Model):
                 # false-positive; pylint: disable=no-member
                 grad.mul_(torch.clamp(max_norm / norm, max=1))
                 if s_wght is not None:
-                    grad.mul_(s_wght)
+                    grad.mul_(s_wght.to(grad.device))
             return grads
         # Vectorize the function to compute sample-wise clipped gradients.
         with torch.no_grad():
@@ -356,7 +386,7 @@ class TorchModel(Model):
         with torch.no_grad():
             for key, upd in updates.coefs.items():
                 tns = self._model.get_parameter(key)
-                tns.add_(upd)
+                tns.add_(upd.to(tns.device))
 
     def compute_batch_predictions(
         self,
@@ -372,9 +402,9 @@ class TorchModel(Model):
             )
         self._model.eval()
         with torch.no_grad():
-            y_pred = self._model(*inputs).numpy()
-        y_true = y_true.numpy()
-        s_wght = s_wght.numpy() if s_wght is not None else s_wght
+            y_pred = self._model(*inputs).cpu().numpy()
+        y_true = y_true.cpu().numpy()
+        s_wght = None if s_wght is None else s_wght.cpu().numpy()
         return y_true, y_pred, s_wght  # type: ignore
 
     def loss_function(
@@ -385,4 +415,16 @@ class TorchModel(Model):
         tns_pred = torch.from_numpy(y_pred)  # pylint: disable=no-member
         tns_true = torch.from_numpy(y_true)  # pylint: disable=no-member
         s_loss = self._loss_fn(tns_pred, tns_true)
-        return s_loss.numpy().squeeze()
+        return s_loss.cpu().numpy().squeeze()
+
+    def update_device_policy(
+        self,
+        policy: Optional[DevicePolicy] = None,
+    ) -> None:
+        # Select the device to use based on the provided or global policy.
+        if policy is None:
+            policy = get_device_policy()
+        device = select_device(gpu=policy.gpu, idx=policy.idx)
+        # Place the wrapped model and loss function modules on that device.
+        self._model.set_device(device)
+        self._loss_fn.set_device(device)
diff --git a/declearn/model/torch/_vector.py b/declearn/model/torch/_vector.py
index 5a6ee87b5a5e1da03e9e2998d9e3338639e6e394..6b2a19e2d05ea79e8e83c4206153c6d614e3e735 100644
--- a/declearn/model/torch/_vector.py
+++ b/declearn/model/torch/_vector.py
@@ -19,25 +19,47 @@
 
 from typing import Any, Callable, Dict, Optional, Set, Tuple, Type
 
-import numpy as np
 import torch
 from typing_extensions import Self  # future: import from typing (Py>=3.11)
 
 from declearn.model.api import Vector, register_vector_type
 from declearn.model.sklearn import NumpyVector
+from declearn.model.torch.utils import select_device
+from declearn.utils import get_device_policy
 
 
 @register_vector_type(torch.Tensor)
 class TorchVector(Vector):
     """Vector subclass to store PyTorch tensors.
 
-    This Vector is designed to store a collection of named
-    PyTorch tensors, enabling computations that are either
-    applied to each and every coefficient, or imply two sets
-    of aligned coefficients (i.e. two TorchVector with
-    similar specifications).
+    This Vector is designed to store a collection of named PyTorch
+    tensors, enabling computations that are either applied to each
+    and every coefficient, or imply two sets of aligned coefficients
+    (i.e. two TorchVector with similar specifications).
 
     Use `vector.coefs` to access the stored coefficients.
+
+    Notes
+    -----
+    - A `TorchVector` can be operated with either a:
+      - scalar value
+      - `NumpyVector` that has similar specifications
+      - `TorchVector` that has similar specifications
+      => resulting in a `TorchVector` in each of these cases.
+    - The wrapped tensors may be placed on any device (CPU, GPU...)
+      and may not be all on the same device.
+    - The device-placement of the initial `TorchVector`'s data
+      is preserved by operations, including with `NumpyVector`.
+    - When combining two `TorchVector`, the device-placement
+      of the left-most one is used; in that case, one ends up with
+      `gpu + cpu = gpu` while `cpu + gpu = cpu`. In both cases, a
+      warning will be emitted to prevent silent un-optimized copies.
+    - When deserializing a `TorchVector` (either by directly using
+      `TorchVector.unpack` or loading one from a JSON dump), loaded
+      tensors are placed based on the global device-placement policy
+      (accessed via `declearn.utils.get_device_policy`). Thus it may
+      have a different device-placement schema than at dump time but
+      should be coherent with that of `TorchModel` computations.
     """
 
     @property
@@ -73,12 +95,20 @@ class TorchVector(Vector):
         other: Any,
         func: Callable[[Any, Any], Any],
     ) -> Self:
+        # Convert 'other' NumpyVector to a (CPU-backed) TorchVector.
         if isinstance(other, NumpyVector):
             # false-positive; pylint: disable=no-member
             coefs = {
                 key: torch.from_numpy(val) for key, val in other.coefs.items()
             }
             other = TorchVector(coefs)
+        # Ensure 'other' TorchVector shares this vector's device placement.
+        if isinstance(other, TorchVector):
+            coefs = {
+                key: val.to(self.coefs[key].device)
+                for key, val in other.coefs.items()
+            }
+            other = TorchVector(coefs)
         return super()._apply_operation(other, func)
 
     def dtypes(
@@ -95,15 +125,20 @@ class TorchVector(Vector):
     def pack(
         self,
     ) -> Dict[str, Any]:
-        return {key: tns.numpy() for key, tns in self.coefs.items()}
+        return {key: tns.cpu().numpy() for key, tns in self.coefs.items()}
 
     @classmethod
     def unpack(
         cls,
         data: Dict[str, Any],
     ) -> Self:
-        # false-positive; pylint: disable=no-member
-        coefs = {key: torch.from_numpy(dat) for key, dat in data.items()}
+        policy = get_device_policy()
+        device = select_device(gpu=policy.gpu, idx=policy.idx)
+        coefs = {
+            # false-positive on `torch.from_numpy`; pylint: disable=no-member
+            key: torch.from_numpy(dat).to(device)
+            for key, dat in data.items()
+        }
         return cls(coefs)
 
     def __eq__(
@@ -115,8 +150,9 @@ class TorchVector(Vector):
             valid = self.coefs.keys() == other.coefs.keys()
         if valid:
             valid = all(
-                np.array_equal(self.coefs[k].numpy(), other.coefs[k].numpy())
-                for k in self.coefs
+                # false-positive on 'torch.equal'; pylint: disable=no-member
+                torch.equal(tns, other.coefs[key].to(tns.device))
+                for key, tns in self.coefs.items()
             )
         return valid
 
diff --git a/declearn/model/torch/utils/__init__.py b/declearn/model/torch/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..225118fa3cc035d9bac36542a59a389906587a6e
--- /dev/null
+++ b/declearn/model/torch/utils/__init__.py
@@ -0,0 +1,27 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils for torch backend support code.
+
+GPU/CPU backing device management utils:
+* AutoDeviceModule:
+    Wrapper for a `torch.nn.Module`, automating device-management.
+* select_device:
+    Select a backing device to use based on inputs and availability.
+"""
+
+from ._gpu import AutoDeviceModule, select_device
diff --git a/declearn/model/torch/utils/_gpu.py b/declearn/model/torch/utils/_gpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..59ed28dc340721d5e92d0e65a851592f4c18fffc
--- /dev/null
+++ b/declearn/model/torch/utils/_gpu.py
@@ -0,0 +1,140 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils for GPU support and device management in torch."""
+
+import warnings
+from typing import Any, Optional
+
+import torch
+
+
+__all__ = [
+    "AutoDeviceModule",
+    "select_device",
+]
+
+
+def select_device(
+    gpu: bool,
+    idx: Optional[int] = None,
+) -> torch.device:  # pylint: disable=no-member
+    """Select a backing device to use based on inputs and availability.
+
+    Parameters
+    ----------
+    gpu: bool
+        Whether to select a GPU device rather than the CPU one.
+    idx: int or None, default=None
+        Optional pre-selected GPU device index. Only used when `gpu=True`.
+        If `idx is None` or exceeds the number of available GPU devices,
+        use `torch.cuda.current_device()`.
+
+    Warns
+    -----
+    UserWarning:
+        If `gpu=True` but no GPU is available.
+        If `idx` exceeds the number of available GPU devices.
+
+    Returns
+    -------
+    device: torch.device
+        Selected torch device, with type "cpu" or "cuda".
+    """
+    # Case when instructed to use the CPU device.
+    if not gpu:
+        return torch.device("cpu")  # pylint: disable=no-member
+    # Case when no GPU is available: warn and use the CPU instead.
+    if gpu and not torch.cuda.is_available():
+        warnings.warn(
+            "Cannot use a GPU device: either CUDA is unavailable "
+            "or no GPU is visible to torch."
+        )
+        return torch.device("cpu")  # pylint: disable=no-member
+    # Case when the desired GPU is invalid: select another one.
+    if (idx or 0) >= torch.cuda.device_count():
+        warnings.warn(
+            f"Cannot use GPU device n°{idx}: index is out-of-range.\n"
+            f"Using GPU device n°{torch.cuda.current_device()} instead."
+        )
+        idx = None
+    # Return the selected or auto-selected GPU device index.
+    if idx is None:
+        idx = torch.cuda.current_device()
+    return torch.device("cuda", index=idx)  # pylint: disable=no-member
+
+
+class AutoDeviceModule(torch.nn.Module):
+    """Wrapper for a `torch.nn.Module`, automating device-management.
+
+    This `torch.nn.Module` subclass enables wrapping another one, and
+    provides:
+    * a `device` attribute (and instantiation parameter) indicating
+      where the wrapped module is placed
+    * automatic placement of input tensors on that device as part of
+      `forward` calls to the module
+    * a `set_device` method to change the device and move the wrapped
+      module to it
+
+    This aims at internalizing device-management boilerplate code.
+    The wrapped module is assigned to the `module` attribute and thus
+    can be accessed directly.
+    """
+
+    def __init__(
+        self,
+        module: torch.nn.Module,
+        device: torch.device,  # pylint: disable=no-member
+    ) -> None:
+        """Wrap a torch Module into an AutoDeviceModule.
+
+        Parameters
+        ----------
+        module: torch.nn.Module
+            Torch module that needs wrapping.
+        device: torch.device
+            Torch device where to place the wrapped module and computations.
+        """
+        super().__init__()
+        self.device = device
+        self.module = module.to(self.device)
+
+    def forward(self, *inputs: Any) -> torch.Tensor:
+        """Run the forward computation, automating device-placement of inputs.
+
+        Please refer to `self.module.forward` for details on the wrapped
+        module's forward specifications.
+        """
+        inputs = tuple(
+            x.to(self.device) if isinstance(x, torch.Tensor) else x
+            for x in inputs
+        )
+        return self.module(*inputs)
+
+    def set_device(
+        self,
+        device: torch.device,  # pylint: disable=no-member
+    ) -> None:
+        """Move the wrapped module to a pre-selected torch device.
+
+        Parameters
+        ----------
+        device: torch.device
+           Torch device where to place the wrapped module and computations.
+        """
+        self.device = device
+        self.module.to(device)
diff --git a/declearn/test_utils/_gen_ssl.py b/declearn/test_utils/_gen_ssl.py
index 5c0cd8f395348a91289726dfd328b87e4a83dce7..3c0661a6790a199059155f1456c132a4bbdb679b 100644
--- a/declearn/test_utils/_gen_ssl.py
+++ b/declearn/test_utils/_gen_ssl.py
@@ -20,7 +20,7 @@
 import os
 import shlex
 import subprocess
-from typing import Optional, Tuple
+from typing import Collection, Optional, Tuple
 
 
 __all__ = [
@@ -32,20 +32,23 @@ def generate_ssl_certificates(
     folder: str = ".",
     c_name: str = "localhost",
     password: Optional[str] = None,
+    alt_ips: Optional[Collection[str]] = None,
+    alt_dns: Optional[Collection[str]] = None,
 ) -> Tuple[str, str, str]:
-    """Generate self-signed SSL certificates.
+    """Generate a self-signed CA and a CA-signed SSL certificate.
+
+    This function is intended to be used for testing and/or in
+    demonstration contexts, whereas real-life applications are
+    expected to use certificates signed by a trusted CA.
 
     This functions orchestrates calls to the system's `openssl`
     command in order to generate and self-sign SSL certificate
     and private-key files that may be used to encrypt network
     communications, notably for declearn.
 
-    Note that as the certificate is self-signed, it will most
-    probably not (and actually should not) be trusted in any
-    other context than when ran on an internal network or the
-    localhost. Hence this function is intended to be used in
-    testing and demonstration contexts, whereas any real-life
-    application requires a certificate signed by a trusted CA.
+    More precisely, it generates:
+    - a self-signed root certificate authority (CA)
+    - a server certificate signed by the former CA
 
     Parameters
     ----------
@@ -53,10 +56,15 @@ def generate_ssl_certificates(
         Path to the folder where to create the intermediate
         and final certificate and key PEM files.
     c_name: str
-        CommonName value for the server certificate, i.e. name or
-        IP address that will be requested by clients to access it.
+        Main domain name or IP for which the certificate is created.
     password: str or None, default=None
         Optional password used to encrypt generated private keys.
+    alt_ips: collection[str] or None, default=None
+        Optional list of additional IP addresses to certify.
+        This is only implemented for OpenSSL >= 3.0.
+    alt_dns: collection[str] or None, default=None
+        Optional list of additional domain names to certify.
+        This is only implemented for OpenSSL >= 3.0.
 
     Returns
     -------
@@ -64,36 +72,89 @@ def generate_ssl_certificates(
         Path to the client-required CA certificate PEM file.
     sv_cert: str
         Path to the server's certificate PEM file.
-    sv_priv: str
+    sv_pkey: str
         Path to the server's private key PEM file.
     """
-    # Generate self-signed CA certificate and private key.
-    ca_priv = os.path.join(folder, "ca-pkey.pem")
+    try:
+        proc = subprocess.run(
+            ["openssl", "version"], check=True, capture_output=True
+        )
+    except (subprocess.CalledProcessError, FileNotFoundError) as exc:
+        raise RuntimeError("Failed to parse openssl version.") from exc
+    old = proc.stdout.decode().startswith("OpenSSL 1")
+    if (alt_ips or alt_dns) and old:
+        raise RuntimeError(
+            "Cannot add subject alternative names with OpenSSL version <3.0."
+        )
+    # Generate a self-signed root CA.
+    ca_cert, ca_pkey = gen_ssl_ca(folder, password)
+    # Generate a server CSR and a private key.
+    sv_csrq, sv_pkey = gen_ssl_csr(folder, c_name, alt_ips, alt_dns, password)
+    # Sign the CSR into a server certificate using the root CA.
+    sv_cert = gen_ssl_cert(folder, sv_csrq, ca_cert, ca_pkey, password, old)
+    # Return paths that are used by declearn network-communication endpoints.
+    return ca_cert, sv_cert, sv_pkey
+
+
+def gen_ssl_ca(
+    folder: str,
+    password: Optional[str] = None,
+) -> Tuple[str, str]:
+    """Generate a self-signed CA certificate and its private key."""
+    # Set up the command to generate the self-signed CA and its key.
+    ca_pkey = os.path.join(folder, "ca-pkey.pem")
     ca_cert = os.path.join(folder, "ca-cert.pem")
     cmd = (
         "openssl req -x509 -newkey rsa:4096 -sha256 -days 365 "
-        + f"-keyout {ca_priv} -out {ca_cert} "
+        + f"-keyout {ca_pkey} -out {ca_cert} "
         + (f"-passout pass:{password} " if password else "-nodes ")
         + '-subj "/C=FR/L=Lille/O=Inria/OU=Magnet/CN=SelfSignedCA"'
     )
+    # Run the command and return the paths to the created files.
     subprocess.run(shlex.split(cmd), check=True, capture_output=True)
-    # Generate server private key and CSR (certificate signing request).
-    sv_priv = os.path.join(folder, "server-pkey.pem")
-    sv_csrq = os.path.join(folder, "server-req.pem")
+    return ca_cert, ca_pkey
+
+
+def gen_ssl_csr(
+    folder: str,
+    c_name: str,
+    alt_ips: Optional[Collection[str]] = None,
+    alt_dns: Optional[Collection[str]] = None,
+    password: Optional[str] = None,
+) -> Tuple[str, str]:
+    """Generate a CSR (certificate signing request) and its private key."""
+    sv_pkey = os.path.join(folder, "server-pkey.pem")
+    sv_csrq = os.path.join(folder, "server-csrq.pem")
     cmd = (
         "openssl req -newkey rsa:4096 "
-        + f"-keyout {sv_priv} -out {sv_csrq} "
-        + (f"-passout pass:{password} " if password else "-nodes ")
-        + f'-subj "/C=FR/L=Lille/O=Inria/OU=Magnet/CN={c_name}"'
+        + f"-keyout {sv_pkey} -out {sv_csrq} "
+        + f"-subj /C=FR/L=Lille/O=Inria/OU=Magnet/CN={c_name}"
+        + (f" -passout pass:{password}" if password else " -nodes")
     )
+    alt_names = [f"IP.{i}:{x}" for i, x in enumerate(alt_ips or tuple(), 1)]
+    alt_names += [f"DNS.{i}:{x}" for i, x in enumerate(alt_dns or tuple(), 1)]
+    if alt_names:
+        cmd += " -addext subjectAltName=" + ",".join(alt_names)
     subprocess.run(shlex.split(cmd), check=True, capture_output=True)
-    # Generate self-signed server certificate.
+    return sv_csrq, sv_pkey
+
+
+def gen_ssl_cert(
+    folder: str,
+    sv_csrq: str,
+    ca_cert: str,
+    ca_pkey: str,
+    password: Optional[str] = None,
+    old: bool = False,  # flag when using an old version (OpenSSL 1.x)
+) -> str:
+    """Sign a CSR into a certificate using a given CA."""
+    # private method; pylint: disable=too-many-arguments
     sv_cert = os.path.join(folder, "server-cert.pem")
     cmd = (
         f"openssl x509 -req -sha256 -days 30 -in {sv_csrq} -out {sv_cert} "
-        + f"-CA {ca_cert} -CAkey {ca_priv} -CAcreateserial"
-        + (f" -passin pass:{password} " if password else "")
+        + f"-CA {ca_cert} -CAkey {ca_pkey} -CAcreateserial"
+        + (" -copy_extensions=copy" if not old else "")
+        + (f" -passin pass:{password}" if password else "")
     )
     subprocess.run(shlex.split(cmd), check=True, capture_output=True)
-    # Return paths that are used in tests.
-    return ca_cert, sv_cert, sv_priv
+    return sv_cert
diff --git a/declearn/test_utils/_vectors.py b/declearn/test_utils/_vectors.py
index d789cec4933795a0364b3083164ff449b9f79c57..27f821674acec621db1fbf14d569d0c9505b55ea 100644
--- a/declearn/test_utils/_vectors.py
+++ b/declearn/test_utils/_vectors.py
@@ -87,7 +87,8 @@ class GradientsTestCase:
             return array
         if self.framework == "tensorflow":
             tensorflow = importlib.import_module("tensorflow")
-            return tensorflow.convert_to_tensor(array)
+            with tensorflow.device("CPU"):
+                return tensorflow.convert_to_tensor(array)
         if self.framework == "torch":
             torch = importlib.import_module("torch")
             return torch.from_numpy(array)
diff --git a/declearn/utils/__init__.py b/declearn/utils/__init__.py
index 00bf9fc14932fef61b6eaa4ab9b8aefeb404374c..bee5e57cbbe3f30f9b6acd2753ad9096c3c25ab8 100644
--- a/declearn/utils/__init__.py
+++ b/declearn/utils/__init__.py
@@ -67,6 +67,17 @@ And examples of pre-registered (de)serialization functions:
 * (deserialize_numpy, serialize_numpy):
     Pair of functions to (un)pack a numpy ndarray as JSON-serializable data.
 
+Device-policy utils
+-------------------
+Utils to access or update parameters defining a global device-selection policy.
+
+* DevicePolicy:
+    Dataclass to store parameters defining a device-selection policy.
+* get_device_policy:
+    Access a copy of the current global device policy.
+* set_device_policy:
+    Update the current global device policy.
+
 Miscellaneous
 -------------
 
@@ -84,6 +95,11 @@ from ._dataclass import (
     dataclass_from_func,
     dataclass_from_init,
 )
+from ._device_policy import (
+    DevicePolicy,
+    get_device_policy,
+    set_device_policy,
+)
 from ._json import (
     add_json_support,
     json_dump,
diff --git a/declearn/utils/_device_policy.py b/declearn/utils/_device_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..f26ea9653d4018d88c23e0020fb9aee38b968a72
--- /dev/null
+++ b/declearn/utils/_device_policy.py
@@ -0,0 +1,122 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils to define a computation device policy.
+
+This private submodule defines:
+* A dataclass defining a standard to hold a device-selection policy.
+* A private global variable holding the current package-wise device policy.
+* A public pair of functions acting as a getter and a setter for that variable.
+"""
+
+import dataclasses
+from typing import Optional
+
+
+__all__ = [
+    "DevicePolicy",
+    "get_device_policy",
+    "set_device_policy",
+]
+
+
+@dataclasses.dataclass
+class DevicePolicy:
+    """Dataclass to store parameters defining a device-selection policy.
+
+    This class merely defines a shared language of keyword-arguments to
+    define whether to back computations on CPU or on a GPU device.
+
+    It is meant to be instantiated as a global variable that holds that
+    information, and can be accessed by framework-specific backend code
+    so as to take the required steps towards implementing that policy.
+
+    To access or update the current global DevicePolicy, please use the
+    getter and setter functions: `declearn.utils.get_device_policy` and
+    `declearn.utils.set_device_policy`.
+
+    Attributes
+    ----------
+    gpu: bool
+        Whether to use a GPU device rather than the CPU one to back data
+        and computations. If no GPU is available, use CPU with a warning.
+    idx: int or None
+        Optional index of the GPU device to use.
+        If None, select one arbitrarily.
+        If this index exceeds the number of available GPUs, select one
+        arbitrarily, with a warning.
+    """
+
+    gpu: bool
+    idx: Optional[int]
+
+    def __post_init__(self) -> None:
+        if not isinstance(self.gpu, bool):
+            raise TypeError(
+                f"DevicePolicy 'gpu' should be a bool, not '{type(self.gpu)}'."
+            )
+        if not (self.idx is None or isinstance(self.idx, int)):
+            raise TypeError(
+                "DevicePolicy 'idx' should be None or an int, not "
+                f"'{type(self.idx)}'."
+            )
+
+
+DEVICE_POLICY = DevicePolicy(gpu=True, idx=None)
+
+
+def get_device_policy() -> DevicePolicy:
+    """Return a copy of the current global device policy.
+
+    This method is meant to be used:
+    - By end-users that wish to check the current device policy.
+    - By the backend code of framework-specific objects so as to
+      take the required steps towards implementing that policy.
+
+    To update the current policy, use `declearn.utils.set_device_policy`.
+
+    Returns
+    -------
+    policy: DevicePolicy
+        DevicePolicy dataclass instance, wrapping parameters that specify
+        the device policy to be enforced by Model and Vector to properly
+        place data and computations.
+    """
+    return DevicePolicy(**dataclasses.asdict(DEVICE_POLICY))
+
+
+def set_device_policy(
+    gpu: bool,
+    idx: Optional[int] = None,
+) -> None:
+    """Update the current global device policy.
+
+    To access the current policy, use `declearn.utils.set_device_policy`.
+
+    Parameters
+    ----------
+    gpu: bool
+        Whether to use a GPU device rather than the CPU one to back data
+        and computations. If no GPU is available, use CPU with a warning.
+    idx: int or None, default=None
+        Optional index of the GPU device to use.
+        If this index exceeds the number of available GPUs, select one
+        arbitrarily, with a warning.
+    """
+    # Using a global statement to have a proper setter to a private variable.
+    global DEVICE_POLICY  # pylint: disable=global-statement
+    DEVICE_POLICY = DevicePolicy(gpu, idx)
diff --git a/test/metrics/test_binary_apr.py b/test/metrics/test_binary_apr.py
index 0473f550d12dbd8811f56e46a52dcb9cd64ef701..803b7fb6039421058f1e9f77e6b90d5202469830 100644
--- a/test/metrics/test_binary_apr.py
+++ b/test/metrics/test_binary_apr.py
@@ -33,6 +33,7 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 from metric_testing import MetricTestCase, MetricTestSuite
 sys.path.pop()
 # pylint: enable=wrong-import-order, wrong-import-position
+# fmt: on
 
 
 @pytest.fixture(name="test_case")
@@ -45,9 +46,8 @@ def test_case_fixture(
         _test_case_1d(thresh) if case == "1d" else _test_case_2d(thresh)
     )
     # Add the F1-score to expected scores.
-    scores["f-score"] = (
-        (states["tpos"] + states["tpos"])
-        / (states["tpos"] + states["tpos"] + states["fpos"] + states["fneg"])
+    scores["f-score"] = (states["tpos"] + states["tpos"]) / (
+        states["tpos"] + states["tpos"] + states["fpos"] + states["fneg"]
     )
     # Add the confusion matrix to expected scores.
     confmt = [
@@ -82,16 +82,12 @@ def _test_case_1d(
         0.3: {"tpos": 2.0, "tneg": 0.0, "fpos": 2.0, "fneg": 0.0},
         0.5: {"tpos": 2.0, "tneg": 1.0, "fpos": 1.0, "fneg": 0.0},
         0.7: {"tpos": 1.0, "tneg": 1.0, "fpos": 1.0, "fneg": 1.0},
-    }[
-        thresh
-    ]
+    }[thresh]
     scores = {
         0.3: {"accuracy": 2 / 4, "precision": 2 / 4, "recall": 2 / 2},
         0.5: {"accuracy": 3 / 4, "precision": 2 / 3, "recall": 2 / 2},
         0.7: {"accuracy": 2 / 4, "precision": 1 / 2, "recall": 1 / 2},
-    }[
-        thresh
-    ]
+    }[thresh]
     return inputs, states, scores  # type: ignore
 
 
@@ -124,16 +120,12 @@ def _test_case_2d(
         0.3: {"tpos": 6.0, "tneg": 5.0, "fpos": 1.0, "fneg": 0.0},
         0.5: {"tpos": 4.0, "tneg": 5.0, "fpos": 1.0, "fneg": 2.0},
         0.7: {"tpos": 3.0, "tneg": 6.0, "fpos": 0.0, "fneg": 3.0},
-    }[
-        thresh
-    ]
+    }[thresh]
     scores = {
         0.3: {"accuracy": 11 / 12, "precision": 6 / 7, "recall": 6 / 6},
         0.5: {"accuracy": 9 / 12, "precision": 4 / 5, "recall": 4 / 6},
         0.7: {"accuracy": 9 / 12, "precision": 3 / 3, "recall": 3 / 6},
-    }[
-        thresh
-    ]
+    }[thresh]
     return inputs, states, scores  # type: ignore
 
 
diff --git a/test/metrics/test_binary_roc.py b/test/metrics/test_binary_roc.py
index 880b843183a3f8def68e45d2d20c8e4e3c0d3bcf..588b8421ac30d8ceea318564e0a5c0018ed30c64 100644
--- a/test/metrics/test_binary_roc.py
+++ b/test/metrics/test_binary_roc.py
@@ -34,6 +34,7 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 from metric_testing import MetricTestCase, MetricTestSuite
 sys.path.pop()
 # pylint: enable=wrong-import-order, wrong-import-position
+# fmt: on
 
 
 @pytest.fixture(name="test_case")
@@ -81,11 +82,13 @@ def test_case_fixture(
     )
 
 
-def _test_case_1d() -> Tuple[
-    Dict[str, np.ndarray],
-    Dict[str, Union[float, np.ndarray]],
-    Dict[str, Union[float, np.ndarray]],
-]:
+def _test_case_1d() -> (
+    Tuple[
+        Dict[str, np.ndarray],
+        Dict[str, Union[float, np.ndarray]],
+        Dict[str, Union[float, np.ndarray]],
+    ]
+):
     """Return a test case with 1-D samples (standard binary classif)."""
     # similar inputs as for Binary APR; pylint: disable=duplicate-code
     inputs = {
@@ -125,11 +128,13 @@ def _test_case_1d() -> Tuple[
     return inputs, states, scores
 
 
-def _test_case_2d() -> Tuple[
-    Dict[str, np.ndarray],
-    Dict[str, Union[float, np.ndarray]],
-    Dict[str, Union[float, np.ndarray]],
-]:
+def _test_case_2d() -> (
+    Tuple[
+        Dict[str, np.ndarray],
+        Dict[str, Union[float, np.ndarray]],
+        Dict[str, Union[float, np.ndarray]],
+    ]
+):
     """Return a test case with 2-D samples (multilabel binary classif)."""
     # similar inputs as for Binary APR; pylint: disable=duplicate-code
     inputs = {
diff --git a/test/metrics/test_mae_mse.py b/test/metrics/test_mae_mse.py
index 823e89bb97f156177ebd5959c6cd3973f2f08813..13fa0e930dff2f50253c45ae780b18a50a965fcd 100644
--- a/test/metrics/test_mae_mse.py
+++ b/test/metrics/test_mae_mse.py
@@ -34,13 +34,13 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 from metric_testing import MetricTestCase, MetricTestSuite
 sys.path.pop()
 # pylint: enable=wrong-import-order, wrong-import-position
+# fmt: on
 
 
 @pytest.fixture(name="test_case")
 def test_case_fixture(
     case: Literal["mae", "mse"],
     weighted: bool,
-    # n_dims: int,
 ) -> MetricTestCase:
     """Return a test case for a MAE or MSE metric, with opt. sample weights."""
     # Generate random inputs and compute the expected sum of errors.
diff --git a/test/metrics/test_multi_apr.py b/test/metrics/test_multi_apr.py
index 659bf706ad1d44bfd91dd9b1bbcda1bbdfc7453e..83c0cb78f8d3cd3d82309fc4dd27028153903d5e 100644
--- a/test/metrics/test_multi_apr.py
+++ b/test/metrics/test_multi_apr.py
@@ -32,6 +32,7 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 from metric_testing import MetricTestCase, MetricTestSuite
 sys.path.pop()
 # pylint: enable=wrong-import-order, wrong-import-position
+# fmt: on
 
 
 @pytest.fixture(name="test_case")
diff --git a/test/metrics/test_rsquared.py b/test/metrics/test_rsquared.py
new file mode 100644
index 0000000000000000000000000000000000000000..56550b7fa45372057e7c5d8990fa502ebfb1c2b3
--- /dev/null
+++ b/test/metrics/test_rsquared.py
@@ -0,0 +1,103 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit and functional tests for the R^2 Metric subclasses."""
+
+import os
+import sys
+from typing import Dict, Union
+
+import numpy as np
+import pytest
+from sklearn.metrics import mean_squared_error, r2_score  # type: ignore
+
+from declearn.metrics import RSquared
+
+# dirty trick to import from `metric_testing.py`;
+# fmt: off
+# pylint: disable=wrong-import-order, wrong-import-position
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+from metric_testing import MetricTestCase
+from test_mae_mse import MeanMetricTestSuite
+sys.path.pop()
+# pylint: enable=wrong-import-order, wrong-import-position
+# fmt: on
+
+
+@pytest.fixture(name="test_case")
+def test_case_fixture(
+    weighted: bool,
+) -> MetricTestCase:
+    """Return a test case for an R2 metric, with opt. sample weights."""
+    # Generate random inputs and sample weights.
+    y_true = np.random.normal(scale=1.0, size=32)
+    y_pred = y_true + np.random.normal(scale=0.5, size=32)
+    s_wght = np.abs(np.random.normal(size=32)) if weighted else np.ones((32,))
+    inputs = {"y_true": y_true, "y_pred": y_pred, "s_wght": s_wght}
+    # Compute expected intermediate and final results.
+    mse = mean_squared_error(y_true, y_pred, sample_weight=s_wght)
+    states = {
+        "sum_of_squared_errors": s_wght.sum() * mse,
+        "sum_of_squared_labels": np.sum(s_wght * np.square(y_true)),
+        "sum_of_labels": np.sum(s_wght * y_true),
+        "sum_of_weights": s_wght.sum(),
+    }
+    scores = {
+        "r2": r2_score(y_true, y_pred, sample_weight=s_wght)
+    }  # type: Dict[str, Union[float, np.ndarray]]
+    # Compute derived aggregation results. Wrap as a test case and return.
+    agg_states = {key: 2 * val for key, val in states.items()}
+    agg_scores = scores.copy()
+    metric = RSquared()
+    return MetricTestCase(
+        metric, inputs, states, scores, agg_states, agg_scores
+    )
+
+
+@pytest.mark.parametrize("weighted", [False, True], ids=["base", "weighted"])
+class TestRSquared(MeanMetricTestSuite):
+    """Unit tests for `RSquared` Metric."""
+
+    @staticmethod
+    def dicts_equal(
+        dicts_a: Dict[str, Union[float, np.ndarray]],
+        dicts_b: Dict[str, Union[float, np.ndarray]],
+    ) -> bool:
+        # Override the base behaviour: allow for very small (10^-12)
+        # numerical imprecisions in values' equality assertions.
+        try:
+            assert dicts_a.keys() == dicts_b.keys()
+            for key, v_a in dicts_a.items():
+                v_b = dicts_b[key]
+                assert np.allclose(v_a, v_b, rtol=0, atol=1e-12)
+        except AssertionError:
+            return False
+        return True
+
+    def test_zero_result(self, test_case: MetricTestCase) -> None:
+        """Test that `get_results` works with zero-valued divisor."""
+        metric = test_case.metric
+        # Case when no samples have been seen: return 0.
+        assert metric.get_result() == {metric.name: 0.0}
+        # Case when SSt is null but SSr is not: return 0.
+        states = getattr(metric, "_states")
+        states["sum_of_weights"] = 1.0
+        states["sum_of_squared_errors"] = 0.1
+        assert metric.get_result() == {metric.name: 0.0}
+        # Case when SSt and SSr are null but samples have been seen: return 1.
+        states["sum_of_squared_errors"] = 0.0
+        assert metric.get_result() == {metric.name: 1.0}
diff --git a/test/model/model_testing.py b/test/model/model_testing.py
index 8becfd3f7919694bf52b9e6bcbae0ce094773a4d..c103019c2982ba3835dc4b7aa7abca71f86b7703 100644
--- a/test/model/model_testing.py
+++ b/test/model/model_testing.py
@@ -18,7 +18,7 @@
 """Shared testing code for TensorFlow and Torch models' unit tests."""
 
 import json
-from typing import Any, List, Protocol, Tuple, Type, Union
+from typing import Any, Generic, List, Protocol, Tuple, Type, TypeVar, Union
 
 import numpy as np
 
@@ -27,10 +27,13 @@ from declearn.typing import Batch
 from declearn.utils import json_pack, json_unpack
 
 
-class ModelTestCase(Protocol):
+VectorT = TypeVar("VectorT", bound=Vector)
+
+
+class ModelTestCase(Protocol, Generic[VectorT]):
     """TestCase fixture-provider protocol."""
 
-    vector_cls: Type[Vector]
+    vector_cls: VectorT
     tensor_cls: Union[Type[Any], Tuple[Type[Any], ...]]
 
     @staticmethod
@@ -51,6 +54,12 @@ class ModelTestCase(Protocol):
     ) -> Model:
         """Suited toy binary-classification model."""
 
+    def assert_correct_device(
+        self,
+        vector: VectorT,
+    ) -> None:
+        """Raise if a vector is backed on the wrong type of device."""
+
 
 class ModelTestSuite:
     """Unit tests for a declearn Model."""
@@ -64,6 +73,7 @@ class ModelTestSuite:
         config = json.dumps(model.get_config())
         other = model.from_config(json.loads(config))
         assert model.get_config() == other.get_config()
+        assert model.device_policy == other.device_policy
 
     def test_get_set_weights(
         self,
@@ -75,7 +85,12 @@ class ModelTestSuite:
         assert isinstance(w_srt, test_case.vector_cls)
         w_end = w_srt + 1.0
         model.set_weights(w_end)
-        assert model.get_weights() == w_end
+        w_upd = model.get_weights()
+        assert w_upd == w_end
+        # Check that weight tensors are properly placed.
+        test_case.assert_correct_device(w_srt)
+        test_case.assert_correct_device(w_end)
+        test_case.assert_correct_device(w_upd)
 
     def test_compute_batch_gradients(
         self,
@@ -113,7 +128,13 @@ class ModelTestSuite:
         np_grads = model.compute_batch_gradients(np_batch)  # type: ignore
         assert isinstance(np_grads, test_case.vector_cls)
         my_grads = model.compute_batch_gradients(my_batch)
-        assert my_grads == np_grads
+        # Allow for a numerical imprecision of 10^-9.
+        diff = my_grads - np_grads
+        max_err = max(
+            np.abs(test_case.to_numpy(weight)).max()
+            for weight in diff.coefs.values()
+        )
+        assert max_err < 1e-8
 
     def test_compute_batch_gradients_clipped(
         self,
@@ -137,6 +158,9 @@ class ModelTestSuite:
             for k in grads_a.coefs
         )
         assert grads_a != grads_b
+        # Check that gradients are properly placed.
+        test_case.assert_correct_device(grads_a)
+        test_case.assert_correct_device(grads_b)
 
     def test_apply_updates(
         self,
@@ -156,9 +180,14 @@ class ModelTestSuite:
         # Check up to 1e-6 numerical precision due to tensor/np conversion.
         w_end = model.get_weights(trainable=True)
         assert w_end != w_srt
-        updt = [test_case.to_numpy(val) for val in grads.coefs.values()]
-        diff = list((w_end - w_srt).coefs.values())
-        assert all(np.abs(a - b).max() < 1e-6 for a, b in zip(diff, updt))
+        diff = (w_end - w_srt) - grads
+        assert all(
+            np.abs(test_case.to_numpy(weight)).max() < 1e-6
+            for weight in diff.coefs.values()
+        )
+        # Check that gradients and updated weights are properly placed.
+        test_case.assert_correct_device(grads)
+        test_case.assert_correct_device(w_end)
 
     def test_serialize_gradients(
         self,
diff --git a/test/model/test_sksgd.py b/test/model/test_sksgd.py
index b214f28ac1c0fb595365dec187ca8ac94d52ed71..e53febae8f98d3fd43fbabc61695e9ae366c054b 100644
--- a/test/model/test_sksgd.py
+++ b/test/model/test_sksgd.py
@@ -106,6 +106,12 @@ class SklearnSGDTestCase(ModelTestCase):
         model.initialize(data_info)
         return model
 
+    def assert_correct_device(
+        self,
+        vector: NumpyVector,
+    ) -> None:
+        pass
+
 
 @pytest.fixture(name="test_case")
 def fixture_test_case(
diff --git a/test/model/test_tflow.py b/test/model/test_tflow.py
index 67bb3c707631594e25f6984e986a27a8479c78b2..49ae73225dc910ee43652cf1d0cbce6847a335b8 100644
--- a/test/model/test_tflow.py
+++ b/test/model/test_tflow.py
@@ -33,6 +33,7 @@ except ModuleNotFoundError:
 
 from declearn.model.tensorflow import TensorflowModel, TensorflowVector
 from declearn.typing import Batch
+from declearn.utils import set_device_policy
 
 # dirty trick to import from `model_testing.py`;
 # pylint: disable=wrong-import-order, wrong-import-position
@@ -69,11 +70,16 @@ class TensorflowTestCase(ModelTestCase):
     def __init__(
         self,
         kind: Literal["MLP", "MLP-tune", "RNN", "CNN"],
+        device: Literal["CPU", "GPU"],
     ) -> None:
         """Specify the desired model architecture."""
         if kind not in ("MLP", "MLP-tune", "RNN", "CNN"):
             raise ValueError(f"Invalid keras test architecture: '{kind}'.")
+        if device not in ("CPU", "GPU"):
+            raise ValueError(f"Invalid device choice for test: '{device}'.")
         self.kind = kind
+        self.device = device
+        set_device_policy(gpu=(device == "GPU"), idx=0)
 
     @staticmethod
     def to_numpy(
@@ -81,7 +87,7 @@ class TensorflowTestCase(ModelTestCase):
     ) -> np.ndarray:
         """Convert an input tensor to a numpy array."""
         assert isinstance(tensor, tf.Tensor)
-        return tensor.numpy()  # type: ignore
+        return tensor.numpy()
 
     @property
     def dataset(
@@ -136,15 +142,32 @@ class TensorflowTestCase(ModelTestCase):
         tfmod.build(shape)  # as model is built, no data_info is required
         return TensorflowModel(tfmod, loss="binary_crossentropy", metrics=None)
 
+    def assert_correct_device(
+        self,
+        vector: TensorflowVector,
+    ) -> None:
+        """Raise if a vector is backed on the wrong type of device."""
+        name = f"{self.device}:0"
+        assert all(
+            tensor.device.endswith(name) for tensor in vector.coefs.values()
+        )
+
 
 @pytest.fixture(name="test_case")
 def fixture_test_case(
-    kind: Literal["MLP", "MLP-tune", "RNN", "CNN"]
+    kind: Literal["MLP", "MLP-tune", "RNN", "CNN"],
+    device: Literal["CPU", "GPU"],
 ) -> TensorflowTestCase:
     """Fixture to access a TensorflowTestCase."""
-    return TensorflowTestCase(kind)
+    return TensorflowTestCase(kind, device)
+
+
+DEVICES = ["CPU"]
+if tf.config.list_logical_devices("GPU"):
+    DEVICES.append("GPU")
 
 
+@pytest.mark.parametrize("device", DEVICES)
 @pytest.mark.parametrize("kind", ["MLP", "MLP-tune", "RNN", "CNN"])
 class TestTensorflowModel(ModelTestSuite):
     """Unit tests for declearn.model.tensorflow.TensorflowModel."""
@@ -179,3 +202,17 @@ class TestTensorflowModel(ModelTestSuite):
         model.set_weights(w_trn, trainable=True)
         with pytest.raises(KeyError):
             model.set_weights(model.get_weights(), trainable=True)
+
+    def test_proper_model_placement(
+        self,
+        test_case: TensorflowTestCase,
+    ) -> None:
+        """Check that at instantiation, model weights are properly placed."""
+        model = test_case.model
+        policy = model.device_policy
+        assert policy.gpu == (test_case.device == "GPU")
+        assert policy.idx == 0
+        tfmod = getattr(model, "_model")
+        device = f"{test_case.device}:0"
+        for var in tfmod.weights:
+            assert var.device.endswith(device)
diff --git a/test/model/test_torch.py b/test/model/test_torch.py
index d14b109fbc1d5e837b295affae687f1d078d5e1d..15031107378231ef5b1a4ec62714da1e7988a53e 100644
--- a/test/model/test_torch.py
+++ b/test/model/test_torch.py
@@ -17,6 +17,7 @@
 
 """Unit tests for TorchModel."""
 
+import json
 import sys
 from typing import Any, List, Literal, Tuple
 
@@ -30,6 +31,7 @@ except ModuleNotFoundError:
 
 from declearn.model.torch import TorchModel, TorchVector
 from declearn.typing import Batch
+from declearn.utils import set_device_policy
 
 # dirty trick to import from `model_testing.py`;
 # pylint: disable=wrong-import-order, wrong-import-position
@@ -84,11 +86,14 @@ class TorchTestCase(ModelTestCase):
     def __init__(
         self,
         kind: Literal["MLP", "MLP-tune", "RNN", "CNN"],
+        device: Literal["CPU", "GPU"],
     ) -> None:
         """Specify the desired model architecture."""
         if kind not in ("MLP", "MLP-tune", "RNN", "CNN"):
             raise ValueError(f"Invalid torch test architecture: '{kind}'.")
         self.kind = kind
+        self.device = device
+        set_device_policy(gpu=(device == "GPU"), idx=0)
 
     @staticmethod
     def to_numpy(
@@ -96,7 +101,7 @@ class TorchTestCase(ModelTestCase):
     ) -> np.ndarray:
         """Convert an input tensor to a numpy array."""
         assert isinstance(tensor, torch.Tensor)
-        return tensor.numpy()  # type: ignore
+        return tensor.cpu().numpy()
 
     @property
     def dataset(
@@ -135,7 +140,7 @@ class TorchTestCase(ModelTestCase):
         elif self.kind == "RNN":
             stack = [
                 torch.nn.Embedding(100, 32),
-                torch.nn.LSTM(32, 16, batch_first=True),  # type: ignore
+                torch.nn.LSTM(32, 16, batch_first=True),
                 ExtractLSTMFinalOutput(),
                 torch.nn.Tanh(),
                 torch.nn.Linear(16, 1),
@@ -156,15 +161,32 @@ class TorchTestCase(ModelTestCase):
         nnmod = torch.nn.Sequential(*stack)
         return TorchModel(nnmod, loss=torch.nn.BCELoss())
 
+    def assert_correct_device(
+        self,
+        vector: TorchVector,
+    ) -> None:
+        """Raise if a vector is backed on the wrong type of device."""
+        dev_type = "cuda" if self.device == "GPU" else "cpu"
+        assert all(
+            tensor.device.type == dev_type for tensor in vector.coefs.values()
+        )
+
 
 @pytest.fixture(name="test_case")
 def fixture_test_case(
     kind: Literal["MLP", "MLP-tune", "RNN", "CNN"],
+    device: Literal["CPU", "GPU"],
 ) -> TorchTestCase:
     """Fixture to access a TorchTestCase."""
-    return TorchTestCase(kind)
+    return TorchTestCase(kind, device)
+
 
+DEVICES = ["CPU"]
+if torch.cuda.device_count():
+    DEVICES.append("GPU")
 
+
+@pytest.mark.parametrize("device", DEVICES)
 @pytest.mark.parametrize("kind", ["MLP", "MLP-tune", "RNN", "CNN"])
 class TestTorchModel(ModelTestSuite):
     """Unit tests for declearn.model.torch.TorchModel."""
@@ -179,12 +201,33 @@ class TestTorchModel(ModelTestSuite):
             #       due to the (de)serialization of a custom nn.Module
             #       the expected model behaviour is, however, correct
             try:
-                super().test_serialization(test_case)
+                self._test_serialization(test_case)
             except AssertionError:
                 pytest.skip(
                     "skipping failed test due to custom nn.Module pickling"
                 )
-        super().test_serialization(test_case)
+        self._test_serialization(test_case)
+
+    def _test_serialization(
+        self,
+        test_case: ModelTestCase,
+    ) -> None:
+        """Check that the model can be JSON-(de)serialized properly.
+
+        This method replaces the parent `test_serialization` one.
+        """
+        # Same setup as in parent test: a model and a config-based other.
+        model = test_case.model
+        config = json.dumps(model.get_config())
+        other = model.from_config(json.loads(config))
+        # Verify that both models have the same device policy.
+        assert model.device_policy == other.device_policy
+        # Verify that both models have a similar structure of modules.
+        mod_a = list(getattr(model, "_model").modules())
+        mod_b = list(getattr(other, "_model").modules())
+        assert len(mod_a) == len(mod_b)
+        assert all(isinstance(a, type(b)) for a, b in zip(mod_a, mod_b))
+        assert all(repr(a) == repr(b) for a, b in zip(mod_a, mod_b))
 
     def test_compute_batch_gradients_clipped(
         self,
@@ -233,3 +276,17 @@ class TestTorchModel(ModelTestSuite):
         with pytest.raises(KeyError):
             model.set_weights(model.get_weights(), trainable=True)
         model.set_weights(w_trn, trainable=True)
+
+    def test_proper_model_placement(
+        self,
+        test_case: TorchTestCase,
+    ) -> None:
+        """Check that at instantiation, model weights are properly placed."""
+        model = test_case.model
+        policy = model.device_policy
+        assert policy.gpu == (test_case.device == "GPU")
+        assert (policy.idx == 0) if policy.gpu else (policy.idx is None)
+        ptmod = getattr(model, "_model").module
+        device_type = "cpu" if test_case.device == "CPU" else "cuda"
+        for param in ptmod.parameters():
+            assert param.device.type == device_type
diff --git a/test/model/test_vector.py b/test/model/test_vector.py
index 6374ad409370145a4f6b2ca4b8a953f3bf9f777e..ca5592a76d25067569903d78b652f202b70fd783 100644
--- a/test/model/test_vector.py
+++ b/test/model/test_vector.py
@@ -32,7 +32,10 @@ from declearn.test_utils import (
     GradientsTestCase,
     list_available_frameworks,
 )
-from declearn.utils import json_pack, json_unpack
+from declearn.utils import json_pack, json_unpack, set_device_policy
+
+
+set_device_policy(gpu=False)  # run Vector unit tests on CPU only
 
 
 @pytest.fixture(name="framework", params=list_available_frameworks())
diff --git a/test/optimizer/test_modules.py b/test/optimizer/test_modules.py
index 0bbde31ad6746bd94b2cc631c4d2b0a991645682..ffd1e3132ffb64691b04ebdd03cdad4bd5d4e355 100644
--- a/test/optimizer/test_modules.py
+++ b/test/optimizer/test_modules.py
@@ -44,7 +44,7 @@ from declearn.test_utils import (
     GradientsTestCase,
     assert_json_serializable_dict,
 )
-from declearn.utils import access_types_mapping
+from declearn.utils import access_types_mapping, set_device_policy
 
 # relative import; pylint: disable=wrong-import-order, wrong-import-position
 # fmt: off
@@ -56,6 +56,8 @@ sys.path.pop()
 
 OPTIMODULE_SUBCLASSES = access_types_mapping(group="OptiModule")
 
+set_device_policy(gpu=False)  # run all OptiModule tests on CPU
+
 
 @pytest.mark.parametrize(
     "cls", OPTIMODULE_SUBCLASSES.values(), ids=OPTIMODULE_SUBCLASSES.keys()