From 43f87b991bfd97ab5e902a24de213f92a53cd333 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 19 May 2025 12:33:58 +0200 Subject: [PATCH] Centralise registry calls --- spacy/__init__.py | 1 + spacy/ml/models/tok2vec.py | 2 -- spacy/pipeline/ner.pyx | 1 - spacy/pipeline/tagger.pyx | 1 - spacy/util.py | 12 +++++++++++- 5 files changed, 12 insertions(+), 5 deletions(-) diff --git a/spacy/__init__.py b/spacy/__init__.py index 1a18ad0d5..efc475dc9 100644 --- a/spacy/__init__.py +++ b/spacy/__init__.py @@ -19,6 +19,7 @@ from .glossary import explain # noqa: F401 from .language import Language from .util import logger, registry # noqa: F401 from .vocab import Vocab +from .registrations import populate_registry, REGISTRY_POPULATED if sys.maxunicode == 65535: raise SystemError(Errors.E130) diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 0edc89991..e1c6db690 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -29,7 +29,6 @@ from ..featureextractor import FeatureExtractor from ..staticvectors import StaticVectors -@registry.architectures("spacy.Tok2VecListener.v1") def tok2vec_listener_v1(width: int, upstream: str = "*"): tok2vec = Tok2VecListener(upstream_name=upstream, width=width) return tok2vec @@ -46,7 +45,6 @@ def get_tok2vec_width(model: Model): return nO -@registry.architectures("spacy.HashEmbedCNN.v2") def build_hash_embed_cnn_tok2vec( *, width: int, diff --git a/spacy/pipeline/ner.pyx b/spacy/pipeline/ner.pyx index bb009dc7a..b8663937b 100644 --- a/spacy/pipeline/ner.pyx +++ b/spacy/pipeline/ner.pyx @@ -183,7 +183,6 @@ def ner_score(examples, **kwargs): return get_ner_prf(examples, **kwargs) -@registry.scorers("spacy.ner_scorer.v1") def make_ner_scorer(): return ner_score diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index 34e85d49c..28d4c6e7f 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -64,7 +64,6 @@ def tagger_score(examples, **kwargs): return Scorer.score_token_attr(examples, "tag", **kwargs) -@registry.scorers("spacy.tagger_scorer.v1") def make_tagger_scorer(): return tagger_score diff --git a/spacy/util.py b/spacy/util.py index c127be03c..96b52e21d 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -132,9 +132,17 @@ class registry(thinc.registry): models = catalogue.create("spacy", "models", entry_points=True) cli = catalogue.create("spacy", "cli", entry_points=True) + @classmethod + def ensure_populated(cls) -> None: + """Ensure the registry is populated with all necessary components.""" + from .registrations import populate_registry, REGISTRY_POPULATED + if not REGISTRY_POPULATED: + populate_registry() + @classmethod def get_registry_names(cls) -> List[str]: """List all available registries.""" + cls.ensure_populated() names = [] for name, value in inspect.getmembers(cls): if not name.startswith("_") and isinstance(value, Registry): @@ -144,6 +152,7 @@ class registry(thinc.registry): @classmethod def get(cls, registry_name: str, func_name: str) -> Callable: """Get a registered function from the registry.""" + cls.ensure_populated() # We're overwriting this classmethod so we're able to provide more # specific error messages and implement a fallback to spacy-legacy. if not hasattr(cls, registry_name): @@ -179,6 +188,7 @@ class registry(thinc.registry): func_name (str): Name of the registered function. RETURNS (Dict[str, Optional[Union[str, int]]]): The function info. """ + cls.ensure_populated() # We're overwriting this classmethod so we're able to provide more # specific error messages and implement a fallback to spacy-legacy. if not hasattr(cls, registry_name): @@ -205,6 +215,7 @@ class registry(thinc.registry): @classmethod def has(cls, registry_name: str, func_name: str) -> bool: """Check whether a function is available in a registry.""" + cls.ensure_populated() if not hasattr(cls, registry_name): return False reg = getattr(cls, registry_name) @@ -1323,7 +1334,6 @@ def filter_chain_spans(*spans: Iterable["Span"]) -> List["Span"]: return filter_spans(itertools.chain(*spans)) -@registry.misc("spacy.first_longest_spans_filter.v1") def make_first_longest_spans_filter(): return filter_chain_spans