From 45d1f776484ab8dda9e5d8e3e655d8177a1f8d51 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Mon, 27 May 2024 11:06:46 +0200
Subject: [PATCH] Fix pylint-identified variable (non-)assignment issues.

---
 declearn/dataset/_split_data.py | 2 ++
 declearn/model/sklearn/_sgd.py  | 6 ++++++
 declearn/optimizer/_base.py     | 1 +
 test/model/test_haiku_model.py  | 4 ++++
 test/model/test_tflow_model.py  | 4 ++++
 test/model/test_torch_model.py  | 2 ++
 6 files changed, 19 insertions(+)

diff --git a/declearn/dataset/_split_data.py b/declearn/dataset/_split_data.py
index 1729c2dd..aa4829f7 100644
--- a/declearn/dataset/_split_data.py
+++ b/declearn/dataset/_split_data.py
@@ -130,6 +130,8 @@ def _extract_column_by_index(
         csc = inputs.tocsc()  # sparse matrix with efficient column slicing
         idx = [i for i in range(inputs.shape[1]) if i != target]
         inputs = type(inputs)(csc[:, idx])
+    else:  # pragma: no cover
+        raise TypeError("Invalid type for 'inputs'.")
     return inputs, labels
 
 
diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py
index 17263da5..c0c8aba0 100644
--- a/declearn/model/sklearn/_sgd.py
+++ b/declearn/model/sklearn/_sgd.py
@@ -302,6 +302,12 @@ class SklearnSGDModel(Model):
             if loss not in REG_LOSSES:
                 raise ValueError(f"Invalid loss '{loss}' for SGDRegressor.")
             sk_cls = SGDRegressor
+        # Invalid input case.
+        else:  # pragma: no cover
+            raise ValueError(
+                "Invalid value for SklearnSGDModel 'kind': must be one of "
+                f"{'classifier', 'regressor'}, received '{kind}'."
+            )
         # Instantiate the sklearn model, wrap it up and return.
         model = sk_cls(
             loss=loss,
diff --git a/declearn/optimizer/_base.py b/declearn/optimizer/_base.py
index 4894e1e4..1b637693 100644
--- a/declearn/optimizer/_base.py
+++ b/declearn/optimizer/_base.py
@@ -295,6 +295,7 @@ class Optimizer:
             weights = model.get_weights(trainable=True)
         # Run input gradients and weights through plug-in regularizers.
         if self.regularizers:
+            # false-positive; pylint: disable=possibly-used-before-assignment
             for regularizer in self.regularizers:
                 gradients = regularizer.run(gradients, weights)
         # Run input gradients through plug-in modules.
diff --git a/test/model/test_haiku_model.py b/test/model/test_haiku_model.py
index 7e31f048..8d0d9aec 100644
--- a/test/model/test_haiku_model.py
+++ b/test/model/test_haiku_model.py
@@ -144,6 +144,8 @@ class HaikuTestCase(ModelTestCase):
             inputs = rng.choice(100, size=(2, 32, 128))
         elif self.kind == "CNN":
             inputs = rng.normal(size=(2, 32, 64, 64, 3)).astype("float32")
+        else:
+            raise ValueError("Invalid model 'kind'.")
         labels = rng.choice(2, size=(2, 32))
         # Convert that data to jax-numpy and return it.
         with warnings.catch_warnings():  # jax.jit(device=...) is deprecated
@@ -164,6 +166,8 @@ class HaikuTestCase(ModelTestCase):
         elif self.kind == "RNN":
             shape = [128]
             model_fn = rnn_fn
+        else:
+            raise ValueError("Invalid model 'kind'.")
         model = HaikuModel(model_fn, loss_fn)
         model.initialize(
             {
diff --git a/test/model/test_tflow_model.py b/test/model/test_tflow_model.py
index 65805b5f..d9d45896 100644
--- a/test/model/test_tflow_model.py
+++ b/test/model/test_tflow_model.py
@@ -97,6 +97,8 @@ class TensorflowTestCase(ModelTestCase):
             inputs = rng.uniform((2, 32, 128), 0, 100, tf.int32)
         elif self.kind == "CNN":
             inputs = rng.normal((2, 32, 64, 64, 3))
+        else:
+            raise ValueError("Invalid model 'kind'.")
         labels = rng.uniform((2, 32), 0, 2, tf.int32)
         dataset = tf.data.Dataset.from_tensor_slices((inputs, labels, None))
         return list(iter(dataset))
@@ -133,6 +135,8 @@ class TensorflowTestCase(ModelTestCase):
                 tf_keras.layers.Dense(1, activation="sigmoid"),
             ]
             shape = [None, 64, 64, 3]
+        else:
+            raise ValueError("Invalid model 'kind'.")
         tfmod = tf_keras.Sequential(stack)
         tfmod.build(shape)  # as model is built, no data_info is required
         return TensorflowModel(tfmod, loss="binary_crossentropy", metrics=None)
diff --git a/test/model/test_torch_model.py b/test/model/test_torch_model.py
index 725e81c8..e442e075 100644
--- a/test/model/test_torch_model.py
+++ b/test/model/test_torch_model.py
@@ -127,6 +127,8 @@ class TorchTestCase(ModelTestCase):
             inputs = torch.randint(0, 100, (2, 32, 128), generator=rng)
         elif self.kind == "CNN":
             inputs = torch.randn((2, 32, 3, 64, 64), generator=rng)
+        else:
+            raise ValueError("Invalid model 'kind'.")
         labels = torch.randint(0, 2, (2, 32, 1), generator=rng)
         labels = labels.type(torch.float)
         dataset = torch.utils.data.TensorDataset(inputs, labels)
-- 
GitLab