From 7a37aa52cb5b78d34800551b29cc5d8381487885 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Fri, 25 Nov 2022 16:28:07 +0100
Subject: [PATCH] Refactor argparse code from the heart-uci example.

---
 declearn/test_utils/__init__.py  |   1 +
 declearn/test_utils/_argparse.py | 154 +++++++++++++++++++++++++++++++
 examples/heart-uci/client.py     |  35 ++-----
 examples/heart-uci/server.py     |  55 +++--------
 4 files changed, 177 insertions(+), 68 deletions(-)
 create mode 100644 declearn/test_utils/_argparse.py

diff --git a/declearn/test_utils/__init__.py b/declearn/test_utils/__init__.py
index 08a136a3..65afb8b1 100644
--- a/declearn/test_utils/__init__.py
+++ b/declearn/test_utils/__init__.py
@@ -17,6 +17,7 @@
 
 """Collection of utils for running tests and examples around declearn."""
 
+from ._argparse import setup_client_argparse, setup_server_argparse
 from ._assertions import (
     assert_dict_equal,
     assert_list_equal,
diff --git a/declearn/test_utils/_argparse.py b/declearn/test_utils/_argparse.py
new file mode 100644
index 00000000..b59f99ff
--- /dev/null
+++ b/declearn/test_utils/_argparse.py
@@ -0,0 +1,154 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils to set up command-line argument parsers for declearn examples."""
+
+import argparse
+from typing import Optional
+
+
+__all__ = [
+    "setup_client_argparse",
+    "setup_server_argparse",
+]
+
+
+def setup_client_argparse(
+    usage: Optional[str] = None,
+    default_uri: str = "wss://localhost:8765",
+    default_ptcl: str = "websockets",
+    default_cert: str = "./ca-cert.pem",
+) -> argparse.ArgumentParser:
+    """Set up an ArgumentParser to be used in a client-side script.
+
+    Arguments
+    ---------
+    usage: str or None, default=None
+        Optional usage string to add to the ArgumentParser.
+    default_uri: str, default="wss://localhost:8765"
+        Default value for the 'uri' argument.
+    default_ptcl: str, default="websockets"
+        Default value for the 'protocol' argument.
+    default_cert: str, default="./ca-cert.pem"
+        Default value for the 'certificate' argument.
+
+    Returns
+    -------
+    parser: argparse.ArgumentParser
+        ArgumentParser with pre-set optional arguments required
+        to configure network communications on the client side.
+    """
+    parser = argparse.ArgumentParser(
+        usage=usage,
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "--uri",
+        dest="uri",
+        type=str,
+        help="server URI to which to connect",
+        default=default_uri,
+    )
+    parser.add_argument(
+        "--protocol",
+        dest="protocol",
+        type=str,
+        help="name of the communication protocol to use",
+        default=default_ptcl,
+    )
+    parser.add_argument(
+        "--cert",
+        dest="certificate",
+        type=str,
+        help="path to the client-side ssl certificate authority file",
+        default=default_cert,
+    )
+    return parser
+
+
+def setup_server_argparse(
+    usage: Optional[str] = None,
+    default_host: str = "localhost",
+    default_port: int = 8765,
+    default_ptcl: str = "websockets",
+    default_cert: str = "./server-cert.pem",
+    default_pkey: str = "./server-pkey.pem",
+) -> argparse.ArgumentParser:
+    """Set up an ArgumentParser to be used in a server-side script.
+
+    Arguments
+    ---------
+    usage: str or None, default=None
+        Optional usage string to add to the ArgumentParser.
+    default_host: str, default="localhost"
+        Default value for the 'host' argument.
+    default_port: int, default=8765
+        Default value for the 'port' argument.
+    default_ptcl: str, default="websockets"
+        Default value for the 'protocol' argument.
+    default_cert: str, default="./server-cert.pem"
+        Default value for the 'certificate' argument.
+    default_pkey: str, default="./server-pkey.pem"
+        Default value for the 'private_key' argument.
+
+    Returns
+    -------
+    parser: argparse.ArgumentParser
+        ArgumentParser with pre-set optional arguments required
+        to configure network communications on the server side.
+    """
+    # arguments serve modularity; pylint: disable=too-many-arguments
+    parser = argparse.ArgumentParser(
+        usage=usage,
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "--host",
+        dest="host",
+        type=str,
+        help="hostname or IP address on which to serve",
+        default=default_host,
+    )
+    parser.add_argument(
+        "--port",
+        dest="port",
+        type=int,
+        help="communication port on which to serve",
+        default=default_port,
+    )
+    parser.add_argument(
+        "--protocol",
+        dest="protocol",
+        type=str,
+        help="name of the communication protocol to use",
+        default=default_ptcl,
+    )
+    parser.add_argument(
+        "--cert",
+        dest="certificate",
+        type=str,
+        help="path to the server-side ssl certificate",
+        default=default_cert,
+    )
+    parser.add_argument(
+        "--pkey",
+        dest="private_key",
+        type=str,
+        help="path to the server-side ssl private key",
+        default=default_pkey,
+    )
+    return parser
diff --git a/examples/heart-uci/client.py b/examples/heart-uci/client.py
index 8e30f6d6..b33fcd05 100644
--- a/examples/heart-uci/client.py
+++ b/examples/heart-uci/client.py
@@ -17,7 +17,6 @@
 
 """Script to run a federated client on the heart-disease example."""
 
-import argparse
 import os
 import sys
 
@@ -26,11 +25,15 @@ import pandas as pd  # type: ignore
 from declearn.communication import NetworkClientConfig
 from declearn.dataset import InMemoryDataset
 from declearn.main import FederatedClient
+from declearn.test_utils import setup_client_argparse
 
 FILEDIR = os.path.dirname(os.path.abspath(__file__))
 # Perform local imports.
 sys.path.append(FILEDIR)
-from data import get_data  # pylint: disable=wrong-import-order
+# pylint: disable=wrong-import-order, wrong-import-position
+from data import get_data
+# pylint: enable=wrong-import-order, wrong-import-position
+sys.path.pop()
 
 
 def run_client(
@@ -102,34 +105,16 @@ def run_client(
 # Called when the script is called directly (using `python client.py`).
 if __name__ == "__main__":
     # Parse command-line arguments.
-    parser = argparse.ArgumentParser()
+    parser = setup_client_argparse(
+        usage="Start a client providing a UCI Heart-Disease Dataset shard.",
+        default_cert=os.path.join(FILEDIR, "ca-cert.pem"),
+    )
     parser.add_argument(
         "name",
         type=str,
         help="name of your client",
         choices=["cleveland", "hungarian", "switzerland", "va"],
     )
-    parser.add_argument(
-        "--cert",
-        dest="cert_path",
-        type=str,
-        help="path to the client-side ssl certification",
-        default=os.path.join(FILEDIR, "ca-cert.pem"),
-    )
-    parser.add_argument(
-        "--protocol",
-        dest="protocol",
-        type=str,
-        help="name of the communication protocol to use",
-        default="websockets",
-    )
-    parser.add_argument(
-        "--uri",
-        dest="uri",
-        type=str,
-        help="server URI to which to connect",
-        default="wss://localhost:8765",
-    )
     args = parser.parse_args()
     # Run the client routine.
-    run_client(args.name, args.cert_path, args.protocol, args.uri)
+    run_client(args.name, args.certificate, args.protocol, args.uri)
diff --git a/examples/heart-uci/server.py b/examples/heart-uci/server.py
index 56de202f..6ef73d29 100644
--- a/examples/heart-uci/server.py
+++ b/examples/heart-uci/server.py
@@ -17,21 +17,21 @@
 
 """Script to run a federated server on the heart-disease example."""
 
-import argparse
 import os
 
 from declearn.communication import NetworkServerConfig
 from declearn.main import FederatedServer
 from declearn.main.config import FLOptimConfig, FLRunConfig
 from declearn.model.sklearn import SklearnSGDModel
+from declearn.test_utils import setup_server_argparse
 
 FILEDIR = os.path.dirname(os.path.abspath(__file__))
 
 
 def run_server(
     nb_clients: int,
-    sv_cert: str,
-    sv_pkey: str,
+    certificate: str,
+    private_key: str,
     protocol: str = "websockets",
     host: str = "localhost",
     port: int = 8765,
@@ -42,9 +42,9 @@ def run_server(
     ---------
     nb_clients: int
         Exact number of clients used in this example.
-    sv_cert: str
+    certificate: str
         Path to the (self-signed) SSL certificate to use.
-    sv_pkey: str
+    private_key: str
         Path to the associated private-key to use.
     protocol: str, default="websockets"
         Name of the communication protocol to use.
@@ -100,8 +100,8 @@ def run_server(
         protocol=protocol,
         host=host,
         port=port,
-        certificate=sv_cert,
-        private_key=sv_pkey,
+        certificate=certificate,
+        private_key=private_key,
     )
 
     # (4) Instantiate and run a FederatedServer.
@@ -132,48 +132,17 @@ def run_server(
 # Called when the script is called directly (using `python server.py`).
 if __name__ == "__main__":
     # Parse command-line arguments.
-    parser = argparse.ArgumentParser()
+    parser = setup_server_argparse(
+        usage="Start a server to train a logistic regression model.",
+        default_cert=os.path.join(FILEDIR, "server-cert.pem"),
+        default_pkey=os.path.join(FILEDIR, "server-pkey.pem"),
+    )
     parser.add_argument(
         "nb_clients",
         type=int,
         help="number of clients",
         choices=[1, 2, 3, 4],
     )
-    parser.add_argument(
-        "--cert",
-        dest="sv_cert",
-        type=str,
-        help="path to the server-side ssl certificate",
-        default=os.path.join(FILEDIR, "server-cert.pem"),
-    )
-    parser.add_argument(
-        "--pkey",
-        dest="sv_pkey",
-        type=str,
-        help="path to the server-side ssl private key",
-        default=os.path.join(FILEDIR, "server-pkey.pem"),
-    )
-    parser.add_argument(
-        "--protocol",
-        dest="protocol",
-        type=str,
-        help="name of the communication protocol to use",
-        default="websockets",
-    )
-    parser.add_argument(
-        "--host",
-        dest="host",
-        type=str,
-        help="hostname or IP address on which to serve",
-        default="localhost",
-    )
-    parser.add_argument(
-        "--port",
-        dest="port",
-        type=int,
-        help="communication port on which to serve",
-        default=8765,
-    )
     args = parser.parse_args()
     # Run the server routine.
     run_server(**args.__dict__)
-- 
GitLab