Allow components to define package requirements

This commit is contained in:
Ines Montani 2020-07-28 16:02:46 +02:00
parent 06a97a8766
commit 28bf43bffb
4 changed files with 50 additions and 0 deletions

View File

@ -483,6 +483,7 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E952 = ("Invalid requirement specified by component '{name}': {req}")
E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}") E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}")
E954 = ("The Tok2Vec listener did not receive a valid input.") E954 = ("The Tok2Vec listener did not receive a valid input.")
E955 = ("Can't find table '{table}' for language '{lang}' in spacy-lookups-data.") E955 = ("Can't find table '{table}' for language '{lang}' in spacy-lookups-data.")

View File

@ -203,6 +203,7 @@ class Language:
self._meta.setdefault("url", "") self._meta.setdefault("url", "")
self._meta.setdefault("license", "") self._meta.setdefault("license", "")
self._meta.setdefault("spacy_git_version", GIT_VERSION) self._meta.setdefault("spacy_git_version", GIT_VERSION)
self._meta.setdefault("requirements", [])
self._meta["vectors"] = { self._meta["vectors"] = {
"width": self.vocab.vectors_length, "width": self.vocab.vectors_length,
"vectors": len(self.vocab.vectors), "vectors": len(self.vocab.vectors),
@ -210,6 +211,8 @@ class Language:
"name": self.vocab.vectors.name, "name": self.vocab.vectors.name,
} }
self._meta["labels"] = self.pipe_labels self._meta["labels"] = self.pipe_labels
reqs = {p: self.get_pipe_meta(p).package_requirements for p in self.pipe_names}
self._meta["requirements"].extend(util.merge_pipe_requirements(reqs))
return self._meta return self._meta
@meta.setter @meta.setter
@ -356,6 +359,7 @@ class Language:
retokenizes: bool = False, retokenizes: bool = False,
scores: Iterable[str] = tuple(), scores: Iterable[str] = tuple(),
default_score_weights: Dict[str, float] = SimpleFrozenDict(), default_score_weights: Dict[str, float] = SimpleFrozenDict(),
package_requirements: Iterable[str] = tuple(),
func: Optional[Callable] = None, func: Optional[Callable] = None,
) -> Callable: ) -> Callable:
"""Register a new pipeline component factory. Can be used as a decorator """Register a new pipeline component factory. Can be used as a decorator
@ -410,6 +414,7 @@ class Language:
scores=scores, scores=scores,
default_score_weights=default_score_weights, default_score_weights=default_score_weights,
retokenizes=retokenizes, retokenizes=retokenizes,
package_requirements=package_requirements,
) )
cls.set_factory_meta(name, factory_meta) cls.set_factory_meta(name, factory_meta)
# We're overwriting the class attr with a frozen dict to handle # We're overwriting the class attr with a frozen dict to handle
@ -435,6 +440,7 @@ class Language:
retokenizes: bool = False, retokenizes: bool = False,
scores: Iterable[str] = tuple(), scores: Iterable[str] = tuple(),
default_score_weights: Dict[str, float] = SimpleFrozenDict(), default_score_weights: Dict[str, float] = SimpleFrozenDict(),
package_requirements: Iterable[str] = tuple(),
func: Optional[Callable[[Doc], Doc]] = None, func: Optional[Callable[[Doc], Doc]] = None,
) -> Callable: ) -> Callable:
"""Register a new pipeline component. Can be used for stateless function """Register a new pipeline component. Can be used for stateless function
@ -476,6 +482,7 @@ class Language:
retokenizes=retokenizes, retokenizes=retokenizes,
scores=scores, scores=scores,
default_score_weights=default_score_weights, default_score_weights=default_score_weights,
package_requirements=package_requirements,
func=factory_func, func=factory_func,
) )
return component_func return component_func
@ -1513,6 +1520,7 @@ class FactoryMeta:
retokenizes: bool = False retokenizes: bool = False
scores: Iterable[str] = tuple() scores: Iterable[str] = tuple()
default_score_weights: Optional[Dict[str, float]] = None # noqa: E704 default_score_weights: Optional[Dict[str, float]] = None # noqa: E704
package_requirements: Iterable[str] = tuple()
def _get_config_overrides( def _get_config_overrides(

View File

@ -157,3 +157,15 @@ def test_dot_to_dict(dot_notation, expected):
result = util.dot_to_dict(dot_notation) result = util.dot_to_dict(dot_notation)
assert result == expected assert result == expected
assert util.dict_to_dot(result) == dot_notation assert util.dict_to_dot(result) == dot_notation
@pytest.mark.parametrize(
"reqs,expected",
[
([["a>=3.0,<2.0", "b[x]>2.1"], ["b<3.0"]], ["a<2.0,>=3.0", "b[x]<3.0,>2.1"],),
([["a[foo]<1"], ["a[bar]>=5"]], ["a[bar,foo]<1,>=5"]),
],
)
def test_merge_pipe_requirements(reqs, expected):
pipe_reqs = {str(i): req for i, req in enumerate(reqs)}
assert util.merge_pipe_requirements(pipe_reqs) == expected

View File

@ -18,6 +18,7 @@ import sys
import warnings import warnings
from packaging.specifiers import SpecifierSet, InvalidSpecifier from packaging.specifiers import SpecifierSet, InvalidSpecifier
from packaging.version import Version, InvalidVersion from packaging.version import Version, InvalidVersion
from packaging.requirements import Requirement, InvalidRequirement
import subprocess import subprocess
from contextlib import contextmanager from contextlib import contextmanager
import tempfile import tempfile
@ -388,6 +389,34 @@ def get_base_version(version: str) -> str:
return Version(version).base_version return Version(version).base_version
def merge_pipe_requirements(pipe_requirements: Dict[str, Iterable[str]]) -> List[str]:
"""Merge package requirements specified by pipeline components. Since
there's no convenient way (?) to check that version specifier sets are
valid and can be met, we're currently only combining them using the built-in
API for handling specifiers. In theory, this could create requirements that
can never be met, like >=3.0,<3.0.
pipe_requirements (Dict[str, Iterable[str]]): Requirement strings specified
by the components, keyed by component name.
RETURNS (List[str]): The combined requirements for the pipeline.
"""
result = {}
for name, package_requirements in pipe_requirements.items():
for req_str in package_requirements:
try:
req = Requirement(req_str)
except InvalidRequirement:
raise ValueError(Errors.E952.format(name=name, req=req_str))
if req.name not in result:
result[req.name] = req
else:
# Combine version specifiers like <3.0 and >=2.0, and merge
# extras (e.g. spacy[lookups]) in a single Requirement object
result[req.name].specifier = result[req.name].specifier & req.specifier
result[req.name].extras.update(req.extras)
return [str(requirement) for requirement in result.values()]
def get_model_meta(path: Union[str, Path]) -> Dict[str, Any]: def get_model_meta(path: Union[str, Path]) -> Dict[str, Any]:
"""Get model meta.json from a directory path and validate its contents. """Get model meta.json from a directory path and validate its contents.