From 5517116995b988bfdf5565c809f94897098573f2 Mon Sep 17 00:00:00 2001
From: BIGAUD Nathan <nathan.bigaud@inria.fr>
Date: Mon, 17 Apr 2023 17:00:36 +0200
Subject: [PATCH] Add a seed policy and wire existing seeding tools.

Can be heavily simplified if implementation of model seeding
does not require anything fancier than a seed integer.
---
 declearn/dataset/_inmemory.py            |  4 +-
 declearn/dataset/_split_data.py          |  5 +-
 declearn/dataset/utils/_split_classif.py |  2 +
 declearn/optimizer/modules/_noise.py     |  6 +-
 declearn/utils/__init__.py               |  6 ++
 declearn/utils/_seed_policy.py           | 76 ++++++++++++++++++++++++
 6 files changed, 94 insertions(+), 5 deletions(-)
 create mode 100644 declearn/utils/_seed_policy.py

diff --git a/declearn/dataset/_inmemory.py b/declearn/dataset/_inmemory.py
index 09df296a..825a3b7a 100644
--- a/declearn/dataset/_inmemory.py
+++ b/declearn/dataset/_inmemory.py
@@ -31,7 +31,7 @@ from typing_extensions import Self  # future: import from typing (py >=3.11)
 from declearn.dataset._base import Dataset, DataSpecs
 from declearn.dataset.utils import load_data_array, save_data_array
 from declearn.typing import Batch
-from declearn.utils import json_dump, json_load, register_type
+from declearn.utils import get_seed_policy, json_dump, json_load, register_type
 
 __all__ = [
     "InMemoryDataset",
@@ -165,7 +165,7 @@ class InMemoryDataset(Dataset):
         self.expose_classes = expose_classes
         self.expose_data_type = expose_data_type
         # Assign a random number generator.
-        self.seed = seed
+        self.seed = seed if seed else get_seed_policy().seed
         self._rng = np.random.default_rng(seed)
 
     @property
diff --git a/declearn/dataset/_split_data.py b/declearn/dataset/_split_data.py
index 1729c2dd..d91bad21 100644
--- a/declearn/dataset/_split_data.py
+++ b/declearn/dataset/_split_data.py
@@ -45,6 +45,7 @@ from declearn.dataset.utils import (
     save_data_array,
     split_multi_classif_dataset,
 )
+from declearn.utils import get_seed_policy
 
 
 __all__ = [
@@ -183,7 +184,8 @@ def split_data(
         Train/validation split in each client dataset, must be in the
         ]0,1] range.
     seed: int or None, default=None
-        Optional seed to the RNG used for all sampling operations.
+        Optional seed to the RNG used for all sampling operations. If None
+        default to global policy.
     """
     # pylint: disable=too-many-arguments,too-many-locals
     # Select output folder.
@@ -196,6 +198,7 @@ def split_data(
     print(
         f"Splitting data into {n_shards} shards using the '{scheme}' scheme."
     )
+    seed = seed if seed else get_seed_policy().seed
     split = split_multi_classif_dataset(
         dataset=(inputs, labels),
         n_shards=n_shards,
diff --git a/declearn/dataset/utils/_split_classif.py b/declearn/dataset/utils/_split_classif.py
index 0e3dbc98..ceb56c5a 100644
--- a/declearn/dataset/utils/_split_classif.py
+++ b/declearn/dataset/utils/_split_classif.py
@@ -22,6 +22,7 @@ from typing import List, Literal, Optional, Tuple, Type, Union
 import numpy as np
 from scipy.sparse import csr_matrix, spmatrix  # type: ignore
 
+from declearn.utils import get_seed_policy
 
 __all__ = [
     "split_multi_classif_dataset",
@@ -85,6 +86,7 @@ def split_multi_classif_dataset(
     else:
         raise ValueError(f"Invalid 'scheme' value: '{scheme}'.")
     # Set up the RNG and unpack the dataset.
+    seed = seed if seed else get_seed_policy().seed
     rng = np.random.default_rng(seed)
     inputs, target = dataset
     # Optionally handle sparse matrix inputs.
diff --git a/declearn/optimizer/modules/_noise.py b/declearn/optimizer/modules/_noise.py
index 2409e76e..cc28fc4b 100644
--- a/declearn/optimizer/modules/_noise.py
+++ b/declearn/optimizer/modules/_noise.py
@@ -28,6 +28,7 @@ import scipy.stats  # type: ignore
 from declearn.model.api import Vector
 from declearn.model.sklearn import NumpyVector
 from declearn.optimizer.modules._api import OptiModule
+from declearn.utils import get_seed_policy
 
 __all__ = [
     "GaussianNoiseModule",
@@ -60,11 +61,12 @@ class NoiseModule(OptiModule, metaclass=ABCMeta, register=False):
             is significantly slower.
         seed: int or None, default=None
             Seed used for initiliazing the non-secure random number generator.
-            If `safe_mode=True`, seed is ignored.
+            If `safe_mode=True`, seed is ignored. If None and safe mode is off,
+            defaults to the global seed policy.
         """
         rng = SystemRandom if safe_mode else np.random.default_rng
         self._rng = rng(seed)
-        self.seed = seed
+        self.seed = seed if seed else get_seed_policy().seed
 
     @property
     def safe_mode(self) -> bool:
diff --git a/declearn/utils/__init__.py b/declearn/utils/__init__.py
index 451517d3..a33cf297 100644
--- a/declearn/utils/__init__.py
+++ b/declearn/utils/__init__.py
@@ -140,3 +140,9 @@ from ._serialize import (
     serialize_object,
 )
 from ._toml_config import TomlConfig
+
+from ._seed_policy import (
+    SeedPolicy,
+    get_seed_policy,
+    set_seed_policy,
+)
\ No newline at end of file
diff --git a/declearn/utils/_seed_policy.py b/declearn/utils/_seed_policy.py
new file mode 100644
index 00000000..d87f33f1
--- /dev/null
+++ b/declearn/utils/_seed_policy.py
@@ -0,0 +1,76 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils to define a seed policy.
+
+This private submodule defines:
+
+- A dataclass defining a standard to hold a global seed.
+- A private global variable holding the current package-wise device policy.
+- A public pair of functions acting as a getter and a setter for that variable.
+"""
+
+import dataclasses
+from typing import Optional
+
+__all__ = [
+    "SeedPolicy",
+    "get_seed_policy",
+    "set_seed_policy",
+]
+
+
+@dataclasses.dataclass
+class SeedPolicy:
+    """Dataclass to store the global seed.
+
+    Attributes
+    ----------
+    seed: int or None
+        Optional global seed. If None, keep None as seed.
+    """
+
+    seed: Optional[int] = None
+
+
+SEED_POLICY = SeedPolicy()
+
+
+def get_seed_policy() -> SeedPolicy:
+    """Return a copy of the current global seed.
+
+    This method is meant to be used:
+
+    - By end-users that wish to check the current global seed.
+    - By the backend code of objects requiring seeding.
+
+    To update the current policy, use `declearn.utils.set_seed_policy`.
+    """
+    return SeedPolicy(**dataclasses.asdict(SEED_POLICY))
+
+
+def set_seed_policy(
+    seed: Optional[int] = None,
+) -> None:
+    """Update the current global device policy.
+
+    To access the current policy, use `declearn.utils.set_device_policy`.
+
+    """
+    # Using a global statement to have a proper setter to a private variable.
+    global SEED_POLICY  # pylint: disable=global-statement
+    SEED_POLICY = SeedPolicy(seed)
-- 
GitLab