diff --git a/declearn/utils/__init__.py b/declearn/utils/__init__.py index d34929f152633415ac78db3e11017101c596caf3..dd302fabb8557cb8460bd41acf579315fcdf43ed 100644 --- a/declearn/utils/__init__.py +++ b/declearn/utils/__init__.py @@ -89,6 +89,15 @@ Utils to set up and configure loggers: * [LOGGING_LEVEL_MAJOR][declearn.utils.LOGGING_LEVEL_MAJOR]: Custom "MAJOR" severity level, between stdlib "INFO" and "WARNING". +Requirement-specification utils +------------------------------- +Utils to specify, verify and import required third-party libraries: + +* [Requirement][declearn.utils.Requirement]: + Interface to specify, version-check and import python requirements. +* [RequirementError][declearn.utils.RequirementError]: + Custom Exception class to denote third-party requirement issues. + Miscellaneous ------------- @@ -137,6 +146,10 @@ from ._register import ( create_types_registry, register_type, ) +from ._requirement import ( + Requirement, + RequirementError, +) from ._serialize import ( ObjectConfig, deserialize_object, diff --git a/declearn/utils/_requirement.py b/declearn/utils/_requirement.py new file mode 100644 index 0000000000000000000000000000000000000000..39e6a63f4d76ed1a19e988e2ada69ce6e43a9d32 --- /dev/null +++ b/declearn/utils/_requirement.py @@ -0,0 +1,316 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils to specify, verify and import required third-party libraries.""" + +import importlib.metadata +import re +from types import ModuleType +from typing import Iterator, List, Optional, Tuple + +import packaging.requirements + + +__all__ = [ + "Requirement", + "RequirementError", +] + + +class RequirementError(Exception): + """Custom Exception class to denote third-party requirement issues.""" + + +class Requirement: + """Interface to specify, version-check and import python requirements. + + Usage + ----- + ``` + >>> # Equivalent of `import sklearn`, with scikit-learn >= 1.0. + >>> req_sklearn = Requirement( + ... "scikit-learn >=1.0", import_name="sklearn" + ... ) + >>> sklearn = req_sklearn.import_module() + >>> # Equivalent of `import numpy as np` with any version of numpy. + >>> req_numpy = Requirement("numpy") # any version + >>> np = req_numpy.import_module() + >>> # Equivalent to `import declearn` & `import declearn.model.torch`, + >>> # after verifying that "torch2" extras were installed with declearn. + >>> req_declearn = Requirement( + ... "declearn[torch2] ~=2.0", + ... submodules=["declearn.model.torch"], + ... ) + >>> declearn = req_declearn.import_module() + ``` + """ + + def __init__( + self, + requirement: str, + import_name: Optional[str] = None, + submodules: Optional[List[str]] = None, + ) -> None: + """Specify a requirement for a third-party module. + + Parameters + ---------- + requirement: + Dependency string of the required third-party package. + This may comprise extra and/or version specifiers. + import_name: + Optional import name of the package. + Only required for packages that are installed under another + name as their import one (e.g. `scklearn` the installation + name of which is `scikit-learn`). + submodules: + Optional list of submodules to import with the main one. + This should be used for submodules that are not imported + by default (e.g. `linear_model` for `sklearn`). + """ + self.requirement = requirement + self.import_name = import_name or self.install_name + self.submodules = [] if submodules is None else sorted(submodules) + + @property + def install_name(self) -> str: + """Installation name of the required dependency.""" + return packaging.requirements.Requirement(self.requirement).name + + @property + def pretty_name(self) -> str: + """String formatting this requirement in a user-friendly way.""" + string = f"'{self.import_name}'" + if self.submodules: + string += f", plus submodules {self.submodules}" + return string + + def __repr__(self) -> str: + return f"Requirement for package {self.pretty_name}." + + def raise_if_unavailable(self) -> None: + """Raise a `RequirementError` if the requirement is not satisfied.""" + verify_required_dependency_availability(self.requirement) + + def check_if_available(self) -> bool: + """Verify if the specified requirement is satisfied. + + Returns + ------- + available: + Bool indicating whether the requirement is available. + """ + try: + self.raise_if_unavailable() + except RequirementError: + return False + return True + + def import_module(self) -> ModuleType: + """Import the required package and return it. + + Also trigger the import of specified submodules. + + Returns + ------- + module: + The main required module. + + Raises + ------ + RequirementError + If the requirement is not available, or if something goes + wrong at import time. + """ + self.raise_if_unavailable() + self._verify_import_name() + self._prompt_user_agreement() + module = self._import_main_module() + self._import_submodules() + return module + + def _prompt_user_agreement(self) -> None: + """Prompt user to authorize module loading. + + Raise a `RequirementError` if the import is not authorized. + """ + prompt = ( + "declearn wants to import the following dependency: " + + f"{self.pretty_name}.\nDo you allow it? (y/N)" + ) + if not input(prompt).strip(" ").lower().startswith("y"): + raise RequirementError( + f"User disallowed the import of dependency {self.requirement}." + ) + + def _verify_import_name(self) -> None: + """Raise a `RequirementError` if the import name is incorrect. + + This method requires the targetted dependency to be installed. + It verifies that the specified import_name (if any) is linked + to the installation one, and raises if it is not. + """ + if self.import_name == self.install_name: + return + dep_name = self.install_name + pkg_name = self.import_name + matched = importlib.metadata.packages_distributions().get(pkg_name, []) + if not dep_name in matched: + raise RequirementError( + f"Specified import name '{pkg_name}' is not associated with " + f"the installed requirement's name '{dep_name}'." + ) + + def _import_main_module(self) -> ModuleType: + """Import the main required module and return it.""" + try: + return importlib.import_module(self.import_name) + except (ImportError, ModuleNotFoundError) as exc: + raise RequirementError( + f"Requirement '{self.requirement}' was verified to be " + "available, but failed to be imported under the name " + f"'{self.import_name}'." + ) from exc + + def _import_submodules(self) -> None: + """Import required submodules (do not return them).""" + prefix = f"{self.import_name}." + for name in self.submodules: + if not name.startswith(prefix): + name = prefix + name + try: + importlib.import_module(name) + except (ImportError, ModuleNotFoundError) as exc: + raise RequirementError( + f"Failed to import submodule '{name}' after successfully " + f"importing module '{self.import_name}'." + ) from exc + + +def verify_required_dependency_availability( + requirement_string: str, +) -> None: + """Raise a RequirementError is a given dependency is not satisfied. + + Parse the input dependency string, and verify that + (a) the dependency is installed, with a compatible version; + (b) its required extra dependencies are also installed. + + Parameters + ---------- + requirement_string: + Dependency string for the required third-party python package. + This should follow the formatting rules specified in python's + packaging system (notably used with pip, conda, etc.). + + Raises + ------ + RequirementError + If the package is missing, has an incompatible version, or is + missing any of its required extra dependencies. + """ + requirement = packaging.requirements.Requirement(requirement_string) + # Raise if the package is not installed (in any version). + distribution = get_distribution(requirement) + # Raise if the package is installed with an incompatible version. + if not requirement.specifier.contains(distribution.version): + raise RequirementError( + f"Package '{requirement.name}' is installed in version " + f"'{distribution.version}', which is incompatible with " + f"required version '{requirement.specifier}'." + ) + # Raise if the required extra dependencies are not installed. + verify_extra_dependencies(requirement, distribution) + + +def get_distribution( + requirement: packaging.requirements.Requirement, +) -> importlib.metadata.Distribution: + """Fetch metadata on a required third-party library. + + Raise a `RequirementError` if the requirement is not installed. + """ + try: + return importlib.metadata.distribution(requirement.name) + except importlib.metadata.PackageNotFoundError as exc: + raise RequirementError( + f"Package '{requirement.name}' is not installed." + ) from exc + + +def verify_extra_dependencies( + requirement: packaging.requirements.Requirement, + distribution: importlib.metadata.Distribution, +) -> None: + """Raise a `RequirementError` if required extra dependencies are missing. + + Parameters + ---------- + requirement: + Requirement object wrapping information on a required package. + distribution: + Distribution object wrapping information on the installed + version of the required package. + + Raises + ------ + RequirementError + If any of the extra-flag-based dependencies of `distribution` that + is required by `requirement` is not installed in a proper version. + """ + iter_deps = yield_required_extra_dependencies(requirement, distribution) + for dependency, extra_name in iter_deps: + try: + verify_required_dependency_availability(dependency) + except RequirementError as exc: + raise RequirementError( + f"Package '{requirement.name}' is missing a required extra " + f"dependency: '{dependency}', from extra '{extra_name}'." + ) from exc + + +def yield_required_extra_dependencies( + requirement: packaging.requirements.Requirement, + distribution: importlib.metadata.Distribution, +) -> Iterator[Tuple[str, str]]: + """Yield dependency strings for required extra dependencies. + + Parameters + ---------- + requirement: + Requirement object wrapping information on a required package. + distribution: + Distribution object wrapping information on the installed + version of the required package. + + Yields + ------ + dependency: + Dependency string for an extra-flag-based dependency of + the `distribution` that is required by `requirement`. + extra_name: + Name of the extra flag that made this dependency required. + """ + # Case when there are no required or requirable extras. + if (not requirement.extras) or (not distribution.requires): + return + # Iteratively parse possible extra dependencies and filter required ones. + for dependency in distribution.requires: + extras = re.findall('extra == "(.*?)"', dependency) + if requirement.extras.intersection(extras): + dep_string = dependency.split(";", 1)[0] + yield (dep_string, extras[0]) diff --git a/pyproject.toml b/pyproject.toml index 68f9dcb10c96f92ee50eebbd85b4834ca38820f0..44dc7268ee42f7497ba0ae389f1745f757fbcfe5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "cryptography >= 35.0", "fire ~= 0.4", "gmpy2 ~= 2.1", + "packaging >= 22.0", "pandas >= 1.2, < 3.0", "requests ~= 2.18", "scikit-learn ~= 1.0", diff --git a/test/utils/test_requirement.py b/test/utils/test_requirement.py new file mode 100644 index 0000000000000000000000000000000000000000..3a47976425e15399abc7ec41e403c584c995857d --- /dev/null +++ b/test/utils/test_requirement.py @@ -0,0 +1,138 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +from unittest import mock + +import numpy +import pytest +import sklearn # type: ignore + +from declearn.utils import Requirement, RequirementError + + +class TestRequirement: + """Unit tests for the 'Requirement' util.""" + + def test_requirement_repr(self) -> None: + """Docstring.""" + name = str(uuid.uuid4()) + subn = str(uuid.uuid4()) + req = Requirement(name, submodules=[subn]) + assert name in repr(req) + assert subn in repr(req) + + def test_require_missing_dependency(self) -> None: + """Docstring.""" + req = Requirement(str(uuid.uuid4())) + assert not req.check_if_available() + + def test_require_dependency_without_specifiers(self) -> None: + """Docstring.""" + req = Requirement("scikit-learn") + assert req.check_if_available() + + def test_require_dependency_with_version_specifier(self) -> None: + """Docstring.""" + req = Requirement("scikit-learn >=1.0") + assert req.check_if_available() + + def test_require_dependency_with_unmatched_version_specifier(self) -> None: + """Docstring.""" + req = Requirement("scikit-learn <1.0") + assert not req.check_if_available() + + def test_require_dependency_with_extra_specifier(self) -> None: + """Docstring.""" + req = Requirement("declearn[tests]") + assert req.check_if_available() + + def test_require_dependency_with_unmatched_extra_specifier(self) -> None: + """Docstring.""" + req = Requirement("declearn[torch1,torch2]") + assert not req.check_if_available() + + def test_import_module_numpy(self) -> None: + """Docstring.""" + req = Requirement("numpy") + with mock.patch("builtins.input") as patch_input: + patch_input.return_value = "y" + module = req.import_module() + patch_input.assert_called_once() + assert module is numpy + + def test_import_module_sklearn_wrong_import_name(self) -> None: + """Docstring.""" + req = Requirement("scikit-learn", import_name=None) + with mock.patch("builtins.input") as patch_input: + patch_input.return_value = "y" + with pytest.raises(RequirementError): + req.import_module() + patch_input.assert_called_once() + + def test_import_module_sklearn(self) -> None: + """Docstring.""" + req = Requirement("scikit-learn", import_name="sklearn") + with mock.patch("builtins.input") as patch_input: + patch_input.return_value = "y" + module = req.import_module() + patch_input.assert_called_once() + assert module is sklearn + assert not hasattr(module, "cluster") + + def test_import_module_sklearn_with_submodule(self) -> None: + """Docstring.""" + req = Requirement( + "scikit-learn", import_name="sklearn", submodules=["cluster"] + ) + with mock.patch("builtins.input") as patch_input: + patch_input.return_value = "y" + module = req.import_module() + patch_input.assert_called_once() + assert module is sklearn + assert hasattr(module, "cluster") + + def test_import_module_sklearn_with_misnamed_submodule(self) -> None: + """Docstring.""" + req = Requirement( + "scikit-learn", + import_name="sklearn", + submodules=[str(uuid.uuid4())], + ) + with mock.patch("builtins.input") as patch_input: + patch_input.return_value = "y" + with pytest.raises(RequirementError): + req.import_module() + patch_input.assert_called_once() + + def test_import_module_blocked_by_user(self) -> None: + """Docstring.""" + req = Requirement("declearn") + with mock.patch("builtins.input") as patch_input: + patch_input.return_value = "n" + with pytest.raises(RequirementError): + req.import_module() + patch_input.assert_called_once() + + def test_import_module_malicious_name_blocked(self) -> None: + """Docstring.""" + req = Requirement("declearn", import_name="numpy") + with mock.patch("builtins.input") as patch_input: + patch_input.return_value = "y" + with pytest.raises(RequirementError): + req.import_module() + patch_input.assert_not_called()