From 7a37aa52cb5b78d34800551b29cc5d8381487885 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 25 Nov 2022 16:28:07 +0100 Subject: [PATCH] Refactor argparse code from the heart-uci example. --- declearn/test_utils/__init__.py | 1 + declearn/test_utils/_argparse.py | 154 +++++++++++++++++++++++++++++++ examples/heart-uci/client.py | 35 ++----- examples/heart-uci/server.py | 55 +++-------- 4 files changed, 177 insertions(+), 68 deletions(-) create mode 100644 declearn/test_utils/_argparse.py diff --git a/declearn/test_utils/__init__.py b/declearn/test_utils/__init__.py index 08a136a3..65afb8b1 100644 --- a/declearn/test_utils/__init__.py +++ b/declearn/test_utils/__init__.py @@ -17,6 +17,7 @@ """Collection of utils for running tests and examples around declearn.""" +from ._argparse import setup_client_argparse, setup_server_argparse from ._assertions import ( assert_dict_equal, assert_list_equal, diff --git a/declearn/test_utils/_argparse.py b/declearn/test_utils/_argparse.py new file mode 100644 index 00000000..b59f99ff --- /dev/null +++ b/declearn/test_utils/_argparse.py @@ -0,0 +1,154 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils to set up command-line argument parsers for declearn examples.""" + +import argparse +from typing import Optional + + +__all__ = [ + "setup_client_argparse", + "setup_server_argparse", +] + + +def setup_client_argparse( + usage: Optional[str] = None, + default_uri: str = "wss://localhost:8765", + default_ptcl: str = "websockets", + default_cert: str = "./ca-cert.pem", +) -> argparse.ArgumentParser: + """Set up an ArgumentParser to be used in a client-side script. + + Arguments + --------- + usage: str or None, default=None + Optional usage string to add to the ArgumentParser. + default_uri: str, default="wss://localhost:8765" + Default value for the 'uri' argument. + default_ptcl: str, default="websockets" + Default value for the 'protocol' argument. + default_cert: str, default="./ca-cert.pem" + Default value for the 'certificate' argument. + + Returns + ------- + parser: argparse.ArgumentParser + ArgumentParser with pre-set optional arguments required + to configure network communications on the client side. + """ + parser = argparse.ArgumentParser( + usage=usage, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--uri", + dest="uri", + type=str, + help="server URI to which to connect", + default=default_uri, + ) + parser.add_argument( + "--protocol", + dest="protocol", + type=str, + help="name of the communication protocol to use", + default=default_ptcl, + ) + parser.add_argument( + "--cert", + dest="certificate", + type=str, + help="path to the client-side ssl certificate authority file", + default=default_cert, + ) + return parser + + +def setup_server_argparse( + usage: Optional[str] = None, + default_host: str = "localhost", + default_port: int = 8765, + default_ptcl: str = "websockets", + default_cert: str = "./server-cert.pem", + default_pkey: str = "./server-pkey.pem", +) -> argparse.ArgumentParser: + """Set up an ArgumentParser to be used in a server-side script. + + Arguments + --------- + usage: str or None, default=None + Optional usage string to add to the ArgumentParser. + default_host: str, default="localhost" + Default value for the 'host' argument. + default_port: int, default=8765 + Default value for the 'port' argument. + default_ptcl: str, default="websockets" + Default value for the 'protocol' argument. + default_cert: str, default="./server-cert.pem" + Default value for the 'certificate' argument. + default_pkey: str, default="./server-pkey.pem" + Default value for the 'private_key' argument. + + Returns + ------- + parser: argparse.ArgumentParser + ArgumentParser with pre-set optional arguments required + to configure network communications on the server side. + """ + # arguments serve modularity; pylint: disable=too-many-arguments + parser = argparse.ArgumentParser( + usage=usage, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--host", + dest="host", + type=str, + help="hostname or IP address on which to serve", + default=default_host, + ) + parser.add_argument( + "--port", + dest="port", + type=int, + help="communication port on which to serve", + default=default_port, + ) + parser.add_argument( + "--protocol", + dest="protocol", + type=str, + help="name of the communication protocol to use", + default=default_ptcl, + ) + parser.add_argument( + "--cert", + dest="certificate", + type=str, + help="path to the server-side ssl certificate", + default=default_cert, + ) + parser.add_argument( + "--pkey", + dest="private_key", + type=str, + help="path to the server-side ssl private key", + default=default_pkey, + ) + return parser diff --git a/examples/heart-uci/client.py b/examples/heart-uci/client.py index 8e30f6d6..b33fcd05 100644 --- a/examples/heart-uci/client.py +++ b/examples/heart-uci/client.py @@ -17,7 +17,6 @@ """Script to run a federated client on the heart-disease example.""" -import argparse import os import sys @@ -26,11 +25,15 @@ import pandas as pd # type: ignore from declearn.communication import NetworkClientConfig from declearn.dataset import InMemoryDataset from declearn.main import FederatedClient +from declearn.test_utils import setup_client_argparse FILEDIR = os.path.dirname(os.path.abspath(__file__)) # Perform local imports. sys.path.append(FILEDIR) -from data import get_data # pylint: disable=wrong-import-order +# pylint: disable=wrong-import-order, wrong-import-position +from data import get_data +# pylint: enable=wrong-import-order, wrong-import-position +sys.path.pop() def run_client( @@ -102,34 +105,16 @@ def run_client( # Called when the script is called directly (using `python client.py`). if __name__ == "__main__": # Parse command-line arguments. - parser = argparse.ArgumentParser() + parser = setup_client_argparse( + usage="Start a client providing a UCI Heart-Disease Dataset shard.", + default_cert=os.path.join(FILEDIR, "ca-cert.pem"), + ) parser.add_argument( "name", type=str, help="name of your client", choices=["cleveland", "hungarian", "switzerland", "va"], ) - parser.add_argument( - "--cert", - dest="cert_path", - type=str, - help="path to the client-side ssl certification", - default=os.path.join(FILEDIR, "ca-cert.pem"), - ) - parser.add_argument( - "--protocol", - dest="protocol", - type=str, - help="name of the communication protocol to use", - default="websockets", - ) - parser.add_argument( - "--uri", - dest="uri", - type=str, - help="server URI to which to connect", - default="wss://localhost:8765", - ) args = parser.parse_args() # Run the client routine. - run_client(args.name, args.cert_path, args.protocol, args.uri) + run_client(args.name, args.certificate, args.protocol, args.uri) diff --git a/examples/heart-uci/server.py b/examples/heart-uci/server.py index 56de202f..6ef73d29 100644 --- a/examples/heart-uci/server.py +++ b/examples/heart-uci/server.py @@ -17,21 +17,21 @@ """Script to run a federated server on the heart-disease example.""" -import argparse import os from declearn.communication import NetworkServerConfig from declearn.main import FederatedServer from declearn.main.config import FLOptimConfig, FLRunConfig from declearn.model.sklearn import SklearnSGDModel +from declearn.test_utils import setup_server_argparse FILEDIR = os.path.dirname(os.path.abspath(__file__)) def run_server( nb_clients: int, - sv_cert: str, - sv_pkey: str, + certificate: str, + private_key: str, protocol: str = "websockets", host: str = "localhost", port: int = 8765, @@ -42,9 +42,9 @@ def run_server( --------- nb_clients: int Exact number of clients used in this example. - sv_cert: str + certificate: str Path to the (self-signed) SSL certificate to use. - sv_pkey: str + private_key: str Path to the associated private-key to use. protocol: str, default="websockets" Name of the communication protocol to use. @@ -100,8 +100,8 @@ def run_server( protocol=protocol, host=host, port=port, - certificate=sv_cert, - private_key=sv_pkey, + certificate=certificate, + private_key=private_key, ) # (4) Instantiate and run a FederatedServer. @@ -132,48 +132,17 @@ def run_server( # Called when the script is called directly (using `python server.py`). if __name__ == "__main__": # Parse command-line arguments. - parser = argparse.ArgumentParser() + parser = setup_server_argparse( + usage="Start a server to train a logistic regression model.", + default_cert=os.path.join(FILEDIR, "server-cert.pem"), + default_pkey=os.path.join(FILEDIR, "server-pkey.pem"), + ) parser.add_argument( "nb_clients", type=int, help="number of clients", choices=[1, 2, 3, 4], ) - parser.add_argument( - "--cert", - dest="sv_cert", - type=str, - help="path to the server-side ssl certificate", - default=os.path.join(FILEDIR, "server-cert.pem"), - ) - parser.add_argument( - "--pkey", - dest="sv_pkey", - type=str, - help="path to the server-side ssl private key", - default=os.path.join(FILEDIR, "server-pkey.pem"), - ) - parser.add_argument( - "--protocol", - dest="protocol", - type=str, - help="name of the communication protocol to use", - default="websockets", - ) - parser.add_argument( - "--host", - dest="host", - type=str, - help="hostname or IP address on which to serve", - default="localhost", - ) - parser.add_argument( - "--port", - dest="port", - type=int, - help="communication port on which to serve", - default=8765, - ) args = parser.parse_args() # Run the server routine. run_server(**args.__dict__) -- GitLab