diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 0000000000000000000000000000000000000000..db87cc3cb08a7307ef9974f677b544250200b22a --- /dev/null +++ b/AUTHORS @@ -0,0 +1,19 @@ +This file maintains the list of present and past declearn authors. +A secondary file listing punctual open-source contributors may complement it. + +Declearn 2.0 (new implementation) +- Paul Andrey (core developer) +- Nathan Bigaud (core developer) +- Nathalie Vauquier (supervision) + +Declearn 1.x (continuation of 1.0) +- Nathalie Vauquier +- Marc-André Sirgiel + +Declearn 1.0 +- Yannick Bouillard (core developer) +- Paul Andrey (review and minor contributions) + +Supervision (all versions) +- Marc Tommasi +- Aurélien Bellet diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 01b48f0be74d3b08d5334feebc07397124fc1342..52c1c3fe9a9f1d5ecd3e436200ed3923a51bdd29 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ - [Hands-on usage](#hands-on-usage) - [Overview of local differential-privacy capabilities](#local-differential-privacy) - [Developers](#developers) - +- [Copyright](#copyright) -------------------- ## Introduction @@ -178,7 +178,7 @@ netwk = declearn.communication.NetworkServerConfig( certificate="path/to/certificate.pem", private_key="path/to/private_key.pem" ) -optim = declearn.main.FLOptimConfig.from_params( +optim = declearn.main.config.FLOptimConfig.from_params( aggregator="averaging", client_opt=0.001, ) @@ -188,7 +188,7 @@ server = declearn.main.FederatedServer( config = declearn.main.config.FLRunConfig.from_params( rounds=10, register={"min_clients": 1, "max_clients": 3, "timeout": 180}, - training={"n_epochs": 5, "batch_size": 128, "drop_remainder": False}, + training={"n_epoch": 5, "batch_size": 128, "drop_remainder": False}, ) server.run(config) ``` @@ -199,6 +199,7 @@ server.run(config) import declearn netwk = declearn.communication.NetworkClientConfig( + protocol="grpc", server_uri="example.com:8888", name="client_name", certificate="path/to/client_cert.pem" @@ -862,3 +863,37 @@ mypy declearn Note that the test suite run with tox comprises the previous command. If mypy identifies errors, the test suite will fail - notably preventing acceptance of merge requests. + + +## Copyright + +Declearn is an open-source software developed by people from the +[Magnet](https://team.inria.fr/magnet/) team at [Inria](https://www.inria.fr/). + +### Authors + +Current core developers are listed under the `pyproject.toml` file. A more +detailed acknowledgement and history of authors and contributors to declearn +can be found in the `AUTHORS` file. + +### License + +Declearn distributed under the Apache-2.0 license. All code files should +therefore contain the following mention, which also applies to the present +README file: +``` +Copyright 2023 Inria (Institut National de la Recherche en Informatique +et Automatique) + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +``` diff --git a/declearn/__init__.py b/declearn/__init__.py index 3368da6f1883c8133faab4fa7b45bbab8b97e117..075582a5e0253267a303617eb869ae4cc79aee55 100644 --- a/declearn/__init__.py +++ b/declearn/__init__.py @@ -1,20 +1,32 @@ # coding: utf-8 -"""Declearn - a python package for decentralized learning. +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -Declearn is a framework providing with tools to set up and -run Federated Learning processes. It is being developed by -the MAGNET team of INRIA Lille, with the aim of providing -users with a modular and extensible framework to implement -federated learning algorithms and apply them to real-world -(or simulated) data using any model-defining framework one -might want to use. +"""Declearn - a python package for private decentralized learning. -Declearn provides with abstractions that enable algorithms -to be written agnostic to the actual computation framework -as well as with workable interfaces that cover some of the -most popular frameworks, such as Scikit-Learn, TensorFlow -and PyTorch. +Declearn is a modular framework to set up and run federated learning +processes. It is being developed by the MAGNET team of INRIA Lille, +with the aim of providing users with a modular and extensible framework +to implement federated learning algorithms and apply them to real-world +(or simulated) data using any common machine learning framework. + +Declearn provides with abstractions that enable algorithms to be written +agnostic to the actual computation framework as well as with workable +interfaces that cover some of the most popular frameworks, such as +Scikit-Learn, TensorFlow and PyTorch. The package is organized into the following submodules: * aggregator: @@ -39,15 +51,17 @@ The package is organized into the following submodules: Shared utils used (extensively) across all of declearn. """ -from . import typing -from . import utils -from . import communication -from . import data_info -from . import dataset -from . import metrics -from . import model -from . import optimizer -from . import aggregator -from . import main +from . import ( + aggregator, + communication, + data_info, + dataset, + main, + metrics, + model, + optimizer, + typing, + utils, +) -__version__ = "2.0.0b3" +__version__ = "2.0.0" diff --git a/declearn/aggregator/__init__.py b/declearn/aggregator/__init__.py index 69ed06b14c68adba063cac6db00c3a6837bbd758..3e93974ced32583edc146fe894ca219605c1cadd 100644 --- a/declearn/aggregator/__init__.py +++ b/declearn/aggregator/__init__.py @@ -1,6 +1,31 @@ # coding: utf-8 -"""Framework-agnostic Vector aggregation API and tools.""" +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model updates aggregating API and implementations. An Aggregator +is typically meant to be used on a round-wise basis by the orchestrating +server of a centralized federated learning process, to aggregate the +client-wise model updated into a Vector that may then be used as "gradients" +by the server's Optimizer to update the global model. + +This declearn submodule provides with: +* Aggregator : abstract class defining an API for Vector aggregation +* AveragingAggregator : average-based-aggregation Aggregator subclass +* GradientMaskedAveraging : gradient Masked Averaging Aggregator subclass +""" from ._api import Aggregator from ._base import AveragingAggregator diff --git a/declearn/aggregator/_api.py b/declearn/aggregator/_api.py index e2cd5df454e68db96e62613ddd299f9bd7d932f6..e404e88e9b51dfa47aa4d16f8d237774cf8f4c01 100644 --- a/declearn/aggregator/_api.py +++ b/declearn/aggregator/_api.py @@ -1,15 +1,28 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Model updates aggregation API.""" from abc import ABCMeta, abstractmethod -from typing import Any, Dict - +from typing import Any, ClassVar, Dict from declearn.model.api import Vector from declearn.utils import create_types_registry, register_type - __all__ = [ "Aggregator", ] @@ -59,13 +72,15 @@ class Aggregator(metaclass=ABCMeta): See `declearn.utils.register_type` for details on types registration. """ - name: str = NotImplemented + name: ClassVar[str] = NotImplemented def __init_subclass__( cls, register: bool = True, + **kwargs: Any, ) -> None: """Automatically type-register Aggregator subclasses.""" + super().__init_subclass__(**kwargs) if register: register_type(cls, cls.name, group="Aggregator") @@ -92,7 +107,6 @@ class Aggregator(metaclass=ABCMeta): Aggregated updates, as a Vector - treated as gradients by the server-side optimizer. """ - return NotImplemented def get_config( self, diff --git a/declearn/aggregator/_base.py b/declearn/aggregator/_base.py index e1b64f6a5da12569de81acb0b1dda858c19b69da..7972ffab20a1a7e5ebee079aca4cf44e55050bd9 100644 --- a/declearn/aggregator/_base.py +++ b/declearn/aggregator/_base.py @@ -1,13 +1,26 @@ # coding: utf-8 -"""FedAvg-like mean-aggregation class.""" +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from typing import Any, Dict, Optional +"""FedAvg-like mean-aggregation class.""" +from typing import Any, ClassVar, Dict, Optional -from declearn.model.api import Vector from declearn.aggregator._api import Aggregator - +from declearn.model.api import Vector __all__ = [ "AveragingAggregator", @@ -24,7 +37,7 @@ class AveragingAggregator(Aggregator): that use simple weighting schemes. """ - name = "averaging" + name: ClassVar[str] = "averaging" def __init__( self, diff --git a/declearn/aggregator/_gma.py b/declearn/aggregator/_gma.py index 98363e1783e740ee230c4894ed4d03467948878a..fa72114e66b77f78c4b3f3e3544dad9195c71056 100644 --- a/declearn/aggregator/_gma.py +++ b/declearn/aggregator/_gma.py @@ -1,13 +1,26 @@ # coding: utf-8 -"""Gradient Masked Averaging aggregation class.""" +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from typing import Any, Dict, Optional +"""Gradient Masked Averaging aggregation class.""" +from typing import Any, ClassVar, Dict, Optional -from declearn.model.api import Vector from declearn.aggregator._base import AveragingAggregator - +from declearn.model.api import Vector __all__ = [ "GradientMaskedAveraging", @@ -42,7 +55,7 @@ class GradientMaskedAveraging(AveragingAggregator): https://arxiv.org/abs/2201.11986 """ - name = "gradient-masked-averaging" + name: ClassVar[str] = "gradient-masked-averaging" def __init__( self, diff --git a/declearn/communication/__init__.py b/declearn/communication/__init__.py index 393a85e021dbd6ad8d0b75fc3958526298626d89..ce57866eee6a952d34d07a98aa7424e1029a2426 100644 --- a/declearn/communication/__init__.py +++ b/declearn/communication/__init__.py @@ -1,6 +1,24 @@ # coding: utf-8 -"""Submodule implementing client/server communications. +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Submodule implementing client/server communications. This is done by +defining server-side and client-side network communication endpoints for +federated learning processes, as well as suitable messages to be transmitted, +and the available communication protocols. This module contains the following core submodules: * api: @@ -9,7 +27,6 @@ This module contains the following core submodules: Message dataclasses defining information containers to be exchanged between communication endpoints. - It also exposes the following core utility functions: * build_client: Instantiate a NetworkClient, selecting its subclass based on protocol name. @@ -19,7 +36,6 @@ It also exposes the following core utility functions: List the protocol names for which both a NetworkClient and NetworkServer classes are registered (hence available to `build_client`/`build_server`). - Finally, it defines the following protocol-specific submodules, provided the associated third-party dependencies are available: * grpc: @@ -31,15 +47,14 @@ the associated third-party dependencies are available: """ # Messaging and Communications API and base tools: -from . import messaging -from . import api +from . import api, messaging from ._build import ( + _INSTALLABLE_BACKENDS, NetworkClientConfig, NetworkServerConfig, build_client, build_server, list_available_protocols, - _INSTALLABLE_BACKENDS, ) # Concrete implementations using various protocols: diff --git a/declearn/communication/_build.py b/declearn/communication/_build.py index da2b4ea9a9b82e344625f53a89d0feb7e1dc9d7c..62cdf6279ffbb34d359d7f2d16cf77f61e78925b 100644 --- a/declearn/communication/_build.py +++ b/declearn/communication/_build.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Communication endpoints generic instantiation utils.""" import logging diff --git a/declearn/communication/api/__init__.py b/declearn/communication/api/__init__.py index 3d9af128b7e2e306d28fc2815b9f12e929a74ff1..ad974025505b0e0e341b35504a725171874a71c2 100644 --- a/declearn/communication/api/__init__.py +++ b/declearn/communication/api/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Base API to define client- and server-side communication endpoints. This module provides `NetworkClient` and `NetworkServer`, two abstract diff --git a/declearn/communication/api/_client.py b/declearn/communication/api/_client.py index e0383436fd6adaa8f2797797f4d54a237d93582d..a3e01d15aad417a4e94e76fa1cca74d2d1aee973 100644 --- a/declearn/communication/api/_client.py +++ b/declearn/communication/api/_client.py @@ -1,12 +1,26 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Abstract class defining an API for client-side communication endpoints.""" import logging import types from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Optional, Type, Union - +from typing import Any, ClassVar, Dict, Optional, Type, Union from declearn.communication.messaging import ( Empty, @@ -18,7 +32,6 @@ from declearn.communication.messaging import ( ) from declearn.utils import create_types_registry, get_logger, register_type - __all__ = [ "NetworkClient", ] @@ -58,10 +71,15 @@ class NetworkClient(metaclass=ABCMeta): probably be rejected by the server if the client has not registered. """ - protocol: str = NotImplemented + protocol: ClassVar[str] = NotImplemented - def __init_subclass__(cls, register: bool = True) -> None: + def __init_subclass__( + cls, + register: bool = True, + **kwargs: Any, + ) -> None: """Automate the type-registration of NetworkClient subclasses.""" + super().__init_subclass__(**kwargs) if register: register_type(cls, cls.protocol, group="NetworkClient") @@ -107,7 +125,6 @@ class NetworkClient(metaclass=ABCMeta): The return type is communication-protocol dependent. """ - return NotImplemented # similar to NetworkServer API; pylint: disable=duplicate-code @@ -118,7 +135,6 @@ class NetworkClient(metaclass=ABCMeta): Note: this method can be called safely even if the client is already running (simply having no effect). """ - return None @abstractmethod async def stop(self) -> None: @@ -127,7 +143,6 @@ class NetworkClient(metaclass=ABCMeta): Note: this method can be called safely even if the client is not running (simply having no effect). """ - return None async def __aenter__( self, @@ -200,7 +215,6 @@ class NetworkClient(metaclass=ABCMeta): to send a Message (of any kind) to the server and await the primary reply from the `MessagesHandler` used by the server. """ - return NotImplemented async def send_message( self, diff --git a/declearn/communication/api/_server.py b/declearn/communication/api/_server.py index 1194ee303aed87ebc3c68740bee7a2c8bdfec0d8..9d73c472db43d9cc80f3a8313979c7e9d6a823c1 100644 --- a/declearn/communication/api/_server.py +++ b/declearn/communication/api/_server.py @@ -1,12 +1,27 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Abstract class defining an API for server-side communication endpoints.""" import asyncio import logging import types from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Optional, Set, Type, Union +from typing import Any, Dict, Optional, Set, Type, Union, ClassVar from declearn.communication.api._service import MessagesHandler @@ -54,10 +69,15 @@ class NetworkServer(metaclass=ABCMeta): of the awaitable `wait_for_clients` method. """ - protocol: str = NotImplemented + protocol: ClassVar[str] = NotImplemented - def __init_subclass__(cls, register: bool = True) -> None: + def __init_subclass__( + cls, + register: bool = True, + **kwargs: Any, + ) -> None: """Automate the type-registration of NetworkServer subclasses.""" + super().__init_subclass__(**kwargs) if register: register_type(cls, cls.protocol, group="NetworkServer") @@ -108,7 +128,6 @@ class NetworkServer(metaclass=ABCMeta): @abstractmethod def uri(self) -> str: """URI on which this server is exposed, to be requested by clients.""" - return NotImplemented @property def client_names(self) -> Set[str]: @@ -142,21 +161,18 @@ class NetworkServer(metaclass=ABCMeta): password: Optional[str] = None, ) -> Any: """Set up and return a SSL context object suitable for this class.""" - return NotImplemented @abstractmethod async def start( self, ) -> None: """Initialize the server and start welcoming communications.""" - return None @abstractmethod async def stop( self, ) -> None: """Stop the server and purge information about clients.""" - return None async def __aenter__( self, diff --git a/declearn/communication/api/_service.py b/declearn/communication/api/_service.py index adbe9045599f4d130341888be823466018f49b86..d0449e6698e5aa0fb1aa3f1047f7456f7feedf53 100644 --- a/declearn/communication/api/_service.py +++ b/declearn/communication/api/_service.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Protocol-agnostic server-side network messages handler.""" import asyncio diff --git a/declearn/communication/grpc/__init__.py b/declearn/communication/grpc/__init__.py index 037ac9466359b63810263f8a2bc0b138310f5281..598822c0d5f3dce444bc9ebaaa83bd95f1dcf7a9 100644 --- a/declearn/communication/grpc/__init__.py +++ b/declearn/communication/grpc/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """gRPC implementation of network communication endpoints.""" from . import protobufs diff --git a/declearn/communication/grpc/_client.py b/declearn/communication/grpc/_client.py index 941d7f761715c5b33b3e08bf572a0b36eaf98686..31971ac1b5bbc48cdd1f529f1a684fa17dd3db30 100644 --- a/declearn/communication/grpc/_client.py +++ b/declearn/communication/grpc/_client.py @@ -1,19 +1,33 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Client-side communication endpoint implementation using gRPC""" import logging -from typing import Any, Dict, Optional, Union +from typing import Any, ClassVar, Dict, Optional, Union import grpc # type: ignore from declearn.communication.api import NetworkClient -from declearn.communication.messaging import Message, parse_message_from_string from declearn.communication.grpc.protobufs import message_pb2 from declearn.communication.grpc.protobufs.message_pb2_grpc import ( MessageBoardStub, ) - +from declearn.communication.messaging import Message, parse_message_from_string __all__ = [ "GrpcClient", @@ -26,7 +40,7 @@ CHUNK_LENGTH = 2**22 - 50 # 2**22 - sys.getsizeof("") - 1 class GrpcClient(NetworkClient): """Client-side communication endpoint using gRPC.""" - protocol = "grpc" + protocol: ClassVar[str] = "grpc" def __init__( self, diff --git a/declearn/communication/grpc/_server.py b/declearn/communication/grpc/_server.py index 86762e2275e82cbbd69cbec601b3bfea5cbb2a47..831ae2d793790dbd52784a2d7dfd3e0152e10a3f 100644 --- a/declearn/communication/grpc/_server.py +++ b/declearn/communication/grpc/_server.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Server-side communication endpoint implementation using gRPC.""" import getpass diff --git a/declearn/communication/grpc/protobufs/__init__.py b/declearn/communication/grpc/protobufs/__init__.py index 5a3f063e82713b327c5617704b8bca15bb68cc05..a4085a746901b0906882744023bb4efbded7655a 100644 --- a/declearn/communication/grpc/protobufs/__init__.py +++ b/declearn/communication/grpc/protobufs/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Load the gRPC backend code auto-generated from "message.proto". Instructions to re-generate the code: diff --git a/declearn/communication/messaging/__init__.py b/declearn/communication/messaging/__init__.py index 9322338816df0fa97712b05f8cf118653a5ee5f4..464024fd6402be43853e98d3d1ce075455e0e87e 100644 --- a/declearn/communication/messaging/__init__.py +++ b/declearn/communication/messaging/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Submodule defining messaging containers and flags used in declearn.""" from . import flags diff --git a/declearn/communication/messaging/_messages.py b/declearn/communication/messaging/_messages.py index b29617865a925dd38443387ab629651fa0f2ad0a..46cd52d278e0838f3958786b1c4482adb1fb9f55 100644 --- a/declearn/communication/messaging/_messages.py +++ b/declearn/communication/messaging/_messages.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Dataclasses defining messages used in declearn communications.""" import dataclasses diff --git a/declearn/communication/messaging/flags.py b/declearn/communication/messaging/flags.py index 4d7abd65a4c62e463a0d3e833eb9f926ec298a34..c3e0dc0e3048703657e1cb1ab5efa9fc09a6e26c 100644 --- a/declearn/communication/messaging/flags.py +++ b/declearn/communication/messaging/flags.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Communication flags used in declearn communication backends.""" diff --git a/declearn/communication/websockets/__init__.py b/declearn/communication/websockets/__init__.py index 934159ff3a67e3bf08a42a494dd860664aaab2a6..858f7d6a2cbdd1ca8998deea428a96c0357a54c0 100644 --- a/declearn/communication/websockets/__init__.py +++ b/declearn/communication/websockets/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """WebSockets implementation of network communication endpoints.""" from ._client import WebsocketsClient diff --git a/declearn/communication/websockets/_client.py b/declearn/communication/websockets/_client.py index 59627c50180345fab328c7187a0282624414d434..6eeca4fb91f2afe55ed9247a2c5821fc28db0a7a 100644 --- a/declearn/communication/websockets/_client.py +++ b/declearn/communication/websockets/_client.py @@ -1,11 +1,26 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Client-side communication endpoint implementation using WebSockets.""" import asyncio import logging import ssl -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, ClassVar import websockets as ws from websockets.client import WebSocketClientProtocol @@ -25,7 +40,7 @@ CHUNK_LENGTH = 100000 class WebsocketsClient(NetworkClient): """Client-side communication endpoint using WebSockets.""" - protocol = "websockets" + protocol: ClassVar[str] = "websockets" def __init__( self, diff --git a/declearn/communication/websockets/_server.py b/declearn/communication/websockets/_server.py index 35dd5a26318dc59fdbae45acf844ce0f0a455978..89516d443a9dbef4e91bfdfd10128cdb2259d713 100644 --- a/declearn/communication/websockets/_server.py +++ b/declearn/communication/websockets/_server.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Server-side communication endpoint implementation using WebSockets.""" import logging diff --git a/declearn/communication/websockets/_tools.py b/declearn/communication/websockets/_tools.py index 67ce9604b5527307dce9038c4dfa1eed2ad1f257..d31850f04f6edc28ac4d6c280e47f1045aec3c99 100644 --- a/declearn/communication/websockets/_tools.py +++ b/declearn/communication/websockets/_tools.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Shared backend utils for Websockets communication endpoints.""" import sys diff --git a/declearn/data_info/__init__.py b/declearn/data_info/__init__.py index d937d011f0812ba74e041528f88a0278ef440d8b..bc3d471d67a39f9ebd81aaf3c30b9f7ba74ce73a 100644 --- a/declearn/data_info/__init__.py +++ b/declearn/data_info/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Tools to write 'data_info' metadata fields specifications. The 'data_info' dictionaries are a discrete yet important component of @@ -13,8 +28,7 @@ writing specifications for expected 'data_info' fields, and automating their use to validate and combine individual 'data_info' dicts into an aggregated one. -DataInfoField API tools ------------------------ +DataInfoField API tools: * DataInfoField: Abstract class defining an API to write field-wise specifications. * register_data_info_field: @@ -24,8 +38,7 @@ DataInfoField API tools * get_data_info_fields_documentation: Gather documentation for all fields that have registered specs. -Field specifications --------------------- +Field specifications: * ClassesField: Specification for the 'classes' field. * InputShapeField: diff --git a/declearn/data_info/_base.py b/declearn/data_info/_base.py index 99bc1f680fb878ef68595db8447801e2824e9580..d8cdb322177e7c940def738c899107619fd5c018 100644 --- a/declearn/data_info/_base.py +++ b/declearn/data_info/_base.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Tools to write 'data_info' metadata fields specifications. The 'data_info' dictionaries are a discrete yet important component of @@ -34,7 +49,7 @@ data_info fields, are implemented (although unexposed) here. import warnings from abc import ABCMeta, abstractmethod -from typing import Any, Dict, List, Optional, Set, Tuple, Type +from typing import Any, Dict, List, Optional, Set, Tuple, Type, ClassVar __all__ = [ @@ -75,9 +90,9 @@ class DataInfoField(metaclass=ABCMeta): is called, run `is_valid` on each and every input. """ - field: str - types: Tuple[Type, ...] - doc: str + field: ClassVar[str] = NotImplemented + types: ClassVar[Tuple[Type, ...]] = NotImplemented + doc: ClassVar[str] = NotImplemented @classmethod def is_valid( @@ -85,6 +100,7 @@ class DataInfoField(metaclass=ABCMeta): value: Any, ) -> bool: """Check that a given value may belong to this field.""" + # false-pos; pylint: disable=isinstance-second-argument-not-valid-type return isinstance(value, cls.types) @classmethod @@ -101,7 +117,6 @@ class DataInfoField(metaclass=ABCMeta): raise ValueError( f"Cannot combine '{cls.field}': invalid values encountered." ) - return NotImplemented DATA_INFO_FIELDS = {} # type: Dict[str, Type[DataInfoField]] diff --git a/declearn/data_info/_fields.py b/declearn/data_info/_fields.py index 7d456646f11728b426e246a200b98d654fbb9152..6d2f325681d74c3a1e8cd2f4af03cb39c1b919d9 100644 --- a/declearn/data_info/_fields.py +++ b/declearn/data_info/_fields.py @@ -1,14 +1,28 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """DataInfoField subclasses specifying common 'data_info' metadata fields.""" -from typing import Any, List, Optional, Set +from typing import Any, ClassVar, List, Optional, Set, Tuple, Type import numpy as np from declearn.data_info._base import DataInfoField, register_data_info_field - __all__ = [ "ClassesField", "InputShapeField", @@ -21,9 +35,9 @@ __all__ = [ class ClassesField(DataInfoField): """Specifications for 'classes' data_info field.""" - field = "classes" - types = (list, set, tuple, np.ndarray) - doc = "Set of classification targets, combined by union." + field: ClassVar[str] = "classes" + types: ClassVar[Tuple[Type, ...]] = (list, set, tuple, np.ndarray) + doc: ClassVar[str] = "Set of classification targets, combined by union." @classmethod def is_valid( @@ -47,9 +61,9 @@ class ClassesField(DataInfoField): class InputShapeField(DataInfoField): """Specifications for 'input_shape' data_info field.""" - field = "input_shape" - types = (tuple, list) - doc = "Input features' batched shape, checked to be equal." + field: ClassVar[str] = "input_shape" + types: ClassVar[Tuple[Type, ...]] = (tuple, list) + doc: ClassVar[str] = "Input features' batched shape, checked to be equal." @classmethod def is_valid( @@ -96,9 +110,9 @@ class InputShapeField(DataInfoField): class NbFeaturesField(DataInfoField): """Specifications for 'n_features' data_info field.""" - field = "n_features" - types = (int,) - doc = "Number of input features, checked to be equal." + field: ClassVar[str] = "n_features" + types: ClassVar[Tuple[Type, ...]] = (int,) + doc: ClassVar[str] = "Number of input features, checked to be equal." @classmethod def is_valid( @@ -128,9 +142,9 @@ class NbFeaturesField(DataInfoField): class NbSamplesField(DataInfoField): """Specifications for 'n_samples' data_info field.""" - field = "n_samples" - types = (int,) - doc = "Number of data samples, combined by summation." + field: ClassVar[str] = "n_samples" + types: ClassVar[Tuple[Type, ...]] = (int,) + doc: ClassVar[str] = "Number of data samples, combined by summation." @classmethod def is_valid( diff --git a/declearn/dataset/__init__.py b/declearn/dataset/__init__.py index 2de0afa1cee1e7bf0b8e0429137036b4b4a90504..0845515e73f5dde23651556af23b3fe8f568724e 100644 --- a/declearn/dataset/__init__.py +++ b/declearn/dataset/__init__.py @@ -1,6 +1,30 @@ # coding: utf-8 -"""Dataset-interface API and actual implementations module.""" +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset-interface API and actual implementations module. A 'Dataset' +is an interface towards data that exposes methods to query batched data +samples and key metadata while remaining agnostic of the way the data +is actually being loaded (from a source file, a database, another API...). + +This declearn submodule provides with: +* Dataset : abstract class defining an API to access training or testing data +* InMemoryDataset : Dataset subclass serving numpy(-like) memory-loaded data +arrays +""" from ._base import Dataset, DataSpecs, load_dataset_from_json diff --git a/declearn/dataset/_base.py b/declearn/dataset/_base.py index 665ba2b99fb2787eab3c3e1c4b2c13d05c99d3d6..118571c80dc62771e098197c492f64be78d8da63 100644 --- a/declearn/dataset/_base.py +++ b/declearn/dataset/_base.py @@ -1,15 +1,29 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Dataset abstraction API.""" from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Any, Iterator, Optional, Set +from typing import Any, ClassVar, Iterator, Optional, Set from declearn.typing import Batch from declearn.utils import access_registered, create_types_registry, json_load - __all__ = [ "DataSpecs", "Dataset", @@ -41,7 +55,7 @@ class Dataset(metaclass=ABCMeta): straightforward to specify as part of FL algorithms. """ - _type_key: str = NotImplemented + _type_key: ClassVar[str] = NotImplemented @abstractmethod def save_to_json( @@ -56,7 +70,6 @@ class Dataset(metaclass=ABCMeta): Path to the main JSON file where to dump the dataset. Additional files may be created in the same folder. """ - return None @classmethod @abstractmethod @@ -65,14 +78,12 @@ class Dataset(metaclass=ABCMeta): path: str, ) -> "Dataset": """Instantiate a dataset based on local files.""" - return NotImplemented @abstractmethod def get_data_specs( self, ) -> DataSpecs: """Return a DataSpecs object describing this dataset.""" - return NotImplemented @abstractmethod def generate_batches( @@ -115,7 +126,6 @@ class Dataset(metaclass=ABCMeta): Optional weights associated with the samples, that are typically used to balance a model's loss or metrics. """ - return NotImplemented def load_dataset_from_json(path: str) -> Dataset: diff --git a/declearn/dataset/_inmemory.py b/declearn/dataset/_inmemory.py index 97987cffd4c0b6953559f23edfa3d0f2b15f3fab..e48db486295b667d9c8dca4367b877969ab836c7 100644 --- a/declearn/dataset/_inmemory.py +++ b/declearn/dataset/_inmemory.py @@ -1,10 +1,25 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Dataset implementation to serve scikit-learn compatible in-memory data.""" import functools import os -from typing import Any, Dict, Iterator, List, Optional, Set, Union +from typing import Any, ClassVar, Dict, Iterator, List, Optional, Set, Union import numpy as np import pandas as pd # type: ignore @@ -17,7 +32,6 @@ from declearn.dataset._sparse import sparse_from_file, sparse_to_file from declearn.typing import Batch from declearn.utils import json_dump, json_load, register_type - __all__ = [ "InMemoryDataset", ] @@ -56,7 +70,7 @@ class InMemoryDataset(Dataset): # attributes serve clarity; pylint: disable=too-many-instance-attributes - _type_key = "InMemoryDataset" + _type_key: ClassVar[str] = "InMemoryDataset" def __init__( self, diff --git a/declearn/dataset/_sparse.py b/declearn/dataset/_sparse.py index 57ae8e1d5708ccf3e7c9f2c6a554fc7425f82b27..cbf8cbf4ab251fdaf2b90b898acb5679b5ee0e72 100644 --- a/declearn/dataset/_sparse.py +++ b/declearn/dataset/_sparse.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Sparse matrix file dumping and loading utils, inspired by svmlight. The format used is mostly similar to the SVMlight one diff --git a/declearn/main/__init__.py b/declearn/main/__init__.py index 497e850e2a5e48f5df05ac519052188165fc8116..298900f961bd54418cdc5e081623da7f85771c26 100644 --- a/declearn/main/__init__.py +++ b/declearn/main/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Main classes implementing a Federated Learning process. This module mainly implements the following two classes: diff --git a/declearn/main/_client.py b/declearn/main/_client.py index f0707f1627d224d849095ef2dfa43eea4a51d2f9..cfd898fb784f9db7410fc3efd3ff126f2118843c 100644 --- a/declearn/main/_client.py +++ b/declearn/main/_client.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Client-side main Federated Learning orchestrating class.""" import asyncio diff --git a/declearn/main/_server.py b/declearn/main/_server.py index a7cfc9a0ef55e097f13223a64eba9c28b2c0547a..ef29bbd43d16eae23409506b57b12b0af78673f5 100644 --- a/declearn/main/_server.py +++ b/declearn/main/_server.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Server-side main Federated Learning orchestrating class.""" import asyncio diff --git a/declearn/main/config/__init__.py b/declearn/main/config/__init__.py index e96f1a5389605620081aea45f17a95a11dd48599..9114a6f627dfdc2c675467af5966f70bbe413fa0 100644 --- a/declearn/main/config/__init__.py +++ b/declearn/main/config/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Tools to specify hyper-parameters of a Federated Learning process. This submodule exposes dataclasses that group, document and facilitate diff --git a/declearn/main/config/_dataclasses.py b/declearn/main/config/_dataclasses.py index a2f429e30c4b5a694065db260b390d47a0a865eb..cc3cd8b0f7cc455967fec5f0f31a16da878bc8c7 100644 --- a/declearn/main/config/_dataclasses.py +++ b/declearn/main/config/_dataclasses.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Dataclasses to wrap and parse some training-related hyperparameters.""" import dataclasses diff --git a/declearn/main/config/_run_config.py b/declearn/main/config/_run_config.py index a63f7be242c30cf7b2ec81216793db0cec284130..eb88e3224e58da771d58d52e1d9644830df56ab1 100644 --- a/declearn/main/config/_run_config.py +++ b/declearn/main/config/_run_config.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """TOML-parsable container for Federated Learning "run" configurations.""" import dataclasses @@ -100,7 +115,7 @@ class FLRunConfig(TomlConfig): if inputs is None: return RegisterConfig() if isinstance(inputs, int): - return RegisterConfig(min_clients=1) + return RegisterConfig(min_clients=inputs) return cls.default_parser(field, inputs) @classmethod diff --git a/declearn/main/config/_strategy.py b/declearn/main/config/_strategy.py index 7284eb5213577251ee95aa6152020b06da969535..cfae32be56c782a1183f89ee243757593558664f 100644 --- a/declearn/main/config/_strategy.py +++ b/declearn/main/config/_strategy.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """TOML-parsable container for a Federated Learning optimization strategy.""" import dataclasses diff --git a/declearn/main/privacy/__init__.py b/declearn/main/privacy/__init__.py index 7f7c0f153bce50259683b7e5ad2c5852ac077891..99b325d1636c983e5885b9701dd293c6d040c3cb 100644 --- a/declearn/main/privacy/__init__.py +++ b/declearn/main/privacy/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Submodule implementing Differential-Privacy-oriented tools.""" from ._dp_trainer import DPTrainingManager diff --git a/declearn/main/privacy/_dp_trainer.py b/declearn/main/privacy/_dp_trainer.py index 6a3d198e1ab195ca7bc57c77642b7c2ee4ce773e..76b9f92ab8afd9d434d6a95f66adccd86dd078b4 100644 --- a/declearn/main/privacy/_dp_trainer.py +++ b/declearn/main/privacy/_dp_trainer.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """TrainingManager subclass implementing Differential Privacy mechanisms.""" import logging diff --git a/declearn/main/utils/__init__.py b/declearn/main/utils/__init__.py index 057aa9c9eda8be3678ba830355d4c502cdbbff36..2fe4cb3c3fe1a5c88e4010987e0fd3ee03a619a7 100644 --- a/declearn/main/utils/__init__.py +++ b/declearn/main/utils/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Utils for the main federated learning traning and evaluation processes.""" from ._checkpoint import Checkpointer diff --git a/declearn/main/utils/_checkpoint.py b/declearn/main/utils/_checkpoint.py index ecc95933c384dd780d26f386b2f1f8f035a355d1..aba199735e3575ee55f2121f52287b9c7eb10616 100644 --- a/declearn/main/utils/_checkpoint.py +++ b/declearn/main/utils/_checkpoint.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Model and metrics checkpointing util.""" import json diff --git a/declearn/main/utils/_constraints.py b/declearn/main/utils/_constraints.py index a5e5a36a3b5a3d63f8fd58e09edc57d30e5687ed..2204973854b6a76e1b8ebfde3d041299ff53593a 100644 --- a/declearn/main/utils/_constraints.py +++ b/declearn/main/utils/_constraints.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Minimal API to design and enforce computational effort constraints.""" import time diff --git a/declearn/main/utils/_data_info.py b/declearn/main/utils/_data_info.py index a4c94f308d1175aced79f86cac9c9d13a1892a29..b4061e86fc1fe3673ee5cd71d65978621d998218 100644 --- a/declearn/main/utils/_data_info.py +++ b/declearn/main/utils/_data_info.py @@ -1,8 +1,23 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Dependency functions for a FL server to process 'data_info'.""" -from typing import Any, Dict, Set +from typing import Any, Dict, Set, NoReturn from declearn.data_info import aggregate_data_info @@ -94,10 +109,6 @@ def aggregate_clients_data_info( _raise_incompatible_fields(clients_data_info, exc) except Exception as exc: # re-raise; pylint: disable=broad-except _raise_aggregation_fails(clients_data_info, exc) - # Unreachable as the broad except raises. - raise NotImplementedError( - "Unreachable code reached in 'aggregate_clients_data_info'." - ) def _raise_on_missing_fields( @@ -135,7 +146,7 @@ def _raise_on_missing_fields( def _raise_aggregation_fails( clients_data_info: Dict[str, Dict[str, Any]], exception: Exception, -) -> None: +) -> NoReturn: """Raise information about aggregation failure for unexpected cause. Raise a RuntimeError containing client-wise messages and server error. @@ -185,7 +196,7 @@ def _raise_on_invalid_fields( def _raise_incompatible_fields( clients_data_info: Dict[str, Dict[str, Any]], exception: ValueError, -) -> None: +) -> NoReturn: """Raise information about incompatible-due data_info agg. failure. Raise a RuntimeError containing client-wise messages and server error. diff --git a/declearn/main/utils/_early_stop.py b/declearn/main/utils/_early_stop.py index 86c15346c1750f6a7d6df01b62f06fa68e8279ef..86a516e5a20a0df8614607f4ab4758273c582735 100644 --- a/declearn/main/utils/_early_stop.py +++ b/declearn/main/utils/_early_stop.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Simple implementation of metric-based early stopping.""" from typing import Optional diff --git a/declearn/main/utils/_training.py b/declearn/main/utils/_training.py index ea4b2106afaef3d5ce6ce42ec19bb54733b35ebe..39ab2c7e81cab671e883209174458c8764f59e01 100644 --- a/declearn/main/utils/_training.py +++ b/declearn/main/utils/_training.py @@ -1,9 +1,24 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Wrapper to run local training and evaluation rounds in a FL process.""" import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, ClassVar, Dict, List, Optional, Union import numpy as np @@ -92,11 +107,12 @@ class TrainingManager: ) -> Metric: """Return an ad-hoc Metric object to compute the model's loss.""" loss_fn = self.model.loss_function + # Write a custom, unregistered Metric subclass. class LossMetric(MeanMetric, register=False): """Ad hoc Metric wrapping a model's loss function.""" - name = "loss" + name: ClassVar[str] = "loss" def metric_func( self, y_true: np.ndarray, y_pred: np.ndarray diff --git a/declearn/metrics/__init__.py b/declearn/metrics/__init__.py index 14416d0072de80fa03839326786fb9e8391b4fbd..d4e9cf5ddbf1c071795964bdc20bee1ff18197c3 100644 --- a/declearn/metrics/__init__.py +++ b/declearn/metrics/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Iterative and federative evaluation metrics computation tools. This module provides with Metric, an abstract base class that defines diff --git a/declearn/metrics/_api.py b/declearn/metrics/_api.py index c408b9918dfcc4e4f16fe2138533309721e72554..89909cfff06a30b08677ad56812f87f3d4645d1a 100644 --- a/declearn/metrics/_api.py +++ b/declearn/metrics/_api.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Iterative and federative evaluation metrics base class.""" from abc import ABCMeta, abstractmethod @@ -104,7 +119,7 @@ class Metric(metaclass=ABCMeta): See `declearn.utils.register_type` for details on types registration. """ - name: ClassVar[str] + name: ClassVar[str] = NotImplemented def __init__( self, @@ -249,8 +264,10 @@ class Metric(metaclass=ABCMeta): def __init_subclass__( cls, register: bool = True, + **kwargs: Any, ) -> None: """Automatically type-register Metric subclasses.""" + super().__init_subclass__(**kwargs) if register: register_type(cls, name=cls.name, group="Metric") diff --git a/declearn/metrics/_classif.py b/declearn/metrics/_classif.py index b6cb9dfb3dd47ada463a399eb1c3d8a552597683..30a0f3a348709a16d3018b252c2a7c056ebe225f 100644 --- a/declearn/metrics/_classif.py +++ b/declearn/metrics/_classif.py @@ -1,8 +1,23 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Iterative and federative classification evaluation metrics.""" -from typing import Any, Collection, Dict, Optional, Union +from typing import Any, ClassVar, Collection, Dict, Optional, Union import numpy as np import sklearn # type: ignore @@ -41,7 +56,7 @@ class BinaryAccuracyPrecisionRecall(Metric): Confusion matrix of predictions. Values: [[TN, FP], [FN, TP]] """ - name = "binary-classif" + name: ClassVar[str] = "binary-classif" def __init__( self, @@ -131,7 +146,7 @@ class MulticlassAccuracyPrecisionRecall(Metric): were predicted to belong to label j. """ - name = "multi-classif" + name: ClassVar[str] = "multi-classif" def __init__( self, diff --git a/declearn/metrics/_mean.py b/declearn/metrics/_mean.py index 573cbed906be3ee1a09f528a59670163c971a5d4..7ac765a12c2309f18dfce7dcc9c1142d9310a619 100644 --- a/declearn/metrics/_mean.py +++ b/declearn/metrics/_mean.py @@ -1,9 +1,24 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Iterative and federative generic evaluation metrics.""" from abc import ABCMeta, abstractmethod -from typing import Dict, Optional, Union +from typing import ClassVar, Dict, Optional, Union import numpy as np @@ -110,7 +125,7 @@ class MeanAbsoluteError(MeanMetric): summed over channels for (>=2)-dimensional inputs). """ - name = "mae" + name: ClassVar[str] = "mae" def metric_func( self, @@ -138,7 +153,7 @@ class MeanSquaredError(MeanMetric): summed over channels for (>=2)-dimensional inputs). """ - name = "mse" + name: ClassVar[str] = "mse" def metric_func( self, diff --git a/declearn/metrics/_roc_auc.py b/declearn/metrics/_roc_auc.py index 7a3d2980c21d82a31d85784cc008ba3a1cd38eec..a8a97dcc0a277081ba9fe92f098ba1dca92a3f80 100644 --- a/declearn/metrics/_roc_auc.py +++ b/declearn/metrics/_roc_auc.py @@ -1,8 +1,23 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Iterative and federative ROC AUC evaluation metrics.""" -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, ClassVar, Dict, Optional, Tuple, Union import numpy as np import sklearn # type: ignore @@ -10,7 +25,6 @@ import sklearn.metrics # type: ignore from declearn.metrics._api import Metric - __all__ = [ "BinaryRocAUC", ] @@ -42,7 +56,7 @@ class BinaryRocAUC(Metric): unless its """ - name = "binary-roc" + name: ClassVar[str] = "binary-roc" def __init__( self, @@ -159,7 +173,7 @@ class BinaryRocAUC(Metric): "fneg": (s_wght * (~tru & pos)).sum(axis=0), } # Aggregate these scores into the retained states. - thresh, states = combine_roc_states( + thresh, states = _combine_roc_states( thresh, states, self._states["thr"], # type: ignore @@ -188,14 +202,14 @@ class BinaryRocAUC(Metric): msg = "Input thresholds differ from bounded self ones." raise ValueError(msg) # Combine input states with self ones. - thresh, states = combine_roc_states( # type: ignore + thresh, states = _combine_roc_states( # type: ignore thr_own, self._states, thr_oth, states # type: ignore ) self._states = states self._states["thr"] = thresh -def combine_roc_states( +def _combine_roc_states( thresh_a: np.ndarray, states_a: Dict[str, np.ndarray], thresh_b: np.ndarray, @@ -229,13 +243,13 @@ def combine_roc_states( return thresh_a, states # Case when thresholds need alignment. thresh = np.union1d(thresh_a, thresh_b) - states_a = interpolate_roc_states(thresh, thresh_a, states_a) - states_b = interpolate_roc_states(thresh, thresh_b, states_b) + states_a = _interpolate_roc_states(thresh, thresh_a, states_a) + states_b = _interpolate_roc_states(thresh, thresh_b, states_b) states = {key: states_a[key] + states_b[key] for key in states_a} return thresh, states -def interpolate_roc_states( +def _interpolate_roc_states( thresh_r: np.ndarray, thresh_p: np.ndarray, states_p: Dict[str, np.ndarray], diff --git a/declearn/metrics/_wrapper.py b/declearn/metrics/_wrapper.py index 6b825e6b840f482b99397ad7736719c422d67c97..e16fee7ec5c266bd38a047e63bbe7335aa2c32ce 100644 --- a/declearn/metrics/_wrapper.py +++ b/declearn/metrics/_wrapper.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Wrapper for an ensemble of Metric objects.""" from typing import Any, Dict, List, Optional, Tuple, Union diff --git a/declearn/model/__init__.py b/declearn/model/__init__.py index b6bf690931ac6108fe57e660f43059fb2ff8bb33..08d9cdb897d87c7572a1eb21499697ff53d4b0d1 100644 --- a/declearn/model/__init__.py +++ b/declearn/model/__init__.py @@ -1,10 +1,26 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Model interfacing submodule, defining an API an derived applications. This declearn submodule provides with: * Model and Vector abstractions, used as an API to design FL algorithms -* Submodules implementing interfaces to various frameworks and models. +* Submodules implementing interfaces to curretnly supported frameworks +and models. """ from . import api diff --git a/declearn/model/api/__init__.py b/declearn/model/api/__init__.py index 35cecc33a8a5516d268078cc860524a42ab1d4ed..ab147758c5187ff8c62a507d3c63824424045646 100644 --- a/declearn/model/api/__init__.py +++ b/declearn/model/api/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Model Vector abstractions submodule.""" from ._vector import Vector, register_vector_type diff --git a/declearn/model/api/_model.py b/declearn/model/api/_model.py index a2238faea9bb7eb693667b993878f989e84def73..0e86778bb0cc0f91ee05d34149d2712735979b0d 100644 --- a/declearn/model/api/_model.py +++ b/declearn/model/api/_model.py @@ -1,10 +1,24 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Model abstraction API.""" -import warnings from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Iterable, Optional, Set, Tuple +from typing import Any, Dict, Optional, Set, Tuple import numpy as np @@ -47,7 +61,6 @@ class Model(metaclass=ABCMeta): Note: these fields should match a registered specification (see `declearn.data_info` submodule) """ - return NotImplemented @abstractmethod def initialize( @@ -72,14 +85,12 @@ class Model(metaclass=ABCMeta): See the `aggregate_data_info` method to derive `data_info` from client-wise dict. """ - return None @abstractmethod def get_config( self, ) -> Dict[str, Any]: """Return the model's parameters as a JSON-serializable dict.""" - return NotImplemented @classmethod @abstractmethod @@ -88,14 +99,12 @@ class Model(metaclass=ABCMeta): config: Dict[str, Any], ) -> "Model": """Instantiate a model from a configuration dict.""" - return NotImplemented @abstractmethod def get_weights( self, ) -> Vector: """Return the model's trainable weights.""" - return NotImplemented @abstractmethod def set_weights( @@ -103,7 +112,6 @@ class Model(metaclass=ABCMeta): weights: Vector, ) -> None: """Assign values to the model's trainable weights.""" - return None @abstractmethod def compute_batch_gradients( @@ -134,7 +142,6 @@ class Model(metaclass=ABCMeta): Batch-averaged gradients, wrapped into a Vector (using a suited Vector subclass depending on the Model class). """ - return NotImplemented @abstractmethod def apply_updates( @@ -142,7 +149,6 @@ class Model(metaclass=ABCMeta): updates: Vector, ) -> None: """Apply updates to the model's weights.""" - return None @abstractmethod def compute_batch_predictions( @@ -213,37 +219,3 @@ class Model(metaclass=ABCMeta): s_loss: np.ndarray Sample-wise loss values, as a 1-d numpy array. """ - - def compute_loss( - self, - dataset: Iterable[Batch], - ) -> float: - """Compute the average loss of the model on a given dataset. - - Parameters - ---------- - dataset: iterable of batches - Iterable yielding batch structures that are to be unpacked - into (input_features, target_labels, [sample_weights]). - If set, sample weights will affect the loss averaging. - - Returns - ------- - loss: float - Average value of the model's loss over samples. - """ - warning = DeprecationWarning( - "The `Model.compute_loss` method is deprecated as of v2.0b4. " - "It will be removed in version 2.0, in favor of the Metric API." - ) - warnings.warn(warning) - total = 0.0 - n_btc = 0.0 - for batch in dataset: - y_true, y_pred, s_wght = self.compute_batch_predictions(batch) - loss = self.loss_function(y_true, y_pred) - if s_wght is not None: - loss *= s_wght - total += loss.sum() - n_btc += len(loss) if s_wght is None else s_wght.sum() - return total / n_btc diff --git a/declearn/model/api/_vector.py b/declearn/model/api/_vector.py index 398775b05a15ac60b25589cd499dc5874104c98f..f03eb5fa0a8e83a7e862500b8c1eae0c2c3ec675 100644 --- a/declearn/model/api/_vector.py +++ b/declearn/model/api/_vector.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Vector abstraction API.""" import operator @@ -91,7 +106,15 @@ class Vector(metaclass=ABCMeta): self, coefs: Dict[str, Any], ) -> None: - """Instantiate the Vector to wrap a collection of data arrays.""" + """Instantiate the Vector to wrap a collection of data arrays. + + Parameters + ---------- + coefs: dict[str, any] + Dict grouping a named collection of data arrays. + The supported types of that dict's values depends + on the concrete `Vector` subclass being used. + """ self.coefs = coefs @staticmethod @@ -106,6 +129,18 @@ class Vector(metaclass=ABCMeta): implemented Vector subclasses can be made buildable through this staticmethod, which relies on input coefficients' type analysis to infer the Vector type to instantiate and return. + + Parameters + ---------- + coefs: dict[str, any] + Dict grouping a named collection of data arrays, that + all belong to the same framework. + + Returns + ------- + vector: Vector + Vector instance, the concrete class of which depends + on that of the values of the `coefs` dict. """ # Type-check the inputs and look up the Vector subclass to use. if not (isinstance(coefs, dict) and coefs): @@ -141,7 +176,14 @@ class Vector(metaclass=ABCMeta): def shapes( self, ) -> Dict[str, Tuple[int, ...]]: - """Return a dict storing the shape of each coefficient.""" + """Return a dict storing the shape of each coefficient. + + Returns + ------- + shapes: dict[str, tuple(int, ...)] + Dict containing the shape of each of the wrapped data array, + indexed by the coefficient's name. + """ try: return {key: coef.shape for key, coef in self.coefs.items()} except AttributeError as exc: @@ -153,7 +195,16 @@ class Vector(metaclass=ABCMeta): def dtypes( self, ) -> Dict[str, str]: - """Return a dict storing the dtype of each coefficient.""" + """Return a dict storing the dtype of each coefficient. + + Returns + ------- + dtypes: dict[str, tuple(int, ...)] + Dict containing the dtype of each of the wrapped data array, + indexed by the coefficient's name. The dtypes are parsed as + a string, the values of which may vary depending on the + concrete framework of the Vector. + """ try: return {key: str(coef.dtype) for key, coef in self.coefs.items()} except AttributeError as exc: @@ -165,7 +216,22 @@ class Vector(metaclass=ABCMeta): def pack( self, ) -> Dict[str, Any]: - """Return a JSON-serializable dict representation of this Vector.""" + """Return a JSON-serializable dict representation of this Vector. + + This method must return a dict that can be serialized to and from + JSON using the JSON-extending declearn hooks (see `json_pack` and + `json_unpack` functions from the `declearn.utils` module). + + The counterpart `unpack` method may be used to re-create a Vector + from its "packed" dict representation. + + Returns + ------- + packed: dict[str, any] + Dict with str keys, that may be serialized to and from JSON + using the `declearn.utils.json_pack` and `json_unpack` util + functions. + """ return self.coefs @classmethod @@ -173,13 +239,43 @@ class Vector(metaclass=ABCMeta): cls, data: Dict[str, Any], ) -> "Vector": - """Instantiate a Vector from its "packed" dict representation.""" + """Instantiate a Vector from its "packed" dict representation. + + This method is the counterpart to the `pack` one. + + Parameters + ---------- + data: dict[str, any] + Dict produced by the `pack` method of an instance of this class. + + Returns + ------- + vector: Self + Instance of this Vector subclass, (re-)created from the inputs. + """ return cls(data) def apply_func( - self, func: Callable[..., Any], *args: Any, **kwargs: Any + self, + func: Callable[..., Any], + *args: Any, + **kwargs: Any, ) -> "Vector": - """Apply a given function to the wrapped coefficients.""" + """Apply a given function to the wrapped coefficients. + + Parameters + ---------- + func: function(<T>, *args, **kwargs) -> <T> + Function to be applied to each and every coefficient (data + array) wrapped by this Vector, that must return a similar + array (same type, shape and dtype). + *args and **kwargs to `func` may also be passed. + + Returns + ------- + vector: Self + Vector similar to the present one, wrapping the resulting data. + """ coefs = { key: func(coef, *args, **kwargs) for key, coef in self.coefs.items() @@ -191,7 +287,21 @@ class Vector(metaclass=ABCMeta): other: Any, func: Callable[[Any, Any], Any], ) -> "Vector": - """Apply an operation to combine this vector with another.""" + """Apply an operation to combine this vector with another. + + Parameters + ---------- + other: Vector + Vector with the same names, shapes and dtypes as this one. + func: function(<T>, <T>) -> <T> + Function to be applied to combine the data arrays stored + in this vector and the `other` one. + + Returns + ------- + vector: Self + Vector similar to the present one, wrapping the resulting data. + """ # Case when operating on two Vector objects. if isinstance(other, tuple(self.compatible_vector_types)): if self.coefs.keys() != other.coefs.keys(): @@ -277,14 +387,19 @@ class Vector(metaclass=ABCMeta): @abstractmethod def __eq__(self, other: Any) -> bool: - raise NotImplementedError + """Equality operator for Vector classes. + + Two Vectors should be deemed equal if they have the same + specs (same keys, shapes and dtypes) and the same values. + + Otherwise, this magic method should return False. + """ @abstractmethod def sign( self, ) -> "Vector": """Return a Vector storing the sign of each coefficient.""" - raise NotImplementedError @abstractmethod def minimum( @@ -292,7 +407,6 @@ class Vector(metaclass=ABCMeta): other: Union["Vector", float, ArrayLike], ) -> "Vector": """Compute coef.-wise, element-wise minimum wrt to another Vector.""" - raise NotImplementedError @abstractmethod def maximum( @@ -300,14 +414,12 @@ class Vector(metaclass=ABCMeta): other: Union["Vector", float, ArrayLike], ) -> "Vector": """Compute coef.-wise, element-wise maximum wrt to another Vector.""" - raise NotImplementedError @abstractmethod def sum( self, ) -> "Vector": """Compute coefficient-wise sum of elements.""" - raise NotImplementedError def register_vector_type( @@ -345,6 +457,7 @@ def register_vector_type( as a class decorator. """ v_types = (v_type, *types) + # Set up a registration function. def register(cls: Type[Vector]) -> Type[Vector]: nonlocal name, v_types diff --git a/declearn/model/sklearn/__init__.py b/declearn/model/sklearn/__init__.py index 06d3265f4e0cce74946836eecaa8a91e1a0104f5..15702ec3696479d790741d856fdc19d743748f7b 100644 --- a/declearn/model/sklearn/__init__.py +++ b/declearn/model/sklearn/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Scikit-Learn models interfacing tools. Due to the variety of model classes provided by scikit-learn diff --git a/declearn/model/sklearn/_np_vec.py b/declearn/model/sklearn/_np_vec.py index 190e2355b37b2654ce52fe53b4ebb218a5b5c304..75d1d42b5954ee56e396f76dab2d3f1ef64a1f08 100644 --- a/declearn/model/sklearn/_np_vec.py +++ b/declearn/model/sklearn/_np_vec.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """NumpyVector data arrays container.""" from typing import Any, Callable, Dict, Optional, Union @@ -10,7 +25,6 @@ from typing_extensions import Self # future: import from typing (Py>=3.11) from declearn.model.api._vector import Vector, register_vector_type - __all__ = [ "NumpyVector", ] @@ -54,10 +68,14 @@ class NumpyVector(Vector): def __eq__(self, other: Any) -> bool: valid = isinstance(other, NumpyVector) - valid = valid and (self.coefs.keys() == other.coefs.keys()) - return valid and all( - np.array_equal(self.coefs[k], other.coefs[k]) for k in self.coefs - ) + if valid: + valid = self.coefs.keys() == other.coefs.keys() + if valid: + valid = all( + np.array_equal(self.coefs[k], other.coefs[k]) + for k in self.coefs + ) + return valid def sign( self, @@ -86,7 +104,7 @@ class NumpyVector(Vector): keepdims: bool = False, ) -> Self: # type: ignore coefs = { - key: np.sum(val, axis=axis, keepdims=keepdims) + key: np.array(np.sum(val, axis=axis, keepdims=keepdims)) for key, val in self.coefs.items() } return self.__class__(coefs) diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py index 93dccbe0edd9c479ef7ed432f44144e8b805b49f..d7b218952f6c5bf6f8afcd3b09cdd474f5a0e239 100644 --- a/declearn/model/sklearn/_sgd.py +++ b/declearn/model/sklearn/_sgd.py @@ -1,14 +1,28 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Model subclass to wrap scikit-learn SGD classifier and regressor models.""" import typing -from typing import Any, Callable, Dict, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, Literal, Optional, Set, Tuple, Union import numpy as np from numpy.typing import ArrayLike from sklearn.linear_model import SGDClassifier, SGDRegressor # type: ignore -from typing_extensions import Literal # future: import from typing (Py>=3.8) from declearn.data_info import aggregate_data_info from declearn.model.api import Model diff --git a/declearn/model/tensorflow/__init__.py b/declearn/model/tensorflow/__init__.py index ce73d013a66d00c95eefec0bb65871f1f56283e2..ff94c495b4647614cc5f9e9e44b0a9d636f3c3d3 100644 --- a/declearn/model/tensorflow/__init__.py +++ b/declearn/model/tensorflow/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Tensorflow models interfacing tools. This submodule provides with a generic interface to wrap up diff --git a/declearn/model/tensorflow/_model.py b/declearn/model/tensorflow/_model.py index 5b8a45ed59e87d69f35cc2b488949a5577657432..e9a24edb58f296a853a30523c4e6bb155eb87537 100644 --- a/declearn/model/tensorflow/_model.py +++ b/declearn/model/tensorflow/_model.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Model subclass to wrap TensorFlow models.""" from copy import deepcopy diff --git a/declearn/model/tensorflow/_utils.py b/declearn/model/tensorflow/_utils.py index 69bb6e9669c20c419d64439df98348185600c013..5230b3307a2368a6d633b95876839bfb2fa1a451 100644 --- a/declearn/model/tensorflow/_utils.py +++ b/declearn/model/tensorflow/_utils.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Backend utils for the declearn.model.tensorflow module.""" import inspect diff --git a/declearn/model/tensorflow/_vector.py b/declearn/model/tensorflow/_vector.py index 163f779ac956839b1f45d3cca4cfb5aa042a480a..e27cc5720cf62a439fc0722dc9a0b832a483eb78 100644 --- a/declearn/model/tensorflow/_vector.py +++ b/declearn/model/tensorflow/_vector.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """TensorflowVector data arrays container.""" from typing import Any, Callable, Dict, Optional, Set, Type, Union @@ -10,11 +25,12 @@ import tensorflow as tf # type: ignore from tensorflow.python.framework.ops import EagerTensor # type: ignore # pylint: enable=no-name-in-module from typing_extensions import Self # future: import from typing (Py>=3.11) -# fmt: on from declearn.model.api import Vector, register_vector_type from declearn.model.sklearn import NumpyVector +# fmt: on + @register_vector_type(tf.Tensor, EagerTensor, tf.IndexedSlices) class TensorflowVector(Vector): @@ -116,11 +132,14 @@ class TensorflowVector(Vector): other: Any, ) -> bool: valid = isinstance(other, TensorflowVector) - valid = valid & (self.coefs.keys() == other.coefs.keys()) - return valid and all( - self._tensor_equal(self.coefs[key], other.coefs[key]) - for key in self.coefs - ) + if valid: + valid = self.coefs.keys() == other.coefs.keys() + if valid: + valid = all( + self._tensor_equal(self.coefs[key], other.coefs[key]) + for key in self.coefs + ) + return valid @staticmethod def _tensor_equal( diff --git a/declearn/model/torch/__init__.py b/declearn/model/torch/__init__.py index 122ca3e23e99e1a7e2f623989a4124c6c6a292e6..352d3a398f6958488404c2ecd5aecc4a0f29a43b 100644 --- a/declearn/model/torch/__init__.py +++ b/declearn/model/torch/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Tensorflow models interfacing tools. This submodule provides with a generic interface to wrap up diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py index 87ada068297792755f7c3dc0a2452a22b14f5f90..1e9290cd60f48e26d2239152c535d9303aa03ccb 100644 --- a/declearn/model/torch/_model.py +++ b/declearn/model/torch/_model.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Model subclass to wrap PyTorch models.""" import io diff --git a/declearn/model/torch/_vector.py b/declearn/model/torch/_vector.py index 73b27738c7c54c85473824c7a1f6b91fc0ef0a17..25470995dfe7a4cef79588ca7742f5c130bc4687 100644 --- a/declearn/model/torch/_vector.py +++ b/declearn/model/torch/_vector.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """TorchVector data arrays container.""" from typing import Any, Callable, Dict, Optional, Set, Tuple, Type @@ -96,11 +111,14 @@ class TorchVector(Vector): other: Any, ) -> bool: valid = isinstance(other, TorchVector) - valid = valid and (self.coefs.keys() == other.coefs.keys()) - return valid and all( - np.array_equal(self.coefs[k].numpy(), other.coefs[k].numpy()) - for k in self.coefs - ) + if valid: + valid = self.coefs.keys() == other.coefs.keys() + if valid: + valid = all( + np.array_equal(self.coefs[k].numpy(), other.coefs[k].numpy()) + for k in self.coefs + ) + return valid def sign(self) -> Self: # type: ignore # false-positive; pylint: disable=no-member @@ -123,7 +141,7 @@ class TorchVector(Vector): ) -> Self: # type: ignore # false-positive; pylint: disable=no-member if isinstance(other, Vector): - return self._apply_operation(other, torch.minimum) + return self._apply_operation(other, torch.maximum) if isinstance(other, float): other = torch.Tensor([other]) return self.apply_func(torch.maximum, other) diff --git a/declearn/optimizer/__init__.py b/declearn/optimizer/__init__.py index 9858d439681f1c81f91a9b88a8dc78a4cdcae17d..ded1eb0e90dc67198f07b50f02e6a5c9ea1cdb5b 100644 --- a/declearn/optimizer/__init__.py +++ b/declearn/optimizer/__init__.py @@ -1,7 +1,33 @@ # coding: utf-8 -"""Framework-agnostic optimizer tools, both generic or FL-specific.""" +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from . import modules -from . import regularizers +"""Framework-agnostic optimizer tools, both generic or FL-specific. In more +details, we here define an `Optimizer` class that wraps together a set of +modules, used to implement various optimization and regularization techniques. + +Main class: +* Optimizer: Base class to define gradient-descent-based optimizers. + +This module also implements the following submodules, used by the former: +* modules: gradients-alteration algorithms, implemented as plug-in modules. +* regularizers: loss-regularization algorithms, implemented as plug-in modules. + + """ + + +from . import modules, regularizers from ._base import Optimizer diff --git a/declearn/optimizer/_base.py b/declearn/optimizer/_base.py index 32f5a71bdba0c58ce22ead2bf39d5e6b293bb69e..9c464063e34746d2d7017ec447616639e41543ab 100644 --- a/declearn/optimizer/_base.py +++ b/declearn/optimizer/_base.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Base class to define gradient-descent-based optimizers.""" from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union diff --git a/declearn/optimizer/modules/__init__.py b/declearn/optimizer/modules/__init__.py index 89adf668b458f95fd6365396775d088394ced1eb..d24cf4315210da8a25b1a228d812163e1c4e3471 100644 --- a/declearn/optimizer/modules/__init__.py +++ b/declearn/optimizer/modules/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Optimizer gradients-alteration algorithms, implemented as plug-in modules. Base class implemented here: diff --git a/declearn/optimizer/modules/_adaptive.py b/declearn/optimizer/modules/_adaptive.py index 38528f730f9f8e3c4c147649a77af2c4f9277870..54ff5e794e53eb98fc330b7c86174b030b003f2b 100644 --- a/declearn/optimizer/modules/_adaptive.py +++ b/declearn/optimizer/modules/_adaptive.py @@ -1,14 +1,28 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Adaptive algorithms for optimizers, implemented as plug-in modules.""" -from typing import Any, Dict, Optional, Union +from typing import Any, ClassVar, Dict, Optional, Union from declearn.model.api import Vector from declearn.optimizer.modules._api import OptiModule from declearn.optimizer.modules._momentum import EWMAModule, YogiMomentumModule - __all__ = [ "AdaGradModule", "AdamModule", @@ -38,7 +52,7 @@ class AdaGradModule(OptiModule): https://jmlr.org/papers/v12/duchi11a.html """ - name = "adagrad" + name: ClassVar[str] = "adagrad" def __init__( self, @@ -102,7 +116,7 @@ class RMSPropModule(OptiModule): Average of its Recent Magnitude. """ - name = "rmsprop" + name: ClassVar[str] = "rmsprop" def __init__( self, @@ -184,7 +198,7 @@ class AdamModule(OptiModule): https://arxiv.org/abs/1904.09237 """ - name = "adam" + name: ClassVar[str] = "adam" def __init__( self, @@ -304,7 +318,7 @@ class YogiModule(AdamModule): https://arxiv.org/abs/1904.09237 """ - name = "yogi" + name: ClassVar[str] = "yogi" def __init__( self, diff --git a/declearn/optimizer/modules/_api.py b/declearn/optimizer/modules/_api.py index 75444d6b1a78c142eb4ac52c5ab20879a93dad4a..f3a87e7815bec2aec4f86a13dc34d94ba2a97fa8 100644 --- a/declearn/optimizer/modules/_api.py +++ b/declearn/optimizer/modules/_api.py @@ -1,9 +1,24 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Base API for plug-in gradients-alteration algorithms.""" from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Optional +from typing import Any, ClassVar, Dict, Optional from declearn.model.api import Vector from declearn.utils import ( @@ -76,15 +91,16 @@ class OptiModule(metaclass=ABCMeta): See `declearn.utils.register_type` for details on types registration. """ - name: str = NotImplemented - - aux_name: Optional[str] = None + name: ClassVar[str] = NotImplemented + aux_name: ClassVar[Optional[str]] = None def __init_subclass__( cls, register: bool = True, + **kwargs: Any, ) -> None: """Automatically type-register OptiModule subclasses.""" + super().__init_subclass__(**kwargs) if register: register_type(cls, cls.name, group="OptiModule") @@ -239,7 +255,7 @@ class OptiModule(metaclass=ABCMeta): Name based on which the module can be retrieved. Available as a class attribute. config: dict[str, any] - Configuration dict of the regularizer, that is to be + Configuration dict of the module, that is to be passed to its `from_config` class constructor. """ cls = access_registered(name, group="OptiModule") diff --git a/declearn/optimizer/modules/_clipping.py b/declearn/optimizer/modules/_clipping.py index c30691bb4511151f44d9c7022bd2ae4c75822fed..24756b2bf632fe76b82fcbe0b7b708a83fc92973 100644 --- a/declearn/optimizer/modules/_clipping.py +++ b/declearn/optimizer/modules/_clipping.py @@ -1,14 +1,27 @@ # coding: utf-8 -"""Batch-averaged gradients clipping plug-in modules.""" +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from typing import Any, Dict +"""Batch-averaged gradients clipping plug-in modules.""" +from typing import Any, ClassVar, Dict from declearn.model.api import Vector from declearn.optimizer.modules._api import OptiModule - __all__ = ["L2Clipping"] @@ -32,7 +45,7 @@ class L2Clipping(OptiModule): to bound the sensitivity associated to that action. """ - name = "l2-clipping" + name: ClassVar[str] = "l2-clipping" def __init__( self, diff --git a/declearn/optimizer/modules/_momentum.py b/declearn/optimizer/modules/_momentum.py index 1b97bf2736e2dc2709282fd9214cc097b8aa1b6f..0923ceb0baed5b5b4865e31d33e7f0d399946c73 100644 --- a/declearn/optimizer/modules/_momentum.py +++ b/declearn/optimizer/modules/_momentum.py @@ -1,8 +1,23 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Base API and common examples of plug-in gradients-alteration algorithms.""" -from typing import Any, Dict, Union +from typing import Any, ClassVar, Dict, Union from declearn.model.api import Vector from declearn.optimizer.modules._api import OptiModule @@ -48,7 +63,7 @@ class MomentumModule(OptiModule): https://proceedings.mlr.press/v28/sutskever13.pdf """ - name = "momentum" + name: ClassVar[str] = "momentum" def __init__( self, @@ -114,7 +129,7 @@ class EWMAModule(OptiModule): decaying moving-average of past gradients. """ - name = "ewma" + name: ClassVar[str] = "ewma" def __init__( self, @@ -184,7 +199,7 @@ class YogiMomentumModule(EWMAModule): Adaptive Methods for Nonconvex Optimization. """ - name = "yogi-momentum" + name: ClassVar[str] = "yogi-momentum" def run( self, diff --git a/declearn/optimizer/modules/_noise.py b/declearn/optimizer/modules/_noise.py index cf91f076ceffe2c87c5346f725018c81a04f9a73..04785191068e1ad3defda614896ff26d0d70dfb4 100644 --- a/declearn/optimizer/modules/_noise.py +++ b/declearn/optimizer/modules/_noise.py @@ -1,10 +1,25 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Noise-addition modules for DP using cryptographically-strong RNG.""" from abc import ABCMeta, abstractmethod from random import SystemRandom -from typing import Any, Dict, Optional, Tuple +from typing import Any, ClassVar, Dict, Optional, Tuple import numpy as np import scipy.stats # type: ignore @@ -26,7 +41,7 @@ class NoiseModule(OptiModule, metaclass=ABCMeta, register=False): or slower cryptographically secure pseudo-random numbers (CSPRN). """ - name = "abstract-noise" + name: ClassVar[str] = "abstract-noise" def __init__( self, @@ -88,7 +103,6 @@ class NoiseModule(OptiModule, metaclass=ABCMeta, register=False): dtype: str, ) -> np.ndarray: """Sample a noise tensor from a module-specific distribution.""" - return NotImplemented class GaussianNoiseModule(NoiseModule): @@ -98,7 +112,7 @@ class GaussianNoiseModule(NoiseModule): or slower cryptographically secure pseudo-random numbers (CSPRN). """ - name = "gaussian-noise" + name: ClassVar[str] = "gaussian-noise" def __init__( self, diff --git a/declearn/optimizer/modules/_scaffold.py b/declearn/optimizer/modules/_scaffold.py index b36ce060099aa008b9879e6f696d2494a59e6dfe..46cae6b4c6affccdd690b02da81c3d982f6445c6 100644 --- a/declearn/optimizer/modules/_scaffold.py +++ b/declearn/optimizer/modules/_scaffold.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """SCAFFOLD algorithm for FL, implemented as a pair of plug-in modules. The pair of `OptiModule` classes implemented here serve to implement @@ -16,13 +31,11 @@ References: https://arxiv.org/abs/1910.06378 """ -from typing import Any, Dict, List, Optional, Union - +from typing import Any, ClassVar, Dict, List, Optional, Union from declearn.model.api import Vector from declearn.optimizer.modules._api import OptiModule - __all__ = [ "ScaffoldClientModule", "ScaffoldServerModule", @@ -80,8 +93,8 @@ class ScaffoldClientModule(OptiModule): https://arxiv.org/abs/1910.06378 """ - name = "scaffold-client" - aux_name = "scaffold" + name: ClassVar[str] = "scaffold-client" + aux_name: ClassVar[str] = "scaffold" def __init__( self, @@ -214,8 +227,8 @@ class ScaffoldServerModule(OptiModule): https://arxiv.org/abs/1910.06378 """ - name = "scaffold-server" - aux_name = "scaffold" + name: ClassVar[str] = "scaffold-server" + aux_name: ClassVar[str] = "scaffold" def __init__( self, diff --git a/declearn/optimizer/regularizers/__init__.py b/declearn/optimizer/regularizers/__init__.py index 951abfae22327d6e6ffbedf659da0f4e02c4a294..87c797080dec30547842e769415092ff9c95f10b 100644 --- a/declearn/optimizer/regularizers/__init__.py +++ b/declearn/optimizer/regularizers/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Optimizer loss-regularization algorithms, implemented as plug-in modules. Base class implemented here: diff --git a/declearn/optimizer/regularizers/_api.py b/declearn/optimizer/regularizers/_api.py index a199dab091c88ba122d0b92ff3dbc9e75388784e..7fb771a4e8ed14c3b61d45587112f80aba794eb8 100644 --- a/declearn/optimizer/regularizers/_api.py +++ b/declearn/optimizer/regularizers/_api.py @@ -1,9 +1,24 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Base API for loss regularization optimizer plug-ins.""" from abc import ABCMeta, abstractmethod -from typing import Any, Dict +from typing import Any, ClassVar, Dict from declearn.model.api import Vector from declearn.utils import ( @@ -12,7 +27,6 @@ from declearn.utils import ( register_type, ) - __all__ = [ "Regularizer", ] @@ -68,13 +82,15 @@ class Regularizer(metaclass=ABCMeta): See `declearn.utils.register_type` for details on types registration. """ - name: str = NotImplemented + name: ClassVar[str] = NotImplemented def __init_subclass__( cls, register: bool = True, + **kwargs: Any, ) -> None: """Automatically type-register Regularizer subclasses.""" + super().__init_subclass__(**kwargs) if register: register_type(cls, cls.name, group="Regularizer") @@ -115,7 +131,6 @@ class Regularizer(metaclass=ABCMeta): fully compatible with the input one - only the values of the wrapped coefficients may have changed. """ - return NotImplemented def on_round_start( self, diff --git a/declearn/optimizer/regularizers/_base.py b/declearn/optimizer/regularizers/_base.py index 0bcac9cf2b6cda276d27b3dfb5be91f9f1d1ea36..f96e6f3681f217c5733a9f87c75b8eff74b01d33 100644 --- a/declearn/optimizer/regularizers/_base.py +++ b/declearn/optimizer/regularizers/_base.py @@ -1,13 +1,27 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Common plug-in loss-regularization plug-ins.""" -from typing import Optional +from typing import ClassVar, Optional from declearn.model.api import Vector from declearn.optimizer.regularizers._api import Regularizer - __all__ = [ "FedProxRegularizer", "LassoRegularizer", @@ -40,7 +54,7 @@ class FedProxRegularizer(Regularizer): https://arxiv.org/abs/1812.06127 """ - name = "fedprox" + name: ClassVar[str] = "fedprox" def __init__( self, @@ -76,7 +90,7 @@ class LassoRegularizer(Regularizer): grads += alpha * sign(weights) """ - name = "lasso" + name: ClassVar[str] = "lasso" def run( self, @@ -97,7 +111,7 @@ class RidgeRegularizer(Regularizer): grads += alpha * 2 * weights """ - name = "ridge" + name: ClassVar[str] = "ridge" def run( self, diff --git a/declearn/test_utils/__init__.py b/declearn/test_utils/__init__.py index 177db06ee824b6565f87bd1fa79d9c26980ae2da..507d8ebffeea17c1601e4bfca51eaa2d2298adf3 100644 --- a/declearn/test_utils/__init__.py +++ b/declearn/test_utils/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Collection of utils for running tests and examples around declearn.""" from ._assertions import assert_json_serializable_dict diff --git a/declearn/test_utils/_assertions.py b/declearn/test_utils/_assertions.py index 1c6e8b05410e6ba00ad8536eb88dc0f099708597..456720c9214b8c520a07f6cf7f08df3926e373a3 100644 --- a/declearn/test_utils/_assertions.py +++ b/declearn/test_utils/_assertions.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Custom "assert" functions commonly used in declearn tests.""" import json diff --git a/declearn/test_utils/_gen_ssl.py b/declearn/test_utils/_gen_ssl.py index 6d49b7cb9b8a5df6376d4f51be1a277898382bc5..5c0cd8f395348a91289726dfd328b87e4a83dce7 100644 --- a/declearn/test_utils/_gen_ssl.py +++ b/declearn/test_utils/_gen_ssl.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Shared fixtures for declearn.communication module testing.""" import os diff --git a/declearn/test_utils/_multiprocess.py b/declearn/test_utils/_multiprocess.py index faa266c9a75ccfcf9f843da54bcf0145ecf1879b..05eeb251dd471b2fc3079f9b97beacbea03d0ba1 100644 --- a/declearn/test_utils/_multiprocess.py +++ b/declearn/test_utils/_multiprocess.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Utils to run concurrent routines parallelly using multiprocessing.""" import multiprocessing as mp diff --git a/declearn/test_utils/_vectors.py b/declearn/test_utils/_vectors.py index 3fc24f479bdee682be3d1de1aa5ab8220326d781..d789cec4933795a0364b3083164ff449b9f79c57 100644 --- a/declearn/test_utils/_vectors.py +++ b/declearn/test_utils/_vectors.py @@ -1,20 +1,33 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Shared objects for testing purposes.""" import importlib import typing -from typing import List, Optional, Type +from typing import List, Literal, Optional, Type import numpy as np import pkg_resources from numpy.typing import ArrayLike -from typing_extensions import Literal # future: import from typing (Py>=3.8) from declearn.model.api import Vector from declearn.model.sklearn import NumpyVector - __all__ = [ "FrameworkType", "GradientsTestCase", @@ -102,14 +115,28 @@ class GradientsTestCase: ) @property - def mock_allzero_gradient(self) -> Vector: + def mock_ones(self) -> Vector: """Instantiate a Vector with random-valued mock gradients. Note: the RNG used to generate gradients has a fixed seed, to that gradients have the same values whatever the tensor framework used is. """ - shapes = [(64, 32), (32,), (32, 16), (16,), (16, 1), (1,)] + shapes = [(5, 5), (4,), (1,)] + values = [np.ones(shape) for shape in shapes] + return self.vector_cls( + {str(idx): self.convert(value) for idx, value in enumerate(values)} + ) + + @property + def mock_zeros(self) -> Vector: + """Instantiate a Vector with random-valued mock gradients. + + Note: the RNG used to generate gradients has a fixed seed, + to that gradients have the same values whatever the + tensor framework used is. + """ + shapes = [(5, 5), (4,), (1,)] values = [np.zeros(shape) for shape in shapes] return self.vector_cls( {str(idx): self.convert(value) for idx, value in enumerate(values)} diff --git a/declearn/typing.py b/declearn/typing.py index 5c0ba98fc9e2d1c0ef408fa3f7cc6522d0f3dca1..2f18323081f238caa204a2fbf18ed70845187299 100644 --- a/declearn/typing.py +++ b/declearn/typing.py @@ -1,15 +1,27 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Type hinting utils, defined and exposed for code readability purposes.""" from abc import ABCMeta, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Protocol, Tuple, Union from numpy.typing import ArrayLike -from typing_extensions import ( - Protocol, # future: import from typing (Py>=3.8) - Self, # future: import from typing (Py>=3.11) -) +from typing_extensions import Self # future: import from typing (Py>=3.11) __all__ = [ diff --git a/declearn/utils/__init__.py b/declearn/utils/__init__.py index 7d10cb2075452dc87c7ead55be378322b3bd03d6..00bf9fc14932fef61b6eaa4ab9b8aefeb404374c 100644 --- a/declearn/utils/__init__.py +++ b/declearn/utils/__init__.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Shared utils used across declearn. The key functionalities implemented here are: diff --git a/declearn/utils/_dataclass.py b/declearn/utils/_dataclass.py index 89c2214c49789fe6bc989f09127f5b03f4cc4b77..da35292529e788fa5186dca7fddf95671dedc78c 100644 --- a/declearn/utils/_dataclass.py +++ b/declearn/utils/_dataclass.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Dataclass-generation tools. These tools are meant to reduce redundant code and ease maintanability @@ -69,7 +84,7 @@ def dataclass_from_func( # Parse the function's signature into dataclass Field instances. signature = inspect.signature(func) parameters = list(signature.parameters.values()) - fields = parameters_to_fields(parameters) + fields = _parameters_to_fields(parameters) # Make a dataclass out of the former fields. if not name: name = "".join(w.capitalize() for w in func.__name__.split("_")) @@ -134,7 +149,7 @@ def dataclass_from_init( """ # Parse the class's __init__ signature into dataclass Field instances. parameters = list(inspect.signature(cls.__init__).parameters.values())[1:] - fields = parameters_to_fields(parameters) + fields = _parameters_to_fields(parameters) # Make a dataclass out of the former fields. name = name or f"{cls.__name__}Config" dcls = dataclasses.make_dataclass(name, fields) # type: Type @@ -151,6 +166,7 @@ def dataclass_from_init( args_field = param.name if param.kind is param.VAR_KEYWORD: kwargs_field = param.name + # Add a method to instantiate from the dataclass. def instantiate(self) -> cls: # type: ignore """Instantiate from the wrapped init parameters.""" @@ -168,7 +184,7 @@ def dataclass_from_init( return dcls # type: ignore -def parameters_to_fields( +def _parameters_to_fields( params: List[inspect.Parameter], ) -> List[Tuple[str, Type, dataclasses.Field]]: """Parse function or method parameters into dataclass fields.""" diff --git a/declearn/utils/_json.py b/declearn/utils/_json.py index a8d2a29b0e626064356bb616ae73179c5d34f0f6..2d6ac181e9f996ec59ea1f003fa01ba64002db12 100644 --- a/declearn/utils/_json.py +++ b/declearn/utils/_json.py @@ -1,13 +1,26 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Tools to add support for non-standard types' JSON-(de)serialization.""" import dataclasses import json import warnings -from typing import Any, Callable, Dict, Optional, Type - -from typing_extensions import TypedDict # future: import from typing (Py>=3.8) +from typing import Any, Callable, Dict, Optional, Type, TypedDict __all__ = [ diff --git a/declearn/utils/_logging.py b/declearn/utils/_logging.py index d8bb026b4309ec9b6f06972833917d4b0ce8d981..b1c926918541db206dc71f3d7b1a935177daf596 100644 --- a/declearn/utils/_logging.py +++ b/declearn/utils/_logging.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Logging tools for declearn internal use.""" import logging diff --git a/declearn/utils/_numpy.py b/declearn/utils/_numpy.py index 243cc2e34e7a1510ed4498190f41821bb41f3811..5c791d66940106fe99fe8d354762397c300bd8ad 100644 --- a/declearn/utils/_numpy.py +++ b/declearn/utils/_numpy.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Numpy-related declearn utils.""" from typing import List, Tuple diff --git a/declearn/utils/_register.py b/declearn/utils/_register.py index ebbefcb950362595deb045bfc666ad793be785f1..7453f4747171117a8e69f9ef6c68eeef2b9458a5 100644 --- a/declearn/utils/_register.py +++ b/declearn/utils/_register.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Generic types-registration system backing some (de)serialization utils.""" import functools diff --git a/declearn/utils/_serialize.py b/declearn/utils/_serialize.py index e73987c6ebdd7fbe9a227b179ed8c634de9c6767..ab2d296220f31463d9a6223b672c9b50de0e72d5 100644 --- a/declearn/utils/_serialize.py +++ b/declearn/utils/_serialize.py @@ -1,14 +1,26 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Generic tools to (de-)serialize custom declearn objects to and from JSON.""" import dataclasses -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Type, TypedDict, Union -from typing_extensions import ( - Self, # future: import from typing (Py>=3.11) - TypedDict, # future: import from typing (Py>=3.8) -) +from typing_extensions import Self # future: import from typing (Py>=3.11) from declearn.utils._register import ( access_registered, diff --git a/declearn/utils/_toml_config.py b/declearn/utils/_toml_config.py index 323adab3651de4901112b82dc56662aa1bb31a5b..5f5f7e321729ec18c637acd631a8c4673507c8ca 100644 --- a/declearn/utils/_toml_config.py +++ b/declearn/utils/_toml_config.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Base class to define TOML-parsable configuration containers.""" import dataclasses @@ -10,11 +25,11 @@ try: import tomllib # type: ignore except ModuleNotFoundError: import tomli as tomllib # type: ignore # required re-definition + from typing import Any, Dict, Optional, Type, TypeVar, Union from typing_extensions import Self # future: import from typing (py >=3.11) - __all__ = [ "TomlConfig", ] @@ -23,7 +38,7 @@ __all__ = [ T = TypeVar("T") -def isinstance_generic(inputs: Any, typevar: Type) -> bool: +def _isinstance_generic(inputs: Any, typevar: Type) -> bool: """Override of `isinstance` built-in that supports some typing generics. Note @@ -63,25 +78,25 @@ def isinstance_generic(inputs: Any, typevar: Type) -> bool: args = typing.get_args(typevar) # Case of a Union generic. if origin is typing.Union: - return any(isinstance_generic(inputs, typevar) for typevar in args) + return any(_isinstance_generic(inputs, typevar) for typevar in args) # Case of a Dict[..., ...] generic. if origin is dict: return ( isinstance(inputs, dict) - and all(isinstance_generic(k, args[0]) for k in inputs) - and all(isinstance_generic(v, args[1]) for v in inputs.values()) + and all(_isinstance_generic(k, args[0]) for k in inputs) + and all(_isinstance_generic(v, args[1]) for v in inputs.values()) ) # Case of a List[...] generic. if origin is list: return isinstance(inputs, list) and all( - isinstance_generic(e, args[0]) for e in inputs + _isinstance_generic(e, args[0]) for e in inputs ) # Case of a Tuple[...] generic. if origin is tuple: return ( isinstance(inputs, tuple) and len(inputs) == len(args) - and all(isinstance_generic(e, t) for e, t in zip(inputs, args)) + and all(_isinstance_generic(e, t) for e, t in zip(inputs, args)) ) # Unsupported cases. raise TypeError( @@ -90,12 +105,12 @@ def isinstance_generic(inputs: Any, typevar: Type) -> bool: ) -def parse_float(src: str) -> Optional[float]: +def _parse_float(src: str) -> Optional[float]: """Custom float parser that replaces nan values with None.""" return None if src == "nan" else float(src) -def instantiate_field( +def _instantiate_field( field: dataclasses.Field, # future: dataclasses.Field[T] (Py >=3.9) *args: Any, **kwargs: Any, @@ -244,7 +259,7 @@ class TomlConfig: Instantiated object that matches the field's specifications. """ # Case of valid inputs: return them as-is (including valid None). - if isinstance_generic(inputs, field.type): # see function's notes + if _isinstance_generic(inputs, field.type): # see function's notes return inputs # Case of None inputs: return default value if any, else raise. if inputs is None: @@ -262,16 +277,16 @@ class TomlConfig: return field.type.from_toml(inputs) # type: ignore # Otherwise, conduct minimal parsing. with open(inputs, "rb") as file: - config = tomllib.load(file, parse_float=parse_float) + config = tomllib.load(file, parse_float=_parse_float) section = config.get(field.name, config) # subsection or full file return ( - instantiate_field(field, **section) + _instantiate_field(field, **section) if isinstance(section, dict) - else instantiate_field(field, section) + else _instantiate_field(field, section) ) # Case of dict inputs: try instantiating the target type. if isinstance(inputs, dict): - return instantiate_field(field, **inputs) + return _instantiate_field(field, **inputs) # Otherwise, raise a TypeError. raise TypeError(f"Failed to parse inputs for field {field.name}.") @@ -315,7 +330,7 @@ class TomlConfig: # Parse the TOML configuration file. try: with open(path, "rb") as file: - config = tomllib.load(file, parse_float=parse_float) + config = tomllib.load(file, parse_float=_parse_float) except tomllib.TOMLDecodeError as exc: raise RuntimeError( "Failed to parse the TOML configuration file." diff --git a/examples/adding_rmsprop/readme.md b/examples/adding_rmsprop/readme.md index 7d47d6e91f2c32c7aacc7f142a8437f24bbd175d..40b748e5a37475944def19bcdc67642991e45f43 100644 --- a/examples/adding_rmsprop/readme.md +++ b/examples/adding_rmsprop/readme.md @@ -65,7 +65,7 @@ class RMSPropModule(OptiModule): # Identifier, that must be unique across modules for type-registration # purposes. This enables specifying the module in configuration files. - name = "rmsprop" + name:ClassVar[str] = "rmsprop" # Define optimizer parameters, here beta and eps diff --git a/examples/heart-uci/client.py b/examples/heart-uci/client.py index 7a85018e37e1ce0aab1a75379bdceae9a6b7b913..ff9ae126014afb869772a6d20b0f1290f1b55d30 100644 --- a/examples/heart-uci/client.py +++ b/examples/heart-uci/client.py @@ -1,3 +1,20 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Script to run a federated client on the heart-disease example.""" import argparse diff --git a/examples/heart-uci/data.py b/examples/heart-uci/data.py index cb11b2608e9ae04bf6f4bd819961b801a2be738e..20ed3d6cb63264d99279ac3425bf0da801232324 100644 --- a/examples/heart-uci/data.py +++ b/examples/heart-uci/data.py @@ -1,3 +1,20 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Script to download and pre-process the UCI Heart Disease Dataset.""" import argparse diff --git a/examples/heart-uci/gen_ssl.py b/examples/heart-uci/gen_ssl.py index c233c610a8bdc49b4b81e346f865e61e9e91bf81..94f81e982c85f0813abf342337387922482b0eeb 100644 --- a/examples/heart-uci/gen_ssl.py +++ b/examples/heart-uci/gen_ssl.py @@ -1,3 +1,20 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Script to generate self-signed SSL certificates for the demo.""" import os diff --git a/examples/heart-uci/run.py b/examples/heart-uci/run.py index 10bd93cf68edd8670ed76b82e8c032cdff834927..aa7f599a9ec018394dbc083e384c7591ce881793 100644 --- a/examples/heart-uci/run.py +++ b/examples/heart-uci/run.py @@ -1,3 +1,20 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Demonstration script using the UCI Heart Disease Dataset.""" import os diff --git a/examples/heart-uci/server.py b/examples/heart-uci/server.py index 93eea30962a7b70f30098d3e5eb71882270f20db..f63607c6e37e0b0ac088a0e64dbd6524bd1df239 100644 --- a/examples/heart-uci/server.py +++ b/examples/heart-uci/server.py @@ -1,3 +1,20 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Script to run a federated server on the heart-disease example.""" import argparse diff --git a/pyproject.toml b/pyproject.toml index 4296a8e299d37189cd03ef61c1187e4d4ef60d20..08608878e46de4de7fe0b17e54a7b9a4b29b0218 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,26 +8,29 @@ requires = [ [project] name = "declearn" -version = "2.0.0.beta3" -description = "Declearn - a python package for decentralized learning." +version = "2.0.0" +description = "Declearn - a python package for private decentralized learning." readme = "README.md" -requires-python = ">=3.7" -license = {text = "MIT"} +requires-python = ">=3.8" +license = {file = "LICENSE"} authors = [ {name = "Paul Andrey", email = "paul.andrey@inria.fr"}, {name = "Nathan Bigaud", email = "nathan.bigaud@inria.fr"}, - {name = "Nathalie Vauquier", email = "nathalie.vauquier@inria.fr"}, ] maintainers = [ {name = "Paul Andrey", email = "paul.andrey@inria.fr"}, - {name = "Nathalie Vauquier", email = "nathalie.vauquier@inria.fr"}, + {name = "Nathan Bigaud", email = "nathan.bigaud@inria.fr"}, ] classifiers = [ - "Development Status :: 4 - Beta", + "Development Status :: 5 - Production/Stable", "Intended Audience :: Science/Research", - "License :: OSI Approved :: MIT License", + "License :: OSI Approved :: Apache Software License", "Operating System :: UNIX", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Mathematics", "Typing :: Typed", @@ -36,59 +39,59 @@ dependencies = [ "cryptography >= 35.0", "grpcio >= 1.45", "pandas >= 1.2", - "scikit-learn >= 1.0", "tomli >= 2.0 ; python_version < '3.11'", "typing_extensions >= 4.0", - "websockets >= 10.1", + "websockets ~= 10.1", ] [project.optional-dependencies] all = [ # all non-tests extra dependencies - "functorch >= 0.1", + "functorch", "grpcio >= 1.45", - "opacus >= 1.1", + "opacus ~= 1.1", "protobuf >= 3.19", - "tensorflow >= 2.5", - "torch >= 1.10", - "websockets >= 10.1", + "tensorflow ~= 2.5", + "torch ~= 1.10", + "websockets ~= 10.1", ] dp = [ - "opacus >= 1.1", + "opacus ~= 1.1", ] grpc = [ "grpcio >= 1.45", "protobuf >= 3.19", ] tensorflow = [ - "tensorflow >= 2.5", + "tensorflow ~= 2.5", ] torch = [ - "functorch >= 0.1", # note: functorch is included with torch>=1.13 - "torch >= 1.10", + "functorch", # note: functorch is included with torch>=1.13 + "torch ~= 1.10", ] websockets = [ - "websockets >= 10.1", + "websockets ~= 10.1", ] tests = [ # test-specific dependencies - "black ~= 22.0", + "black ~= 23.0", "mypy >= 0.930", "pylint >= 2.14", "pytest >= 6.1", "pytest-asyncio", # other extra dependencies (copy of "all") - "functorch >= 0.1", + "functorch", "grpcio >= 1.45", - "opacus >= 1.1", + "opacus ~= 1.1", "protobuf >= 3.19", - "tensorflow >= 2.5", - "torch >= 1.10", - "websockets >= 10.1", + "tensorflow ~= 2.5", + "torch ~= 1.10", + "websockets ~= 10.1", ] [project.urls] -source = "https://gitlab.inria.fr/magnet/declearn/declearn2" +homepage = "https://gitlab.inria.fr/magnet/declearn/declearn2" +repository = "https://gitlab.inria.fr/magnet/declearn/declearn2.git" [tool.black] line-length = 79 diff --git a/test/communication/conftest.py b/test/communication/conftest.py index cb09fd0b83c757771f2ea176a63e37825391e4d3..5c0c4e34a7ff57fd76125176dbd4464724e33729 100644 --- a/test/communication/conftest.py +++ b/test/communication/conftest.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Shared fixtures for declearn.communication module testing.""" import tempfile diff --git a/test/communication/test_grpc.py b/test/communication/test_grpc.py index 7530189ab3dd61e8ee24cd4a790c3b81b4f585bd..3d5f88562424e2ccd7dcc2e311c4fd2de06e3956 100644 --- a/test/communication/test_grpc.py +++ b/test/communication/test_grpc.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for gRPC network communication tools. The tests implemented here only test that communications work as expected, diff --git a/test/communication/test_routines.py b/test/communication/test_routines.py index 3f04cd1deaba17cadcfee73632bafda88f67c6e5..c1d746f2258b3a921f5430b6408f6f73d2c60a53 100644 --- a/test/communication/test_routines.py +++ b/test/communication/test_routines.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Functional test for declearn.communication classes. The test implemented here spawns a NetworkServer endpoint as well as one @@ -131,6 +146,7 @@ def _build_server_func( "certificate": ssl_cert["server_cert"] if use_ssl else None, "private_key": ssl_cert["server_pkey"] if use_ssl else None, } # type: Dict[str, Any] + # Define a coroutine that spawns and runs a server. async def server_coroutine() -> None: """Spawn a client and run `server_routine` in its context.""" @@ -158,6 +174,7 @@ def _build_client_funcs( server_uri = "localhost:8765" if protocol == "websockets": server_uri = f"ws{'s' * use_ssl}://{server_uri}" + # Define a coroutine that spawns and runs a client. async def client_coroutine( name: str, diff --git a/test/conftest.py b/test/conftest.py index 125cd6f80687f649c669bb679810afcdfe37f5a3..5d51c2c90bb87f3062215fc1a537281c810a3e19 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Shared pytest configuration code for the test suite.""" import pytest diff --git a/test/main/test_checkpoint.py b/test/main/test_checkpoint.py index 4a3e4d8831b88d31d12ff656b46a395782852e00..f35ced380c586100bb85cb2e572027c9f3daea4f 100644 --- a/test/main/test_checkpoint.py +++ b/test/main/test_checkpoint.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for Checkpointer class.""" import json @@ -9,9 +24,9 @@ from typing import Dict, Iterator, List, Union from unittest import mock import numpy as np -import pandas as pd +import pandas as pd # type: ignore import pytest -from sklearn.linear_model import SGDClassifier +from sklearn.linear_model import SGDClassifier # type: ignore from declearn.main.utils import Checkpointer from declearn.model.api import Model diff --git a/test/main/test_train_manager.py b/test/main/test_train_manager.py index e29e08952dc05e7fb18f63777277d6923cb8cef8..8adf85514a36967707050e0cc9d75ed6e09a71ac 100644 --- a/test/main/test_train_manager.py +++ b/test/main/test_train_manager.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for `declearn.main.utils.TrainingManager`.""" from unittest import mock diff --git a/test/main/test_train_manager_dp.py b/test/main/test_train_manager_dp.py index 9452ce6062d0d005c64293e50288f97f3422eb88..c55b5abdc2fafff5ad9938790e03f8c529a029a0 100644 --- a/test/main/test_train_manager_dp.py +++ b/test/main/test_train_manager_dp.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for `declearn.main.privacy.DPTrainingManager`.""" import sys diff --git a/test/metrics/metric_testing.py b/test/metrics/metric_testing.py index 264c7b7d62e79ce86cb55230d90a13cd3b3cf898..4141465093b268eb348872971f4cda30166a22d1 100644 --- a/test/metrics/metric_testing.py +++ b/test/metrics/metric_testing.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Template test-suite for declearn Metric subclasses.""" import json diff --git a/test/metrics/test_binary_apr.py b/test/metrics/test_binary_apr.py index 0e3234bde187c586eec98ec01eec763849b16969..0473f550d12dbd8811f56e46a52dcb9cd64ef701 100644 --- a/test/metrics/test_binary_apr.py +++ b/test/metrics/test_binary_apr.py @@ -1,14 +1,28 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for `declearn.metrics.BinaryAccuracyPrecisionRecall`.""" import os import sys -from typing import Dict, Union, Tuple +from typing import Dict, Literal, Union, Tuple import numpy as np import pytest -from typing_extensions import Literal # future: import from typing (py>=3.8) from declearn.metrics import BinaryAccuracyPrecisionRecall diff --git a/test/metrics/test_binary_roc.py b/test/metrics/test_binary_roc.py index 0b84e2caf507daa2e6861bb4e54a3d4a24ab9929..880b843183a3f8def68e45d2d20c8e4e3c0d3bcf 100644 --- a/test/metrics/test_binary_roc.py +++ b/test/metrics/test_binary_roc.py @@ -1,15 +1,29 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for `declearn.metrics.BinaryRocAUC`.""" import os import sys -from typing import Dict, Union, Tuple +from typing import Dict, Literal, Union, Tuple import numpy as np import pytest import sklearn # type: ignore -from typing_extensions import Literal # future: import from typing (py>=3.8) from declearn.metrics import BinaryRocAUC diff --git a/test/metrics/test_mae_mse.py b/test/metrics/test_mae_mse.py index e125e8e5ab710e98648fe667399f1758aebabeff..823e89bb97f156177ebd5959c6cd3973f2f08813 100644 --- a/test/metrics/test_mae_mse.py +++ b/test/metrics/test_mae_mse.py @@ -1,14 +1,28 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit and functional tests for the MAE and MSE Metric subclasses.""" import os import sys -from typing import Dict, Union +from typing import Dict, Literal, Union import numpy as np import pytest -from typing_extensions import Literal # future: import from typing (py>=3.8) from declearn.metrics import MeanAbsoluteError, MeanSquaredError, Metric diff --git a/test/metrics/test_metricset.py b/test/metrics/test_metricset.py index 62cedf7138e7fa615a97991ebda9b139e1591a4f..3e66148fb218f69c16c8bf0228e2d960a2d92700 100644 --- a/test/metrics/test_metricset.py +++ b/test/metrics/test_metricset.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for `declearn.metrics.MetricSet`.""" from unittest import mock @@ -11,9 +26,9 @@ import pytest from declearn.metrics import MeanAbsoluteError, MeanSquaredError, MetricSet -def get_mock_metricset() -> Tuple[ - MeanAbsoluteError, MeanSquaredError, MetricSet -]: +def get_mock_metricset() -> ( + Tuple[MeanAbsoluteError, MeanSquaredError, MetricSet] +): """Provide with a MetricSet wrapping mock metrics.""" mae = mock.create_autospec(MeanAbsoluteError, instance=True) mae.name = MeanAbsoluteError.name diff --git a/test/metrics/test_multi_apr.py b/test/metrics/test_multi_apr.py index cbb8285dcc4708337b7f5900cc7028ce7bf8bc57..659bf706ad1d44bfd91dd9b1bbcda1bbdfc7453e 100644 --- a/test/metrics/test_multi_apr.py +++ b/test/metrics/test_multi_apr.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for `declearn.metrics.MulticlassAccuracyPrecisionRecall`.""" import os diff --git a/test/model/model_testing.py b/test/model/model_testing.py index 6267d2d788826bcf376abc0dee570b2770fd0340..4751a18fc3ae9c7ecfd809d541286c147027030d 100644 --- a/test/model/model_testing.py +++ b/test/model/model_testing.py @@ -1,12 +1,26 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Shared testing code for TensorFlow and Torch models' unit tests.""" import json from typing import Any, List, Protocol, Tuple, Type, Union import numpy as np -import pytest from declearn.model.api import Model, Vector from declearn.typing import Batch @@ -157,15 +171,6 @@ class ModelTestSuite: other = json.loads(gdump, object_hook=json_unpack) assert grads == other - def test_compute_loss( - self, - test_case: ModelTestCase, - ) -> None: - """Test that loss computation abides by its specs.""" - with pytest.warns(DeprecationWarning): - loss = test_case.model.compute_loss(test_case.dataset) - assert isinstance(loss, float) - def test_compute_batch_predictions( self, test_case: ModelTestCase, @@ -190,13 +195,13 @@ class ModelTestSuite: """Test that the exposed loss function abides by its specs.""" model = test_case.model batch = test_case.dataset[0] + # Test that sample-wise loss computation works. y_true, y_pred, s_wght = model.compute_batch_predictions(batch) s_loss = model.loss_function(y_true, y_pred).squeeze() assert isinstance(s_loss, np.ndarray) and s_loss.ndim == 1 assert len(s_loss) == len(s_wght if s_wght is not None else y_true) + # Test that sample weights are properly combinable with that loss. if s_wght is None: s_wght = np.ones_like(s_loss) r_loss = (s_loss * s_wght).sum() / s_wght.sum() - with pytest.warns(DeprecationWarning): - loss = model.compute_loss([batch]) - assert r_loss == loss + r_loss = float(r_loss) # conversion from numpy.float diff --git a/test/model/test_sksgd.py b/test/model/test_sksgd.py index 67ff9e64dc9a70b87330975bb153f8fe13ca1d50..b214f28ac1c0fb595365dec187ca8ac94d52ed71 100644 --- a/test/model/test_sksgd.py +++ b/test/model/test_sksgd.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for SklearnSGDModel.""" import sys diff --git a/test/model/test_tflow.py b/test/model/test_tflow.py index e39146c7edaea885195cd11e4e45d8839c27f97e..aafd5d5d567ff37baba544ab8ff34904a529674d 100644 --- a/test/model/test_tflow.py +++ b/test/model/test_tflow.py @@ -1,18 +1,35 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for TensorflowModel.""" import warnings import sys -from typing import Any, List +from typing import Any, List, Literal import numpy as np import pytest -with warnings.catch_warnings(): # silence tensorflow import-time warnings - warnings.simplefilter("ignore") - import tensorflow as tf # type: ignore -from typing_extensions import Literal # future: import from typing (Py>=3.8) +try: + with warnings.catch_warnings(): # silence tensorflow import-time warnings + warnings.simplefilter("ignore") + import tensorflow as tf # type: ignore +except ModuleNotFoundError: + pytest.skip("TensorFlow is unavailable", allow_module_level=True) from declearn.model.tensorflow import TensorflowModel, TensorflowVector from declearn.typing import Batch diff --git a/test/model/test_torch.py b/test/model/test_torch.py index c3731a526f66bd1c1e4de8a009dc49fa2367f8ef..18b6418a632182262e040ecada563e3cc958ebee 100644 --- a/test/model/test_torch.py +++ b/test/model/test_torch.py @@ -1,14 +1,32 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for TorchModel.""" import sys -from typing import Any, List, Tuple +from typing import Any, List, Literal, Tuple import numpy as np import pytest -import torch -from typing_extensions import Literal # future: import from typing (Py>=3.8) + +try: + import torch +except ModuleNotFoundError: + pytest.skip("PyTorch is unavailable", allow_module_level=True) from declearn.model.torch import TorchModel, TorchVector from declearn.typing import Batch diff --git a/test/model/test_vector.py b/test/model/test_vector.py new file mode 100644 index 0000000000000000000000000000000000000000..2b46f002329e47a7b098d20e83d7d86fc5c0040c --- /dev/null +++ b/test/model/test_vector.py @@ -0,0 +1,170 @@ +# coding: utf-8 + +"""Unit tests for Vector and its subclasses. + +This test makes use of `declearn.test_utils.list_available_frameworks` +so as to modularly run a standard test suite on the available Vector +subclasses. +""" + +import json + +import numpy as np +import pytest + +from declearn.test_utils import ( + FrameworkType, + GradientsTestCase, + list_available_frameworks, +) +from declearn.utils import json_pack, json_unpack + + +@pytest.fixture(name="framework", params=list_available_frameworks()) +def framework_fixture(request): + """Fixture to provide with the name of a model framework.""" + return request.param + + +class TestVectorAbstractMethods: + """Test abstract methods.""" + + def test_sum(self, framework: FrameworkType) -> None: + """Test coefficient-wise sum.""" + grad = GradientsTestCase(framework) + ones = grad.mock_ones + test_coefs = ones.sum().coefs + test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs] + values = [25.0, 4.0, 1.0] + assert values == test_values + + def test_max(self, framework: FrameworkType) -> None: + """Test coef.-wise, element-wise maximum wrt to another Vector.""" + grad = GradientsTestCase(framework) + ones, zeros = (grad.mock_ones, grad.mock_zeros) + values = [np.ones((5, 5)), np.ones((4,)), np.ones((1,))] + # test Vector + test_coefs = zeros.maximum(ones).coefs + test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs] + assert all( + (values[i] == test_values[i]).all() for i in range(len(values)) + ) + # test float + test_coefs = zeros.maximum(1.0).coefs + test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs] + assert all( + (values[i] == test_values[i]).all() for i in range(len(values)) + ) + + def test_min(self, framework: FrameworkType) -> None: + """Test coef.-wise, element-wise minimum wrt to another Vector.""" + grad = GradientsTestCase(framework) + ones, zeros = (grad.mock_ones, grad.mock_zeros) + values = [np.zeros((5, 5)), np.zeros((4,)), np.zeros((1,))] + # test Vector + test_coefs = ones.minimum(zeros).coefs + test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs] + assert all( + (values[i] == test_values[i]).all() for i in range(len(values)) + ) + # test float + test_coefs = zeros.minimum(1.0).coefs + test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs] + assert all( + (values[i] == test_values[i]).all() for i in range(len(values)) + ) + + def test_sign(self, framework: FrameworkType) -> None: + """Test coefficient-wise sign check""" + grad = GradientsTestCase(framework) + ones = grad.mock_ones + for vec in ones, -1 * ones: + test_coefs = vec.sign().coefs + test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs] + values = [grad.to_numpy(vec.coefs[el]) for el in vec.coefs] + assert all( + (values[i] == test_values[i]).all() for i in range(len(values)) + ) + + def test_eq(self, framework: FrameworkType) -> None: + """Test __eq__ operator""" + grad = GradientsTestCase(framework) + ones, ones_bis, zeros = grad.mock_ones, grad.mock_ones, grad.mock_zeros + rand = grad.mock_gradient + assert ones == ones_bis + assert zeros != ones + assert ones != rand + assert 1.0 != ones + + +class TestVector: + """Test non-abstract methods""" + + def test_operator(self, framework: FrameworkType) -> None: + "Test all element-wise operators wiring" + grad = GradientsTestCase(framework) + + def _get_sq_root_two(ones, zeros): + """Returns the comaprison of a hardcoded sequence of operations + with its exptected result""" + values = [ + el * (2 ** (1 / 2)) + for el in [np.ones((5, 5)), np.ones((4,)), np.ones((1,))] + ] + test_grad = (0 + (1.0 * ones + ones * 1.0) * ones / 1.0 - 0) ** ( + (zeros + ones - zeros) / (ones + ones) + ) + test_coefs = test_grad.coefs + test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs] + return all( + (values[i] == test_values[i]).all() for i in range(len(values)) + ) + + ones, zeros = grad.mock_ones, grad.mock_zeros + assert _get_sq_root_two(ones, zeros) + assert _get_sq_root_two(ones, 0) + + def test_pack(self, framework: FrameworkType) -> None: + """Test that `Vector.pack` returns JSON-serializable results.""" + grad = GradientsTestCase(framework) + ones = grad.mock_ones + packed = ones.pack() + # Check that the output is a dict with str keys. + assert isinstance(packed, dict) + assert all(isinstance(key, str) for key in packed) + # Check that the "packed" dict is JSON-serializable. + dump = json.dumps(packed, default=json_pack) + load = json.loads(dump, object_hook=json_unpack) + assert isinstance(load, dict) + assert load.keys() == packed.keys() + assert all(np.all(load[key] == packed[key]) for key in load) + + def test_unpack(self, framework: FrameworkType) -> None: + """Test that `Vector.unpack` counterparts `Vector.pack` adequately.""" + grad = GradientsTestCase(framework) + ones = grad.mock_ones + packed = ones.pack() + test_vec = grad.vector_cls.unpack(packed) + assert test_vec == ones + + def test_repr(self, framework: FrameworkType) -> None: + """Test shape and dtypes together using __repr__""" + grad = GradientsTestCase(framework) + test_value = repr(grad.mock_ones) + value = grad.mock_ones.coefs["0"] + arr_type = f"{type(value).__module__}.{type(value).__name__}" + value = ( + f"{grad.vector_cls.__name__} with 3 coefs:" + f"\n 0: float64 {arr_type} with shape (5, 5)" + f"\n 1: float64 {arr_type} with shape (4,)" + f"\n 2: float64 {arr_type} with shape (1,)" + ) + assert test_value == value + + def test_json_serialization(self, framework: FrameworkType) -> None: + """Test that a Vector instance is JSON-serializable.""" + vector = GradientsTestCase(framework).mock_gradient + dump = json.dumps(vector, default=json_pack) + loaded = json.loads(dump, object_hook=json_unpack) + assert isinstance(loaded, type(vector)) + assert loaded == vector diff --git a/test/optimizer/conftest.py b/test/optimizer/conftest.py index 08b73ed2b2c933ccfe3fa097547b1563adcbd671..b2878c7d0b57814c6fe67d209d1a9734e61f19ca 100644 --- a/test/optimizer/conftest.py +++ b/test/optimizer/conftest.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Shared pytest fixtures for testing optmizer and plugins.""" diff --git a/test/optimizer/optim_testing.py b/test/optimizer/optim_testing.py index 3f953d6db6a67c899c7ad176c18a946ebf339fe0..8b2469f0fa81aa38cfe2ea0d7e493f2c5b5e18a9 100644 --- a/test/optimizer/optim_testing.py +++ b/test/optimizer/optim_testing.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Shared code to define unit tests for declearn optimizer plug-in classes.""" import warnings diff --git a/test/optimizer/test_modules.py b/test/optimizer/test_modules.py index 55f4168f74f7a325552a8cf9d1e8112f47acf1b1..0bbde31ad6746bd94b2cc631c4d2b0a991645682 100644 --- a/test/optimizer/test_modules.py +++ b/test/optimizer/test_modules.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for OptiModule subclasses. This script implements unit tests that are automatically run diff --git a/test/optimizer/test_noise.py b/test/optimizer/test_noise.py index e8247e4ebb6a0f620bbf27f7c8ff80b796b7680c..7b40ff3811c2bf97723c8127fed859c5fb2955b6 100644 --- a/test/optimizer/test_noise.py +++ b/test/optimizer/test_noise.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Functional tests for NoiseModule subclasses. * Test that a given seed returns the same thing twice. diff --git a/test/optimizer/test_optimizer.py b/test/optimizer/test_optimizer.py index b4fe63fa592cbda0f3ebcbf716b220d1d396a752..1229c4c5dc841c06eb2440a65d5db4961367918b 100644 --- a/test/optimizer/test_optimizer.py +++ b/test/optimizer/test_optimizer.py @@ -1,9 +1,24 @@ # coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # type: ignore # mock objects everywhere """Unit tests for `declearn.optimizer.Optimizer`.""" -from typing import Any, Dict, Tuple +from typing import Any, ClassVar, Dict, Tuple from unittest import mock from uuid import uuid4 @@ -19,7 +34,7 @@ from declearn.test_utils import assert_json_serializable_dict class MockOptiModule(OptiModule): """Type-registered mock OptiModule subclass.""" - name = f"mock-{uuid4()}" + name: ClassVar[str] = f"mock-{uuid4()}" def __init__(self, **kwargs: Any) -> None: super().__init__() @@ -35,7 +50,7 @@ class MockOptiModule(OptiModule): class MockRegularizer(Regularizer): """Type-registered mock Regularizer subclass.""" - name = f"mock-{uuid4()}" + name: ClassVar[str] = f"mock-{uuid4()}" def __init__(self, **kwargs: Any) -> None: super().__init__() diff --git a/test/optimizer/test_regularizers.py b/test/optimizer/test_regularizers.py index 1e44ee44e8c2fa976fbec1f0d9e929164412c8ba..a51e8f2a034148466ab8f334b3481e5be3bbe038 100644 --- a/test/optimizer/test_regularizers.py +++ b/test/optimizer/test_regularizers.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for Regularizer subclasses. This script implements unit tests that are automatically run diff --git a/test/optimizer/test_scaffold.py b/test/optimizer/test_scaffold.py index 6515ef7372d14586877a70dbb59c4427544d135f..a7ae284b1c2f2be675a81d88d1895628a78be883 100644 --- a/test/optimizer/test_scaffold.py +++ b/test/optimizer/test_scaffold.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for Scaffold OptiModule subclasses.""" diff --git a/test/test_main.py b/test/test_main.py index be8ac62afd62bd3bfbc8b0f07c9eaf1bf97ffa05..f11ee0efcf411aa17d2526f3f9a287f35ea0bed6 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -1,20 +1,29 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Declearn demonstration / testing code.""" import tempfile import warnings -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional import numpy as np import pytest -with warnings.catch_warnings(): # silence tensorflow import-time warnings - warnings.simplefilter("ignore") - import tensorflow as tf # type: ignore -import torch -from typing_extensions import Literal # future: import from typing (Py>=3.8) - from declearn.communication import ( build_client, build_server, @@ -24,11 +33,26 @@ from declearn.communication.api import NetworkClient, NetworkServer from declearn.dataset import InMemoryDataset from declearn.model.api import Model from declearn.model.sklearn import SklearnSGDModel -from declearn.model.tensorflow import TensorflowModel -from declearn.model.torch import TorchModel from declearn.main import FederatedClient, FederatedServer from declearn.test_utils import run_as_processes +# Select the subset of tests to run, based on framework availability. +# Note: TensorFlow and Torch (-related) imports are delayed due to this. +# pylint: disable=ungrouped-imports +FRAMEWORKS = ["Sksgd", "Tflow", "Torch"] +try: + import tensorflow as tf # type: ignore +except ModuleNotFoundError: + FRAMEWORKS.remove("Tflow") +else: + from declearn.model.tensorflow import TensorflowModel +try: + import torch +except ModuleNotFoundError: + FRAMEWORKS.remove("Torch") +else: + from declearn.model.torch import TorchModel + class DeclearnTestCase: """Test-case for the "main" federated learning orchestrating classes.""" @@ -74,7 +98,7 @@ class DeclearnTestCase: def _build_tflow_model( self, - ) -> TensorflowModel: + ) -> Model: """Return a TensorflowModel suitable for the learning task.""" if self.kind == "Reg": output_layer = tf.keras.layers.Dense(1) @@ -98,8 +122,9 @@ class DeclearnTestCase: def _build_torch_model( self, - ) -> TorchModel: + ) -> Model: """Return a TorchModel suitable for the learning task.""" + # Build the model and return it. stack = [ torch.nn.Linear(32, 32), torch.nn.ReLU(), @@ -246,7 +271,7 @@ def run_test_case( @pytest.mark.parametrize("strategy", ["FedAvg", "FedAvgM", "Scaffold"]) -@pytest.mark.parametrize("framework", ["Sksgd", "Tflow", "Torch"]) +@pytest.mark.parametrize("framework", FRAMEWORKS) @pytest.mark.parametrize("kind", ["Reg", "Bin", "Clf"]) @pytest.mark.filterwarnings("ignore: PyTorch JSON serialization") def test_declearn( diff --git a/test/utils/test_json.py b/test/utils/test_json.py index b17e2d9c46d49aa9c6523944378cae3c7c835c7a..05cfb27b57e59b9dc35b322957ad64cc8fe8aaff 100644 --- a/test/utils/test_json.py +++ b/test/utils/test_json.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for `declearn.utils._json` tools.""" import json @@ -46,6 +61,7 @@ def test_add_json_support() -> None: not that the associated mechanics perform well. These are tested in `test_json_pack` and `test_json_unpack_known`. """ + # Declare a second, empty custom type for this test only. class OtherType: # pylint: disable=all pass @@ -66,6 +82,7 @@ def test_add_json_support() -> None: def test_json_pack() -> None: """Unit tests for `json_pack` with custom-specified objects.""" + # Define a subtype of CustomType (to ensure it is not supported). class SubType(CustomType): # pylint: disable=all pass diff --git a/test/utils/test_register.py b/test/utils/test_register.py index 302a04ebdbd0c0d425f52d01700aac62b86b5805..2cf5085c0062835c0d6a8b6df169611275fcbffe 100644 --- a/test/utils/test_register.py +++ b/test/utils/test_register.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for 'declearn.utils._register' tools.""" import time @@ -25,6 +40,7 @@ def test_create_types_registry() -> None: def test_register_type() -> None: """Unit tests for 'register_type' using valid instructions.""" + # Define mock custom classes. class BaseClass: # pylint: disable=all pass @@ -38,6 +54,7 @@ def test_register_type() -> None: assert register_type(BaseClass, name="base", group=group) is BaseClass # Register ChildClass. assert register_type(ChildClass, name="child", group=group) is ChildClass + # Register another BaseClass-inheriting class using decorator syntax. @register_type(name="other", group=group) class OtherChild(BaseClass): @@ -46,6 +63,7 @@ def test_register_type() -> None: def test_register_type_fails() -> None: """Unit tests for 'register_type' using invalid instructions.""" + # Define mock custom classes. class BaseClass: # pylint: disable=all pass @@ -69,6 +87,7 @@ def test_register_type_fails() -> None: def test_access_registered() -> None: """Unit tests for 'access_registered'.""" + # Define a mock custom class. class Class: # pylint: disable=all pass @@ -90,6 +109,7 @@ def test_access_registered() -> None: def test_access_registeration_info() -> None: """Unit tests for 'access_registration_info'.""" + # Define a pair of mock custom class. class Class_1: # pylint: disable=all pass @@ -116,6 +136,7 @@ def test_access_registeration_info() -> None: def test_access_types_mapping() -> None: """Unit tests for 'access_types_mapping'.""" group = f"test_{time.time_ns()}" + # Define mock custom type-registered classes. @register_type(name="base", group=group) @create_types_registry(name=group) diff --git a/test/utils/test_serialize.py b/test/utils/test_serialize.py index 468ecca945e79995dcb0890519cead56fa40c2e1..9874e7fd573cc9bc50b0fc060768a5d50775d5d5 100644 --- a/test/utils/test_serialize.py +++ b/test/utils/test_serialize.py @@ -1,5 +1,20 @@ # coding: utf-8 +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for `declearn.utils._serialize` tools. Note: some of these tests require `declearn.utils._register` tools @@ -44,6 +59,7 @@ class MockClass: @pytest.fixture(name="registered_class") def fixture_registered_class() -> Tuple[Type[MockClass], str]: """Provide with a type-registered MockClass subclass.""" + # Declare a subtype to avoid side effects between tests. class SubClass(MockClass): # pylint: disable=all pass