From 28bf43bffb7ecdf411ac82ec2a46a6d3f948bb49 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Tue, 28 Jul 2020 16:02:46 +0200 Subject: [PATCH] Allow components to define package requirements --- spacy/errors.py | 1 + spacy/language.py | 8 ++++++++ spacy/tests/test_misc.py | 12 ++++++++++++ spacy/util.py | 29 +++++++++++++++++++++++++++++ 4 files changed, 50 insertions(+) diff --git a/spacy/errors.py b/spacy/errors.py index df6f82757..5d26407ad 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -483,6 +483,7 @@ class Errors: E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") # TODO: fix numbering after merging develop into master + E952 = ("Invalid requirement specified by component '{name}': {req}") E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}") E954 = ("The Tok2Vec listener did not receive a valid input.") E955 = ("Can't find table '{table}' for language '{lang}' in spacy-lookups-data.") diff --git a/spacy/language.py b/spacy/language.py index cade90b24..26ac85a5d 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -203,6 +203,7 @@ class Language: self._meta.setdefault("url", "") self._meta.setdefault("license", "") self._meta.setdefault("spacy_git_version", GIT_VERSION) + self._meta.setdefault("requirements", []) self._meta["vectors"] = { "width": self.vocab.vectors_length, "vectors": len(self.vocab.vectors), @@ -210,6 +211,8 @@ class Language: "name": self.vocab.vectors.name, } self._meta["labels"] = self.pipe_labels + reqs = {p: self.get_pipe_meta(p).package_requirements for p in self.pipe_names} + self._meta["requirements"].extend(util.merge_pipe_requirements(reqs)) return self._meta @meta.setter @@ -356,6 +359,7 @@ class Language: retokenizes: bool = False, scores: Iterable[str] = tuple(), default_score_weights: Dict[str, float] = SimpleFrozenDict(), + package_requirements: Iterable[str] = tuple(), func: Optional[Callable] = None, ) -> Callable: """Register a new pipeline component factory. Can be used as a decorator @@ -410,6 +414,7 @@ class Language: scores=scores, default_score_weights=default_score_weights, retokenizes=retokenizes, + package_requirements=package_requirements, ) cls.set_factory_meta(name, factory_meta) # We're overwriting the class attr with a frozen dict to handle @@ -435,6 +440,7 @@ class Language: retokenizes: bool = False, scores: Iterable[str] = tuple(), default_score_weights: Dict[str, float] = SimpleFrozenDict(), + package_requirements: Iterable[str] = tuple(), func: Optional[Callable[[Doc], Doc]] = None, ) -> Callable: """Register a new pipeline component. Can be used for stateless function @@ -476,6 +482,7 @@ class Language: retokenizes=retokenizes, scores=scores, default_score_weights=default_score_weights, + package_requirements=package_requirements, func=factory_func, ) return component_func @@ -1513,6 +1520,7 @@ class FactoryMeta: retokenizes: bool = False scores: Iterable[str] = tuple() default_score_weights: Optional[Dict[str, float]] = None # noqa: E704 + package_requirements: Iterable[str] = tuple() def _get_config_overrides( diff --git a/spacy/tests/test_misc.py b/spacy/tests/test_misc.py index e6ef45f90..509cbc043 100644 --- a/spacy/tests/test_misc.py +++ b/spacy/tests/test_misc.py @@ -157,3 +157,15 @@ def test_dot_to_dict(dot_notation, expected): result = util.dot_to_dict(dot_notation) assert result == expected assert util.dict_to_dot(result) == dot_notation + + +@pytest.mark.parametrize( + "reqs,expected", + [ + ([["a>=3.0,<2.0", "b[x]>2.1"], ["b<3.0"]], ["a<2.0,>=3.0", "b[x]<3.0,>2.1"],), + ([["a[foo]<1"], ["a[bar]>=5"]], ["a[bar,foo]<1,>=5"]), + ], +) +def test_merge_pipe_requirements(reqs, expected): + pipe_reqs = {str(i): req for i, req in enumerate(reqs)} + assert util.merge_pipe_requirements(pipe_reqs) == expected diff --git a/spacy/util.py b/spacy/util.py index f4c810e07..ccd6d3869 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -18,6 +18,7 @@ import sys import warnings from packaging.specifiers import SpecifierSet, InvalidSpecifier from packaging.version import Version, InvalidVersion +from packaging.requirements import Requirement, InvalidRequirement import subprocess from contextlib import contextmanager import tempfile @@ -388,6 +389,34 @@ def get_base_version(version: str) -> str: return Version(version).base_version +def merge_pipe_requirements(pipe_requirements: Dict[str, Iterable[str]]) -> List[str]: + """Merge package requirements specified by pipeline components. Since + there's no convenient way (?) to check that version specifier sets are + valid and can be met, we're currently only combining them using the built-in + API for handling specifiers. In theory, this could create requirements that + can never be met, like >=3.0,<3.0. + + pipe_requirements (Dict[str, Iterable[str]]): Requirement strings specified + by the components, keyed by component name. + RETURNS (List[str]): The combined requirements for the pipeline. + """ + result = {} + for name, package_requirements in pipe_requirements.items(): + for req_str in package_requirements: + try: + req = Requirement(req_str) + except InvalidRequirement: + raise ValueError(Errors.E952.format(name=name, req=req_str)) + if req.name not in result: + result[req.name] = req + else: + # Combine version specifiers like <3.0 and >=2.0, and merge + # extras (e.g. spacy[lookups]) in a single Requirement object + result[req.name].specifier = result[req.name].specifier & req.specifier + result[req.name].extras.update(req.extras) + return [str(requirement) for requirement in result.values()] + + def get_model_meta(path: Union[str, Path]) -> Dict[str, Any]: """Get model meta.json from a directory path and validate its contents.