diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py
index 7825fa6555fdc4fffb199be47a7c438183f4d7a3..22d72c08f3edce7879722aed5cf052e4d5171d31 100644
--- a/declearn/model/sklearn/_sgd.py
+++ b/declearn/model/sklearn/_sgd.py
@@ -25,9 +25,9 @@ from typing import (
 )
 
 import numpy as np
-from numpy.typing import ArrayLike
 import pandas as pd
 import sklearn  # type: ignore
+from numpy.typing import ArrayLike
 from scipy.sparse import spmatrix  # type: ignore
 from sklearn.linear_model import SGDClassifier, SGDRegressor  # type: ignore
 from typing_extensions import Self  # future: import from typing (py >=3.11)
@@ -65,6 +65,51 @@ REG_LOSSES = (
 DataArray = Union[np.ndarray, spmatrix, pd.DataFrame, pd.Series]
 
 
+def select_sgd_model_dtype(
+    model: Union[SGDClassifier, SGDRegressor],
+    dtype: Union[str, np.dtype, Type[np.number]],
+) -> str:
+    """Return the current or future dtype of a scikit-learn SGD model.
+
+    Emit a `RuntimeWarning` if the model has already been initialized
+    with a dtype that does not match the input one, and/or if the
+    scikit-learn version makes it impossible to enforce a non-default
+    dtype.
+    """
+    # Identify the user-defined dtype's name.
+    dtype = _get_dtype_name(dtype)
+    # Warn if the dtype is already defined and does not match the argument.
+    if hasattr(model, "coef_") and (model.coef_.dtype.name != dtype):
+        warnings.warn(
+            f"Cannot enforce dtype '{dtype}' for pre-initialized "
+            f"scikit-learn SGD model (dtype '{model.coef_.dtype.name}').",
+            RuntimeWarning,
+        )
+        return model.coef_.dtype.name
+    # When using scikit-learn <= 1.3, warn about un-settable dtype.
+    if int(sklearn.__version__.split(".", 2)[1]) < 3 and (dtype != "float64"):
+        warnings.warn(
+            "Using scikit-learn <1.3; hence the 'float64' dtype will"
+            f"forcibly be used rather than user-input '{dtype}'.",
+            RuntimeWarning,
+        )
+        return "float64"
+    return dtype
+
+
+def _get_dtype_name(
+    dtype: Union[str, np.dtype, Type[np.number]],
+) -> str:
+    """Gather the name of an input numpy dtype."""
+    if isinstance(dtype, np.dtype):
+        return dtype.name
+    if isinstance(dtype, type) and issubclass(dtype, np.number):
+        return dtype.__name__
+    if isinstance(dtype, str):
+        return dtype
+    raise TypeError("'dtype' should be a str, np.dtype or np.number type.")
+
+
 @register_type(name="SklearnSGDModel", group="Model")
 class SklearnSGDModel(Model):
     """Model wrapper for Scikit-Learn SGDClassifier and SGDRegressor.
@@ -120,26 +165,7 @@ class SklearnSGDModel(Model):
             average=False,
         )
         super().__init__(model)
-        if hasattr(model, "coef_"):
-            self._dtype = self._model.coef_.dtype.name
-        elif isinstance(dtype, np.dtype):
-            self._dtype = dtype.name
-        elif isinstance(dtype, type) and issubclass(dtype, np.number):
-            self._dtype = dtype.__name__
-        elif isinstance(dtype, str):
-            self._dtype = dtype
-        else:
-            raise TypeError(
-                "'dtype' should be a str, np.dtype or np.number type."
-            )
-        if int(sklearn.__version__.split(".", 2)[1]) < 3:
-            if self._dtype != "float64":
-                warnings.warn(
-                    "Using scikit-learn <1.3; hence the float64 dtype will"
-                    f"forcibly be used rather than user-input {self._dtype}.",
-                    RuntimeWarning,
-                )
-                self._dtype = "float64"
+        self._dtype = select_sgd_model_dtype(model, dtype)
         self._predict = (
             self._model.decision_function
             if isinstance(model, SGDClassifier)