mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-11 16:52:21 +03:00
Allow components to define package requirements
This commit is contained in:
parent
06a97a8766
commit
28bf43bffb
|
@ -483,6 +483,7 @@ class Errors:
|
||||||
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
E952 = ("Invalid requirement specified by component '{name}': {req}")
|
||||||
E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}")
|
E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}")
|
||||||
E954 = ("The Tok2Vec listener did not receive a valid input.")
|
E954 = ("The Tok2Vec listener did not receive a valid input.")
|
||||||
E955 = ("Can't find table '{table}' for language '{lang}' in spacy-lookups-data.")
|
E955 = ("Can't find table '{table}' for language '{lang}' in spacy-lookups-data.")
|
||||||
|
|
|
@ -203,6 +203,7 @@ class Language:
|
||||||
self._meta.setdefault("url", "")
|
self._meta.setdefault("url", "")
|
||||||
self._meta.setdefault("license", "")
|
self._meta.setdefault("license", "")
|
||||||
self._meta.setdefault("spacy_git_version", GIT_VERSION)
|
self._meta.setdefault("spacy_git_version", GIT_VERSION)
|
||||||
|
self._meta.setdefault("requirements", [])
|
||||||
self._meta["vectors"] = {
|
self._meta["vectors"] = {
|
||||||
"width": self.vocab.vectors_length,
|
"width": self.vocab.vectors_length,
|
||||||
"vectors": len(self.vocab.vectors),
|
"vectors": len(self.vocab.vectors),
|
||||||
|
@ -210,6 +211,8 @@ class Language:
|
||||||
"name": self.vocab.vectors.name,
|
"name": self.vocab.vectors.name,
|
||||||
}
|
}
|
||||||
self._meta["labels"] = self.pipe_labels
|
self._meta["labels"] = self.pipe_labels
|
||||||
|
reqs = {p: self.get_pipe_meta(p).package_requirements for p in self.pipe_names}
|
||||||
|
self._meta["requirements"].extend(util.merge_pipe_requirements(reqs))
|
||||||
return self._meta
|
return self._meta
|
||||||
|
|
||||||
@meta.setter
|
@meta.setter
|
||||||
|
@ -356,6 +359,7 @@ class Language:
|
||||||
retokenizes: bool = False,
|
retokenizes: bool = False,
|
||||||
scores: Iterable[str] = tuple(),
|
scores: Iterable[str] = tuple(),
|
||||||
default_score_weights: Dict[str, float] = SimpleFrozenDict(),
|
default_score_weights: Dict[str, float] = SimpleFrozenDict(),
|
||||||
|
package_requirements: Iterable[str] = tuple(),
|
||||||
func: Optional[Callable] = None,
|
func: Optional[Callable] = None,
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
"""Register a new pipeline component factory. Can be used as a decorator
|
"""Register a new pipeline component factory. Can be used as a decorator
|
||||||
|
@ -410,6 +414,7 @@ class Language:
|
||||||
scores=scores,
|
scores=scores,
|
||||||
default_score_weights=default_score_weights,
|
default_score_weights=default_score_weights,
|
||||||
retokenizes=retokenizes,
|
retokenizes=retokenizes,
|
||||||
|
package_requirements=package_requirements,
|
||||||
)
|
)
|
||||||
cls.set_factory_meta(name, factory_meta)
|
cls.set_factory_meta(name, factory_meta)
|
||||||
# We're overwriting the class attr with a frozen dict to handle
|
# We're overwriting the class attr with a frozen dict to handle
|
||||||
|
@ -435,6 +440,7 @@ class Language:
|
||||||
retokenizes: bool = False,
|
retokenizes: bool = False,
|
||||||
scores: Iterable[str] = tuple(),
|
scores: Iterable[str] = tuple(),
|
||||||
default_score_weights: Dict[str, float] = SimpleFrozenDict(),
|
default_score_weights: Dict[str, float] = SimpleFrozenDict(),
|
||||||
|
package_requirements: Iterable[str] = tuple(),
|
||||||
func: Optional[Callable[[Doc], Doc]] = None,
|
func: Optional[Callable[[Doc], Doc]] = None,
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
"""Register a new pipeline component. Can be used for stateless function
|
"""Register a new pipeline component. Can be used for stateless function
|
||||||
|
@ -476,6 +482,7 @@ class Language:
|
||||||
retokenizes=retokenizes,
|
retokenizes=retokenizes,
|
||||||
scores=scores,
|
scores=scores,
|
||||||
default_score_weights=default_score_weights,
|
default_score_weights=default_score_weights,
|
||||||
|
package_requirements=package_requirements,
|
||||||
func=factory_func,
|
func=factory_func,
|
||||||
)
|
)
|
||||||
return component_func
|
return component_func
|
||||||
|
@ -1513,6 +1520,7 @@ class FactoryMeta:
|
||||||
retokenizes: bool = False
|
retokenizes: bool = False
|
||||||
scores: Iterable[str] = tuple()
|
scores: Iterable[str] = tuple()
|
||||||
default_score_weights: Optional[Dict[str, float]] = None # noqa: E704
|
default_score_weights: Optional[Dict[str, float]] = None # noqa: E704
|
||||||
|
package_requirements: Iterable[str] = tuple()
|
||||||
|
|
||||||
|
|
||||||
def _get_config_overrides(
|
def _get_config_overrides(
|
||||||
|
|
|
@ -157,3 +157,15 @@ def test_dot_to_dict(dot_notation, expected):
|
||||||
result = util.dot_to_dict(dot_notation)
|
result = util.dot_to_dict(dot_notation)
|
||||||
assert result == expected
|
assert result == expected
|
||||||
assert util.dict_to_dot(result) == dot_notation
|
assert util.dict_to_dot(result) == dot_notation
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"reqs,expected",
|
||||||
|
[
|
||||||
|
([["a>=3.0,<2.0", "b[x]>2.1"], ["b<3.0"]], ["a<2.0,>=3.0", "b[x]<3.0,>2.1"],),
|
||||||
|
([["a[foo]<1"], ["a[bar]>=5"]], ["a[bar,foo]<1,>=5"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_merge_pipe_requirements(reqs, expected):
|
||||||
|
pipe_reqs = {str(i): req for i, req in enumerate(reqs)}
|
||||||
|
assert util.merge_pipe_requirements(pipe_reqs) == expected
|
||||||
|
|
|
@ -18,6 +18,7 @@ import sys
|
||||||
import warnings
|
import warnings
|
||||||
from packaging.specifiers import SpecifierSet, InvalidSpecifier
|
from packaging.specifiers import SpecifierSet, InvalidSpecifier
|
||||||
from packaging.version import Version, InvalidVersion
|
from packaging.version import Version, InvalidVersion
|
||||||
|
from packaging.requirements import Requirement, InvalidRequirement
|
||||||
import subprocess
|
import subprocess
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
import tempfile
|
import tempfile
|
||||||
|
@ -388,6 +389,34 @@ def get_base_version(version: str) -> str:
|
||||||
return Version(version).base_version
|
return Version(version).base_version
|
||||||
|
|
||||||
|
|
||||||
|
def merge_pipe_requirements(pipe_requirements: Dict[str, Iterable[str]]) -> List[str]:
|
||||||
|
"""Merge package requirements specified by pipeline components. Since
|
||||||
|
there's no convenient way (?) to check that version specifier sets are
|
||||||
|
valid and can be met, we're currently only combining them using the built-in
|
||||||
|
API for handling specifiers. In theory, this could create requirements that
|
||||||
|
can never be met, like >=3.0,<3.0.
|
||||||
|
|
||||||
|
pipe_requirements (Dict[str, Iterable[str]]): Requirement strings specified
|
||||||
|
by the components, keyed by component name.
|
||||||
|
RETURNS (List[str]): The combined requirements for the pipeline.
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
for name, package_requirements in pipe_requirements.items():
|
||||||
|
for req_str in package_requirements:
|
||||||
|
try:
|
||||||
|
req = Requirement(req_str)
|
||||||
|
except InvalidRequirement:
|
||||||
|
raise ValueError(Errors.E952.format(name=name, req=req_str))
|
||||||
|
if req.name not in result:
|
||||||
|
result[req.name] = req
|
||||||
|
else:
|
||||||
|
# Combine version specifiers like <3.0 and >=2.0, and merge
|
||||||
|
# extras (e.g. spacy[lookups]) in a single Requirement object
|
||||||
|
result[req.name].specifier = result[req.name].specifier & req.specifier
|
||||||
|
result[req.name].extras.update(req.extras)
|
||||||
|
return [str(requirement) for requirement in result.values()]
|
||||||
|
|
||||||
|
|
||||||
def get_model_meta(path: Union[str, Path]) -> Dict[str, Any]:
|
def get_model_meta(path: Union[str, Path]) -> Dict[str, Any]:
|
||||||
"""Get model meta.json from a directory path and validate its contents.
|
"""Get model meta.json from a directory path and validate its contents.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user