Include available registry names in error

This commit is contained in:
Ines Montani 2021-01-16 14:35:03 +11:00
parent d12be459f6
commit a552db2819
2 changed files with 13 additions and 3 deletions

View File

@ -468,7 +468,7 @@ class Errors:
"If the function is provided by a third-party package, e.g. " "If the function is provided by a third-party package, e.g. "
"spacy-transformers, make sure the package is installed in your " "spacy-transformers, make sure the package is installed in your "
"environment.\n\nAvailable names: {available}") "environment.\n\nAvailable names: {available}")
E894 = ("Unknown function registry: '{name}'.") E894 = ("Unknown function registry: '{name}'.\n\nAvailable names: {available}")
E895 = ("The 'textcat' component received gold-standard annotations with " E895 = ("The 'textcat' component received gold-standard annotations with "
"multiple labels per document. In spaCy 3 you should use the " "multiple labels per document. In spaCy 3 you should use the "
"'textcat_multilabel' component for this instead. " "'textcat_multilabel' component for this instead. "

View File

@ -15,7 +15,7 @@ import numpy.random
import numpy import numpy
import srsly import srsly
import catalogue import catalogue
from catalogue import RegistryError from catalogue import RegistryError, Registry
import sys import sys
import warnings import warnings
from packaging.specifiers import SpecifierSet, InvalidSpecifier from packaging.specifiers import SpecifierSet, InvalidSpecifier
@ -106,13 +106,23 @@ class registry(thinc.registry):
models = catalogue.create("spacy", "models", entry_points=True) models = catalogue.create("spacy", "models", entry_points=True)
cli = catalogue.create("spacy", "cli", entry_points=True) cli = catalogue.create("spacy", "cli", entry_points=True)
@classmethod
def get_registry_names(cls) -> List[str]:
"""List all available registries."""
names = []
for name, value in inspect.getmembers(cls):
if not name.startswith("_") and isinstance(value, Registry):
names.append(name)
return sorted(names)
@classmethod @classmethod
def get(cls, registry_name: str, func_name: str) -> Callable: def get(cls, registry_name: str, func_name: str) -> Callable:
"""Get a registered function from the registry.""" """Get a registered function from the registry."""
# We're overwriting this classmethod so we're able to provide more # We're overwriting this classmethod so we're able to provide more
# specific error messages and implement a fallback to spacy-legacy. # specific error messages and implement a fallback to spacy-legacy.
if not hasattr(cls, registry_name): if not hasattr(cls, registry_name):
raise RegistryError(Errors.E894.format(name=registry_name)) names = ", ".join(cls.get_registry_names()) or "none"
raise RegistryError(Errors.E894.format(name=registry_name, available=names))
reg = getattr(cls, registry_name) reg = getattr(cls, registry_name)
try: try:
func = reg.get(func_name) func = reg.get(func_name)