mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
Handle spacy-legacy in package CLI for dependencies (#9163)
* Handle spacy-legacy in package CLI for dependencies * Implement legacy backoff in spacy registry.find * Remove unused import * Update and format test
This commit is contained in:
parent
632d8d4c35
commit
aba6ce3a43
|
@ -200,7 +200,7 @@ def get_third_party_dependencies(
|
||||||
exclude (list): List of packages to exclude (e.g. that already exist in meta).
|
exclude (list): List of packages to exclude (e.g. that already exist in meta).
|
||||||
RETURNS (list): The versioned requirements.
|
RETURNS (list): The versioned requirements.
|
||||||
"""
|
"""
|
||||||
own_packages = ("spacy", "spacy-nightly", "thinc", "srsly")
|
own_packages = ("spacy", "spacy-legacy", "spacy-nightly", "thinc", "srsly")
|
||||||
distributions = util.packages_distributions()
|
distributions = util.packages_distributions()
|
||||||
funcs = defaultdict(set)
|
funcs = defaultdict(set)
|
||||||
for path, value in util.walk_dict(config):
|
for path, value in util.walk_dict(config):
|
||||||
|
@ -211,9 +211,8 @@ def get_third_party_dependencies(
|
||||||
funcs["factories"].add(component["factory"])
|
funcs["factories"].add(component["factory"])
|
||||||
modules = set()
|
modules = set()
|
||||||
for reg_name, func_names in funcs.items():
|
for reg_name, func_names in funcs.items():
|
||||||
sub_registry = getattr(util.registry, reg_name)
|
|
||||||
for func_name in func_names:
|
for func_name in func_names:
|
||||||
func_info = sub_registry.find(func_name)
|
func_info = util.registry.find(reg_name, func_name)
|
||||||
module_name = func_info.get("module")
|
module_name = func_info.get("module")
|
||||||
if module_name: # the code is part of a module, not a --code file
|
if module_name: # the code is part of a module, not a --code file
|
||||||
modules.add(func_info["module"].split(".")[0])
|
modules.add(func_info["module"].split(".")[0])
|
||||||
|
|
|
@ -536,7 +536,7 @@ def test_init_labels(component_name):
|
||||||
assert len(nlp2.get_pipe(component_name).labels) == 4
|
assert len(nlp2.get_pipe(component_name).labels) == 4
|
||||||
|
|
||||||
|
|
||||||
def test_get_third_party_dependencies_runs():
|
def test_get_third_party_dependencies():
|
||||||
# We can't easily test the detection of third-party packages here, but we
|
# We can't easily test the detection of third-party packages here, but we
|
||||||
# can at least make sure that the function and its importlib magic runs.
|
# can at least make sure that the function and its importlib magic runs.
|
||||||
nlp = Dutch()
|
nlp = Dutch()
|
||||||
|
@ -544,6 +544,22 @@ def test_get_third_party_dependencies_runs():
|
||||||
nlp.add_pipe("tagger")
|
nlp.add_pipe("tagger")
|
||||||
assert get_third_party_dependencies(nlp.config) == []
|
assert get_third_party_dependencies(nlp.config) == []
|
||||||
|
|
||||||
|
# Test with legacy function
|
||||||
|
nlp = Dutch()
|
||||||
|
nlp.add_pipe(
|
||||||
|
"textcat",
|
||||||
|
config={
|
||||||
|
"model": {
|
||||||
|
# Do not update from legacy architecture spacy.TextCatBOW.v1
|
||||||
|
"@architectures": "spacy.TextCatBOW.v1",
|
||||||
|
"exclusive_classes": True,
|
||||||
|
"ngram_size": 1,
|
||||||
|
"no_output_layer": False,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
get_third_party_dependencies(nlp.config) == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"parent,child,expected",
|
"parent,child,expected",
|
||||||
|
|
|
@ -140,6 +140,32 @@ class registry(thinc.registry):
|
||||||
) from None
|
) from None
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def find(cls, registry_name: str, func_name: str) -> Callable:
|
||||||
|
"""Get info about a registered function from the registry."""
|
||||||
|
# We're overwriting this classmethod so we're able to provide more
|
||||||
|
# specific error messages and implement a fallback to spacy-legacy.
|
||||||
|
if not hasattr(cls, registry_name):
|
||||||
|
names = ", ".join(cls.get_registry_names()) or "none"
|
||||||
|
raise RegistryError(Errors.E892.format(name=registry_name, available=names))
|
||||||
|
reg = getattr(cls, registry_name)
|
||||||
|
try:
|
||||||
|
func_info = reg.find(func_name)
|
||||||
|
except RegistryError:
|
||||||
|
if func_name.startswith("spacy."):
|
||||||
|
legacy_name = func_name.replace("spacy.", "spacy-legacy.")
|
||||||
|
try:
|
||||||
|
return reg.find(legacy_name)
|
||||||
|
except catalogue.RegistryError:
|
||||||
|
pass
|
||||||
|
available = ", ".join(sorted(reg.get_all().keys())) or "none"
|
||||||
|
raise RegistryError(
|
||||||
|
Errors.E893.format(
|
||||||
|
name=func_name, reg_name=registry_name, available=available
|
||||||
|
)
|
||||||
|
) from None
|
||||||
|
return func_info
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def has(cls, registry_name: str, func_name: str) -> bool:
|
def has(cls, registry_name: str, func_name: str) -> bool:
|
||||||
"""Check whether a function is available in a registry."""
|
"""Check whether a function is available in a registry."""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user