From 5517116995b988bfdf5565c809f94897098573f2 Mon Sep 17 00:00:00 2001 From: BIGAUD Nathan <nathan.bigaud@inria.fr> Date: Mon, 17 Apr 2023 17:00:36 +0200 Subject: [PATCH] Add a seed policy and wire existing seeding tools. Can be heavily simplified if implementation of model seeding does not require anything fancier than a seed integer. --- declearn/dataset/_inmemory.py | 4 +- declearn/dataset/_split_data.py | 5 +- declearn/dataset/utils/_split_classif.py | 2 + declearn/optimizer/modules/_noise.py | 6 +- declearn/utils/__init__.py | 6 ++ declearn/utils/_seed_policy.py | 76 ++++++++++++++++++++++++ 6 files changed, 94 insertions(+), 5 deletions(-) create mode 100644 declearn/utils/_seed_policy.py diff --git a/declearn/dataset/_inmemory.py b/declearn/dataset/_inmemory.py index 09df296a..825a3b7a 100644 --- a/declearn/dataset/_inmemory.py +++ b/declearn/dataset/_inmemory.py @@ -31,7 +31,7 @@ from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.dataset._base import Dataset, DataSpecs from declearn.dataset.utils import load_data_array, save_data_array from declearn.typing import Batch -from declearn.utils import json_dump, json_load, register_type +from declearn.utils import get_seed_policy, json_dump, json_load, register_type __all__ = [ "InMemoryDataset", @@ -165,7 +165,7 @@ class InMemoryDataset(Dataset): self.expose_classes = expose_classes self.expose_data_type = expose_data_type # Assign a random number generator. - self.seed = seed + self.seed = seed if seed else get_seed_policy().seed self._rng = np.random.default_rng(seed) @property diff --git a/declearn/dataset/_split_data.py b/declearn/dataset/_split_data.py index 1729c2dd..d91bad21 100644 --- a/declearn/dataset/_split_data.py +++ b/declearn/dataset/_split_data.py @@ -45,6 +45,7 @@ from declearn.dataset.utils import ( save_data_array, split_multi_classif_dataset, ) +from declearn.utils import get_seed_policy __all__ = [ @@ -183,7 +184,8 @@ def split_data( Train/validation split in each client dataset, must be in the ]0,1] range. seed: int or None, default=None - Optional seed to the RNG used for all sampling operations. + Optional seed to the RNG used for all sampling operations. If None + default to global policy. """ # pylint: disable=too-many-arguments,too-many-locals # Select output folder. @@ -196,6 +198,7 @@ def split_data( print( f"Splitting data into {n_shards} shards using the '{scheme}' scheme." ) + seed = seed if seed else get_seed_policy().seed split = split_multi_classif_dataset( dataset=(inputs, labels), n_shards=n_shards, diff --git a/declearn/dataset/utils/_split_classif.py b/declearn/dataset/utils/_split_classif.py index 0e3dbc98..ceb56c5a 100644 --- a/declearn/dataset/utils/_split_classif.py +++ b/declearn/dataset/utils/_split_classif.py @@ -22,6 +22,7 @@ from typing import List, Literal, Optional, Tuple, Type, Union import numpy as np from scipy.sparse import csr_matrix, spmatrix # type: ignore +from declearn.utils import get_seed_policy __all__ = [ "split_multi_classif_dataset", @@ -85,6 +86,7 @@ def split_multi_classif_dataset( else: raise ValueError(f"Invalid 'scheme' value: '{scheme}'.") # Set up the RNG and unpack the dataset. + seed = seed if seed else get_seed_policy().seed rng = np.random.default_rng(seed) inputs, target = dataset # Optionally handle sparse matrix inputs. diff --git a/declearn/optimizer/modules/_noise.py b/declearn/optimizer/modules/_noise.py index 2409e76e..cc28fc4b 100644 --- a/declearn/optimizer/modules/_noise.py +++ b/declearn/optimizer/modules/_noise.py @@ -28,6 +28,7 @@ import scipy.stats # type: ignore from declearn.model.api import Vector from declearn.model.sklearn import NumpyVector from declearn.optimizer.modules._api import OptiModule +from declearn.utils import get_seed_policy __all__ = [ "GaussianNoiseModule", @@ -60,11 +61,12 @@ class NoiseModule(OptiModule, metaclass=ABCMeta, register=False): is significantly slower. seed: int or None, default=None Seed used for initiliazing the non-secure random number generator. - If `safe_mode=True`, seed is ignored. + If `safe_mode=True`, seed is ignored. If None and safe mode is off, + defaults to the global seed policy. """ rng = SystemRandom if safe_mode else np.random.default_rng self._rng = rng(seed) - self.seed = seed + self.seed = seed if seed else get_seed_policy().seed @property def safe_mode(self) -> bool: diff --git a/declearn/utils/__init__.py b/declearn/utils/__init__.py index 451517d3..a33cf297 100644 --- a/declearn/utils/__init__.py +++ b/declearn/utils/__init__.py @@ -140,3 +140,9 @@ from ._serialize import ( serialize_object, ) from ._toml_config import TomlConfig + +from ._seed_policy import ( + SeedPolicy, + get_seed_policy, + set_seed_policy, +) \ No newline at end of file diff --git a/declearn/utils/_seed_policy.py b/declearn/utils/_seed_policy.py new file mode 100644 index 00000000..d87f33f1 --- /dev/null +++ b/declearn/utils/_seed_policy.py @@ -0,0 +1,76 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils to define a seed policy. + +This private submodule defines: + +- A dataclass defining a standard to hold a global seed. +- A private global variable holding the current package-wise device policy. +- A public pair of functions acting as a getter and a setter for that variable. +""" + +import dataclasses +from typing import Optional + +__all__ = [ + "SeedPolicy", + "get_seed_policy", + "set_seed_policy", +] + + +@dataclasses.dataclass +class SeedPolicy: + """Dataclass to store the global seed. + + Attributes + ---------- + seed: int or None + Optional global seed. If None, keep None as seed. + """ + + seed: Optional[int] = None + + +SEED_POLICY = SeedPolicy() + + +def get_seed_policy() -> SeedPolicy: + """Return a copy of the current global seed. + + This method is meant to be used: + + - By end-users that wish to check the current global seed. + - By the backend code of objects requiring seeding. + + To update the current policy, use `declearn.utils.set_seed_policy`. + """ + return SeedPolicy(**dataclasses.asdict(SEED_POLICY)) + + +def set_seed_policy( + seed: Optional[int] = None, +) -> None: + """Update the current global device policy. + + To access the current policy, use `declearn.utils.set_device_policy`. + + """ + # Using a global statement to have a proper setter to a private variable. + global SEED_POLICY # pylint: disable=global-statement + SEED_POLICY = SeedPolicy(seed) -- GitLab