mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-11 08:42:28 +03:00
Allow components to define package requirements
This commit is contained in:
parent
06a97a8766
commit
28bf43bffb
|
@ -483,6 +483,7 @@ class Errors:
|
|||
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
||||
|
||||
# TODO: fix numbering after merging develop into master
|
||||
E952 = ("Invalid requirement specified by component '{name}': {req}")
|
||||
E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}")
|
||||
E954 = ("The Tok2Vec listener did not receive a valid input.")
|
||||
E955 = ("Can't find table '{table}' for language '{lang}' in spacy-lookups-data.")
|
||||
|
|
|
@ -203,6 +203,7 @@ class Language:
|
|||
self._meta.setdefault("url", "")
|
||||
self._meta.setdefault("license", "")
|
||||
self._meta.setdefault("spacy_git_version", GIT_VERSION)
|
||||
self._meta.setdefault("requirements", [])
|
||||
self._meta["vectors"] = {
|
||||
"width": self.vocab.vectors_length,
|
||||
"vectors": len(self.vocab.vectors),
|
||||
|
@ -210,6 +211,8 @@ class Language:
|
|||
"name": self.vocab.vectors.name,
|
||||
}
|
||||
self._meta["labels"] = self.pipe_labels
|
||||
reqs = {p: self.get_pipe_meta(p).package_requirements for p in self.pipe_names}
|
||||
self._meta["requirements"].extend(util.merge_pipe_requirements(reqs))
|
||||
return self._meta
|
||||
|
||||
@meta.setter
|
||||
|
@ -356,6 +359,7 @@ class Language:
|
|||
retokenizes: bool = False,
|
||||
scores: Iterable[str] = tuple(),
|
||||
default_score_weights: Dict[str, float] = SimpleFrozenDict(),
|
||||
package_requirements: Iterable[str] = tuple(),
|
||||
func: Optional[Callable] = None,
|
||||
) -> Callable:
|
||||
"""Register a new pipeline component factory. Can be used as a decorator
|
||||
|
@ -410,6 +414,7 @@ class Language:
|
|||
scores=scores,
|
||||
default_score_weights=default_score_weights,
|
||||
retokenizes=retokenizes,
|
||||
package_requirements=package_requirements,
|
||||
)
|
||||
cls.set_factory_meta(name, factory_meta)
|
||||
# We're overwriting the class attr with a frozen dict to handle
|
||||
|
@ -435,6 +440,7 @@ class Language:
|
|||
retokenizes: bool = False,
|
||||
scores: Iterable[str] = tuple(),
|
||||
default_score_weights: Dict[str, float] = SimpleFrozenDict(),
|
||||
package_requirements: Iterable[str] = tuple(),
|
||||
func: Optional[Callable[[Doc], Doc]] = None,
|
||||
) -> Callable:
|
||||
"""Register a new pipeline component. Can be used for stateless function
|
||||
|
@ -476,6 +482,7 @@ class Language:
|
|||
retokenizes=retokenizes,
|
||||
scores=scores,
|
||||
default_score_weights=default_score_weights,
|
||||
package_requirements=package_requirements,
|
||||
func=factory_func,
|
||||
)
|
||||
return component_func
|
||||
|
@ -1513,6 +1520,7 @@ class FactoryMeta:
|
|||
retokenizes: bool = False
|
||||
scores: Iterable[str] = tuple()
|
||||
default_score_weights: Optional[Dict[str, float]] = None # noqa: E704
|
||||
package_requirements: Iterable[str] = tuple()
|
||||
|
||||
|
||||
def _get_config_overrides(
|
||||
|
|
|
@ -157,3 +157,15 @@ def test_dot_to_dict(dot_notation, expected):
|
|||
result = util.dot_to_dict(dot_notation)
|
||||
assert result == expected
|
||||
assert util.dict_to_dot(result) == dot_notation
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"reqs,expected",
|
||||
[
|
||||
([["a>=3.0,<2.0", "b[x]>2.1"], ["b<3.0"]], ["a<2.0,>=3.0", "b[x]<3.0,>2.1"],),
|
||||
([["a[foo]<1"], ["a[bar]>=5"]], ["a[bar,foo]<1,>=5"]),
|
||||
],
|
||||
)
|
||||
def test_merge_pipe_requirements(reqs, expected):
|
||||
pipe_reqs = {str(i): req for i, req in enumerate(reqs)}
|
||||
assert util.merge_pipe_requirements(pipe_reqs) == expected
|
||||
|
|
|
@ -18,6 +18,7 @@ import sys
|
|||
import warnings
|
||||
from packaging.specifiers import SpecifierSet, InvalidSpecifier
|
||||
from packaging.version import Version, InvalidVersion
|
||||
from packaging.requirements import Requirement, InvalidRequirement
|
||||
import subprocess
|
||||
from contextlib import contextmanager
|
||||
import tempfile
|
||||
|
@ -388,6 +389,34 @@ def get_base_version(version: str) -> str:
|
|||
return Version(version).base_version
|
||||
|
||||
|
||||
def merge_pipe_requirements(pipe_requirements: Dict[str, Iterable[str]]) -> List[str]:
|
||||
"""Merge package requirements specified by pipeline components. Since
|
||||
there's no convenient way (?) to check that version specifier sets are
|
||||
valid and can be met, we're currently only combining them using the built-in
|
||||
API for handling specifiers. In theory, this could create requirements that
|
||||
can never be met, like >=3.0,<3.0.
|
||||
|
||||
pipe_requirements (Dict[str, Iterable[str]]): Requirement strings specified
|
||||
by the components, keyed by component name.
|
||||
RETURNS (List[str]): The combined requirements for the pipeline.
|
||||
"""
|
||||
result = {}
|
||||
for name, package_requirements in pipe_requirements.items():
|
||||
for req_str in package_requirements:
|
||||
try:
|
||||
req = Requirement(req_str)
|
||||
except InvalidRequirement:
|
||||
raise ValueError(Errors.E952.format(name=name, req=req_str))
|
||||
if req.name not in result:
|
||||
result[req.name] = req
|
||||
else:
|
||||
# Combine version specifiers like <3.0 and >=2.0, and merge
|
||||
# extras (e.g. spacy[lookups]) in a single Requirement object
|
||||
result[req.name].specifier = result[req.name].specifier & req.specifier
|
||||
result[req.name].extras.update(req.extras)
|
||||
return [str(requirement) for requirement in result.values()]
|
||||
|
||||
|
||||
def get_model_meta(path: Union[str, Path]) -> Dict[str, Any]:
|
||||
"""Get model meta.json from a directory path and validate its contents.
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user