Centralise registry calls

This commit is contained in:
Matthew Honnibal 2025-05-19 12:33:58 +02:00
parent 9d7b22c52e
commit 43f87b991b
5 changed files with 12 additions and 5 deletions

View File

@ -19,6 +19,7 @@ from .glossary import explain # noqa: F401
from .language import Language
from .util import logger, registry # noqa: F401
from .vocab import Vocab
from .registrations import populate_registry, REGISTRY_POPULATED
if sys.maxunicode == 65535:
raise SystemError(Errors.E130)

View File

@ -29,7 +29,6 @@ from ..featureextractor import FeatureExtractor
from ..staticvectors import StaticVectors
@registry.architectures("spacy.Tok2VecListener.v1")
def tok2vec_listener_v1(width: int, upstream: str = "*"):
tok2vec = Tok2VecListener(upstream_name=upstream, width=width)
return tok2vec
@ -46,7 +45,6 @@ def get_tok2vec_width(model: Model):
return nO
@registry.architectures("spacy.HashEmbedCNN.v2")
def build_hash_embed_cnn_tok2vec(
*,
width: int,

View File

@ -183,7 +183,6 @@ def ner_score(examples, **kwargs):
return get_ner_prf(examples, **kwargs)
@registry.scorers("spacy.ner_scorer.v1")
def make_ner_scorer():
return ner_score

View File

@ -64,7 +64,6 @@ def tagger_score(examples, **kwargs):
return Scorer.score_token_attr(examples, "tag", **kwargs)
@registry.scorers("spacy.tagger_scorer.v1")
def make_tagger_scorer():
return tagger_score

View File

@ -132,9 +132,17 @@ class registry(thinc.registry):
models = catalogue.create("spacy", "models", entry_points=True)
cli = catalogue.create("spacy", "cli", entry_points=True)
@classmethod
def ensure_populated(cls) -> None:
"""Ensure the registry is populated with all necessary components."""
from .registrations import populate_registry, REGISTRY_POPULATED
if not REGISTRY_POPULATED:
populate_registry()
@classmethod
def get_registry_names(cls) -> List[str]:
"""List all available registries."""
cls.ensure_populated()
names = []
for name, value in inspect.getmembers(cls):
if not name.startswith("_") and isinstance(value, Registry):
@ -144,6 +152,7 @@ class registry(thinc.registry):
@classmethod
def get(cls, registry_name: str, func_name: str) -> Callable:
"""Get a registered function from the registry."""
cls.ensure_populated()
# We're overwriting this classmethod so we're able to provide more
# specific error messages and implement a fallback to spacy-legacy.
if not hasattr(cls, registry_name):
@ -179,6 +188,7 @@ class registry(thinc.registry):
func_name (str): Name of the registered function.
RETURNS (Dict[str, Optional[Union[str, int]]]): The function info.
"""
cls.ensure_populated()
# We're overwriting this classmethod so we're able to provide more
# specific error messages and implement a fallback to spacy-legacy.
if not hasattr(cls, registry_name):
@ -205,6 +215,7 @@ class registry(thinc.registry):
@classmethod
def has(cls, registry_name: str, func_name: str) -> bool:
"""Check whether a function is available in a registry."""
cls.ensure_populated()
if not hasattr(cls, registry_name):
return False
reg = getattr(cls, registry_name)
@ -1323,7 +1334,6 @@ def filter_chain_spans(*spans: Iterable["Span"]) -> List["Span"]:
return filter_spans(itertools.chain(*spans))
@registry.misc("spacy.first_longest_spans_filter.v1")
def make_first_longest_spans_filter():
return filter_chain_spans