diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py
index 6ab4c88087b4cc775f17aa9dc3b6ec9f9d19b78f..22d72c08f3edce7879722aed5cf052e4d5171d31 100644
--- a/declearn/model/sklearn/_sgd.py
+++ b/declearn/model/sklearn/_sgd.py
@@ -19,10 +19,16 @@
 
 import typing
 import warnings
-from typing import Any, Callable, Dict, Literal, Optional, Set, Tuple, Union
+from typing import (
+    # fmt: off
+    Any, Callable, Dict, Literal, Optional, Set, Tuple, Type, Union
+)
 
 import numpy as np
+import pandas as pd
+import sklearn  # type: ignore
 from numpy.typing import ArrayLike
+from scipy.sparse import spmatrix  # type: ignore
 from sklearn.linear_model import SGDClassifier, SGDRegressor  # type: ignore
 from typing_extensions import Self  # future: import from typing (py >=3.11)
 
@@ -56,6 +62,54 @@ REG_LOSSES = (
 )
 
 
+DataArray = Union[np.ndarray, spmatrix, pd.DataFrame, pd.Series]
+
+
+def select_sgd_model_dtype(
+    model: Union[SGDClassifier, SGDRegressor],
+    dtype: Union[str, np.dtype, Type[np.number]],
+) -> str:
+    """Return the current or future dtype of a scikit-learn SGD model.
+
+    Emit a `RuntimeWarning` if the model has already been initialized
+    with a dtype that does not match the input one, and/or if the
+    scikit-learn version makes it impossible to enforce a non-default
+    dtype.
+    """
+    # Identify the user-defined dtype's name.
+    dtype = _get_dtype_name(dtype)
+    # Warn if the dtype is already defined and does not match the argument.
+    if hasattr(model, "coef_") and (model.coef_.dtype.name != dtype):
+        warnings.warn(
+            f"Cannot enforce dtype '{dtype}' for pre-initialized "
+            f"scikit-learn SGD model (dtype '{model.coef_.dtype.name}').",
+            RuntimeWarning,
+        )
+        return model.coef_.dtype.name
+    # When using scikit-learn <= 1.3, warn about un-settable dtype.
+    if int(sklearn.__version__.split(".", 2)[1]) < 3 and (dtype != "float64"):
+        warnings.warn(
+            "Using scikit-learn <1.3; hence the 'float64' dtype will"
+            f"forcibly be used rather than user-input '{dtype}'.",
+            RuntimeWarning,
+        )
+        return "float64"
+    return dtype
+
+
+def _get_dtype_name(
+    dtype: Union[str, np.dtype, Type[np.number]],
+) -> str:
+    """Gather the name of an input numpy dtype."""
+    if isinstance(dtype, np.dtype):
+        return dtype.name
+    if isinstance(dtype, type) and issubclass(dtype, np.number):
+        return dtype.__name__
+    if isinstance(dtype, str):
+        return dtype
+    raise TypeError("'dtype' should be a str, np.dtype or np.number type.")
+
+
 @register_type(name="SklearnSGDModel", group="Model")
 class SklearnSGDModel(Model):
     """Model wrapper for Scikit-Learn SGDClassifier and SGDRegressor.
@@ -76,6 +130,7 @@ class SklearnSGDModel(Model):
     def __init__(
         self,
         model: Union[SGDClassifier, SGDRegressor],
+        dtype: Union[str, np.dtype, Type[np.number]] = "float64",
     ) -> None:
         """Instantiate a Model interfacing a sklearn SGD-based model.
 
@@ -89,6 +144,14 @@ class SklearnSGDModel(Model):
             Scikit-learn model that needs wrapping for federated training.
             Note that some hyperparameters will be overridden, as will the
             model's existing weights (if any).
+        dtype: str or numpy.dtype or type[np.number], default="float64"
+            Data type to enforce for the model's coefficients. Input data
+            will be cast to the matching dtype.
+            Only used when both these conditions are met:
+                - `model` is an un-initialized instance;
+                  otherwise the dtype of existing coefficients is used.
+                - the scikit-learn version is >= 1.3;
+                  otherwise the dtype is forced to float64
         """
         if not isinstance(model, (SGDClassifier, SGDRegressor)):
             raise TypeError(
@@ -102,7 +165,7 @@ class SklearnSGDModel(Model):
             average=False,
         )
         super().__init__(model)
-        self._initialized = False
+        self._dtype = select_sgd_model_dtype(model, dtype)
         self._predict = (
             self._model.decision_function
             if isinstance(model, SGDClassifier)
@@ -122,6 +185,8 @@ class SklearnSGDModel(Model):
     def required_data_info(
         self,
     ) -> Set[str]:
+        if hasattr(self._model, "coef_"):
+            return set()
         if isinstance(self._model, SGDRegressor):
             return {"features_shape"}
         return {"features_shape", "classes"}
@@ -130,6 +195,9 @@ class SklearnSGDModel(Model):
         self,
         data_info: Dict[str, Any],
     ) -> None:
+        # Skip for pre-initialized models.
+        if hasattr(self._model, "coef_"):
+            return
         # Check that required fields are available and of valid type.
         data_info = aggregate_data_info([data_info], self.required_data_info)
         if not (
@@ -145,14 +213,12 @@ class SklearnSGDModel(Model):
             self._model.classes_ = np.array(list(data_info["classes"]))
             n_classes = len(self._model.classes_)
             dim = n_classes if (n_classes > 2) else 1
-            self._model.coef_ = np.zeros((dim, feat))
-            self._model.intercept_ = np.zeros((dim,))
+            self._model.coef_ = np.zeros((dim, feat), dtype=self._dtype)
+            self._model.intercept_ = np.zeros((dim,), dtype=self._dtype)
         # SGDRegressor case.
         else:
-            self._model.coef_ = np.zeros((feat,))
-            self._model.intercept_ = np.zeros((1,))
-        # Mark the SklearnSGDModel as initialized.
-        self._initialized = True
+            self._model.coef_ = np.zeros((feat,), dtype=self._dtype)
+            self._model.intercept_ = np.zeros((1,), dtype=self._dtype)
 
     @classmethod
     def from_parameters(
@@ -165,6 +231,7 @@ class SklearnSGDModel(Model):
         epsilon: float = 0.1,
         fit_intercept: bool = True,
         n_jobs: Optional[int] = None,
+        dtype: Union[str, np.dtype, Type[np.number]] = "float64",
     ) -> Self:
         """Instantiate a SklearnSGDModel from model parameters.
 
@@ -198,6 +265,14 @@ class SklearnSGDModel(Model):
             Number of CPUs to use when to compute one-versus-all.
             Only used for multi-class classifiers.
             `None` means 1, while -1 means all available CPUs.
+        dtype: str or numpy.dtype or type[np.number], default="float64"
+            Data type to enforce for the model's coefficients. Input data
+            will be cast to the matching dtype.
+            Only used when both these conditions are met:
+                - `model` is an un-initialized instance;
+                  otherwise the dtype of existing coefficients is used.
+                - the scikit-learn version is >= 1.3;
+                  otherwise the dtype is forced to float64
 
         Notes
         -----
@@ -237,22 +312,26 @@ class SklearnSGDModel(Model):
             fit_intercept=fit_intercept,
             **kwargs,
         )
-        return cls(model)
+        return cls(model, dtype)
 
     def get_config(
         self,
     ) -> Dict[str, Any]:
         is_clf = isinstance(self._model, SGDClassifier)
         data_info = None  # type: Optional[Dict[str, Any]]
-        if self._initialized:
+        if hasattr(self._model, "coef_"):
             data_info = {
                 "features_shape": (self._model.coef_.shape[-1],),
                 "classes": self._model.classes_.tolist() if is_clf else None,
             }
+            dtype = self._model.coef_.dtype.name
+        else:
+            dtype = self._dtype
         return {
             "kind": "classifier" if is_clf else "regressor",
             "params": self._model.get_params(),
             "data_info": data_info,
+            "dtype": dtype,
         }
 
     @classmethod
@@ -261,13 +340,14 @@ class SklearnSGDModel(Model):
         config: Dict[str, Any],
     ) -> Self:
         """Instantiate a SklearnSGDModel from a configuration dict."""
-        for key in ("kind", "params"):
+        for key in ("kind", "params", "dtype"):
             if key not in config:
                 raise KeyError(f"Missing key '{key}' in the config dict.")
         if config["kind"] == "classifier":
-            model = cls(SGDClassifier(**config["params"]))
+            skmod = SGDClassifier(**config["params"])
         else:
-            model = cls(SGDRegressor(**config["params"]))
+            skmod = SGDRegressor(**config["params"])
+        model = cls(skmod, dtype=config["dtype"])
         if config.get("data_info"):
             model.initialize(config["data_info"])
         return model
@@ -306,8 +386,7 @@ class SklearnSGDModel(Model):
         x_data, y_data, s_wght = self._unpack_batch(batch)
         # Iteratively compute sample-wise gradients.
         grad = [
-            self._compute_sample_gradient(x, y)  # type: ignore
-            for x, y in zip(x_data, y_data)  # type: ignore
+            self._compute_sample_gradient(x, y) for x, y in zip(x_data, y_data)
         ]
         # Optionally clip sample-wise gradients based on their L2 norm.
         if max_norm:
@@ -317,27 +396,45 @@ class SklearnSGDModel(Model):
                     arr *= min(max_norm / norm, 1)
         # Optionally re-weight gradients based on sample weights.
         if s_wght is not None:
-            grad = [g * w for g, w in zip(grad, s_wght)]  # type: ignore
+            grad = [g * w for g, w in zip(grad, s_wght)]
         # Batch-average the gradients and return them.
         return sum(grad) / len(grad)  # type: ignore
 
     def _unpack_batch(
         self,
         batch: Batch,
-    ) -> Tuple[ArrayLike, ArrayLike, Optional[ArrayLike]]:
+    ) -> Tuple[DataArray, DataArray, Optional[DataArray]]:
         """Verify and unpack an input batch into (x, y, [w]).
 
         Note: this method does not verify arrays' dimensionality or
         shape coherence; the wrapped sklearn objects already do so.
         """
         x_data, y_data, s_wght = batch
-        invalid = (y_data is None) or isinstance(y_data, list)
-        if invalid or isinstance(x_data, list):
+        if (
+            (y_data is None)
+            or isinstance(y_data, list)
+            or isinstance(x_data, list)
+        ):
             raise TypeError(
                 "'SklearnSGDModel' requires (array, array, [array|None]) "
                 "data batches."
             )
-        return x_data, y_data, s_wght  # type: ignore
+        x_data = self._validate_and_cast_array(x_data)
+        y_data = self._validate_and_cast_array(y_data)
+        if s_wght is not None:
+            s_wght = self._validate_and_cast_array(s_wght)
+        return x_data, y_data, s_wght
+
+    def _validate_and_cast_array(
+        self,
+        array: ArrayLike,
+    ) -> DataArray:
+        """Type-check and type-cast an input data array."""
+        if not isinstance(array, typing.get_args(DataArray)):
+            raise TypeError(
+                f"Invalid data type for 'SklearnSGDModel': '{type(array)}'."
+            )
+        return array.astype(self._dtype, copy=False)  # type: ignore
 
     def _compute_sample_gradient(
         self,
diff --git a/declearn/model/tensorflow/utils/_loss.py b/declearn/model/tensorflow/utils/_loss.py
index c6b880c8e8aaa6dc90cb197e18bf7677f32285fc..5a3ffdaae06ee144f51777747f1d2bc797dd3f04 100644
--- a/declearn/model/tensorflow/utils/_loss.py
+++ b/declearn/model/tensorflow/utils/_loss.py
@@ -85,22 +85,9 @@ def build_keras_loss(
     # Case when 'loss' is already a Loss object.
     if isinstance(loss, tf.keras.losses.Loss):
         loss.reduction = reduction
-    # Case when 'loss' is a string.
+    # Case when 'loss' is a string: deserialize and/or wrap into a Loss object.
     elif isinstance(loss, str):
-        cls = tf.keras.losses.deserialize(loss)
-        # Case when the string was deserialized into a function.
-        if inspect.isfunction(cls):
-            # Try altering the string to gather its object counterpart.
-            loss = "".join(word.capitalize() for word in loss.split("_"))
-            try:
-                loss = tf.keras.losses.deserialize(loss)
-                loss.reduction = reduction
-            # If this failed, try wrapping the function using LossFunction.
-            except ValueError:
-                loss = LossFunction(cls)
-        # Case when the string was deserialized into a class.
-        else:
-            loss = cls(reduction=reduction)
+        loss = get_keras_loss_from_string(name=loss, reduction=reduction)
     # Case when 'loss' is a function: wrap it up using LossFunction.
     elif inspect.isfunction(loss):
         loss = LossFunction(loss, reduction=reduction)
@@ -111,3 +98,32 @@ def build_keras_loss(
         )
     # Otherwise, properly configure the reduction scheme and return.
     return loss
+
+
+def get_keras_loss_from_string(
+    name: str,
+    reduction: str,
+) -> tf.keras.losses.Loss:
+    """Instantiate a keras Loss object from a registered string identifier.
+
+    - If `name` matches a Loss registration name, return an instance.
+    - If it matches a loss function registration name, return either
+      an instance from its name-matching Loss subclass, or a custom
+      Loss subclass instance wrapping the function.
+    - If it does not match anything, raise a ValueError.
+    """
+    loss = tf.keras.losses.deserialize(name)
+    if isinstance(loss, tf.keras.losses.Loss):
+        loss.reduction = reduction
+        return loss
+    if inspect.isfunction(loss):
+        try:
+            name = "".join(word.capitalize() for word in name.split("_"))
+            return get_keras_loss_from_string(name, reduction)
+        except ValueError:
+            return LossFunction(
+                loss, reduction=reduction, name=getattr(loss, "__name__", None)
+            )
+    raise ValueError(
+        f"Name '{loss}' cannot be deserialized into a keras loss."
+    )
diff --git a/test/functional/test_main.py b/test/functional/test_main.py
index 1f1a0bfe67ab2ac0aea456c62c744f12000a2e97..986140da9b529578cd124daf1a2ce9f5394939b5 100644
--- a/test/functional/test_main.py
+++ b/test/functional/test_main.py
@@ -89,7 +89,8 @@ class DeclearnTestCase:
         """Return a Model suitable for the learning task and framework."""
         if self.framework.lower() == "sksgd":
             return SklearnSGDModel.from_parameters(
-                kind=("regressor" if self.kind == "Reg" else "classifier")
+                kind=("regressor" if self.kind == "Reg" else "classifier"),
+                dtype="float32",
             )
         if self.framework.lower() == "tflow":
             return self._build_tflow_model()
@@ -274,7 +275,7 @@ def run_test_case(
     )
 
 
-@pytest.mark.parametrize("strategy", ["FedAvg", "FedAvgM", "Scaffold"])
+@pytest.mark.parametrize("strategy", ["FedAvg", "Scaffold"])
 @pytest.mark.parametrize("framework", FRAMEWORKS)
 @pytest.mark.parametrize("kind", ["Reg", "Bin", "Clf"])
 @pytest.mark.filterwarnings("ignore: PyTorch JSON serialization")
@@ -293,7 +294,7 @@ def test_declearn(
     Note: If websockets is unavailable, use gRPC (warn) or fail.
     """
     if not fulltest:
-        if (kind != "Reg") or (strategy == "FedAvgM"):
+        if (kind != "Reg") or (strategy == "FedAvg"):
             pytest.skip("skip scenario (no --fulltest option)")
     protocol = "websockets"  # type: Literal["grpc", "websockets"]
     if "websockets" not in list_available_protocols():
diff --git a/test/model/test_sksgd.py b/test/model/test_sksgd.py
index 8ea5504a269c4b8446d1493bd7a19679cda1f18a..2a84a4b19e7093440d747063d6caa646d6489161 100644
--- a/test/model/test_sksgd.py
+++ b/test/model/test_sksgd.py
@@ -70,7 +70,7 @@ class SklearnSGDTestCase(ModelTestCase):
         assert isinstance(tensor, (np.ndarray, csr_matrix))
         if isinstance(tensor, csr_matrix):
             tensor = tensor.toarray()
-        return tensor  # type: ignore
+        return tensor
 
     @property
     def dataset(
@@ -78,15 +78,15 @@ class SklearnSGDTestCase(ModelTestCase):
     ) -> List[Batch]:
         """Suited toy binary-classification dataset."""
         rng = np.random.default_rng(seed=0)
-        inputs = rng.normal(size=(2, 32, 8))
+        inputs = rng.normal(size=(2, 32, 8)).astype("float32")
         if self.as_sparse:
             inputs = [csr_matrix(arr) for arr in inputs]  # type: ignore
         if isinstance(self.n_classes, int):
-            labels = rng.choice(self.n_classes, size=(2, 32)).astype(float)
+            labels = rng.choice(self.n_classes, size=(2, 32)).astype("float32")
         else:
-            labels = rng.normal(size=(2, 32))
+            labels = rng.normal(size=(2, 32)).astype("float32")
         if self.s_weights:
-            s_wght = np.exp(rng.normal(size=(2, 32)))
+            s_wght = np.exp(rng.normal(size=(2, 32)).astype("float32"))
             s_wght /= s_wght.sum(axis=1, keepdims=True) * 32
             batches = list(zip(inputs, labels, s_wght))
         else:
@@ -99,7 +99,7 @@ class SklearnSGDTestCase(ModelTestCase):
     ) -> SklearnSGDModel:
         """Suited toy binary-classification model."""
         skmod = (SGDClassifier if self.n_classes else SGDRegressor)()
-        model = SklearnSGDModel(skmod)
+        model = SklearnSGDModel(skmod, dtype="float32")
         data_info = {"features_shape": (8,)}  # type: Dict[str, Any]
         if self.n_classes:
             data_info["classes"] = np.arange(self.n_classes)
diff --git a/test/model/test_tflow.py b/test/model/test_tflow.py
index 0387075f66dab9bad3635d033e76775c56666624..168d1c0b9b2e1f3dc3c9d9fdd241282848215108 100644
--- a/test/model/test_tflow.py
+++ b/test/model/test_tflow.py
@@ -32,6 +32,7 @@ except ModuleNotFoundError:
     pytest.skip("TensorFlow is unavailable", allow_module_level=True)
 
 from declearn.model.tensorflow import TensorflowModel, TensorflowVector
+from declearn.model.tensorflow.utils import build_keras_loss
 from declearn.typing import Batch
 from declearn.utils import set_device_policy
 
@@ -221,3 +222,66 @@ class TestTensorflowModel(ModelTestSuite):
         device = f"{test_case.device}:0"
         for var in tfmod.weights:
             assert var.device.endswith(device)
+
+
+class TestBuildKerasLoss:
+    """Unit tests for `build_keras_loss` util function."""
+
+    def test_build_keras_loss_from_string_class_name(self) -> None:
+        """Test `build_keras_loss` with a valid class name string input."""
+        loss = build_keras_loss(
+            "BinaryCrossentropy", tf.keras.losses.Reduction.SUM
+        )
+        assert isinstance(loss, tf.keras.losses.BinaryCrossentropy)
+        assert loss.reduction == tf.keras.losses.Reduction.SUM
+
+    def test_build_keras_loss_from_string_function_name(self) -> None:
+        """Test `build_keras_loss` with a valid function name string input."""
+        loss = build_keras_loss(
+            "binary_crossentropy", tf.keras.losses.Reduction.SUM
+        )
+        assert isinstance(loss, tf.keras.losses.BinaryCrossentropy)
+        assert loss.reduction == tf.keras.losses.Reduction.SUM
+
+    def test_build_keras_loss_from_string_noclass_function_name(self) -> None:
+        """Test `build_keras_loss` with a valid function name string input."""
+        loss = build_keras_loss("mse", tf.keras.losses.Reduction.SUM)
+        assert isinstance(loss, tf.keras.losses.Loss)
+        assert hasattr(loss, "loss_fn")
+        assert loss.loss_fn is tf.keras.losses.mse
+        assert loss.reduction == tf.keras.losses.Reduction.SUM
+
+    def test_build_keras_loss_from_loss_instance(self) -> None:
+        """Test `build_keras_loss` with a valid keras Loss input."""
+        # Set up a BinaryCrossentropy loss instance.
+        loss = tf.keras.losses.BinaryCrossentropy(
+            reduction=tf.keras.losses.Reduction.SUM
+        )
+        assert loss.reduction == tf.keras.losses.Reduction.SUM
+        # Pass it through the util and verify that reduction changes.
+        loss = build_keras_loss(loss, tf.keras.losses.Reduction.NONE)
+        assert isinstance(loss, tf.keras.losses.BinaryCrossentropy)
+        assert loss.reduction == tf.keras.losses.Reduction.NONE
+
+    def test_build_keras_loss_from_loss_function(self) -> None:
+        """Test `build_keras_loss` with a valid keras loss function input."""
+        loss = build_keras_loss(
+            tf.keras.losses.binary_crossentropy, tf.keras.losses.Reduction.SUM
+        )
+        assert isinstance(loss, tf.keras.losses.Loss)
+        assert hasattr(loss, "loss_fn")
+        assert loss.loss_fn is tf.keras.losses.binary_crossentropy
+        assert loss.reduction == tf.keras.losses.Reduction.SUM
+
+    def test_build_keras_loss_from_custom_function(self) -> None:
+        """Test `build_keras_loss` with a valid custom loss function input."""
+
+        def loss_fn(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
+            """Custom loss function."""
+            return tf.reduce_sum(tf.cast(tf.equal(y_true, y_pred), tf.float32))
+
+        loss = build_keras_loss(loss_fn, tf.keras.losses.Reduction.SUM)
+        assert isinstance(loss, tf.keras.losses.Loss)
+        assert hasattr(loss, "loss_fn")
+        assert loss.loss_fn is loss_fn
+        assert loss.reduction == tf.keras.losses.Reduction.SUM