mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-12 17:22:25 +03:00
Centralise registry calls
This commit is contained in:
parent
9d7b22c52e
commit
43f87b991b
|
@ -19,6 +19,7 @@ from .glossary import explain # noqa: F401
|
|||
from .language import Language
|
||||
from .util import logger, registry # noqa: F401
|
||||
from .vocab import Vocab
|
||||
from .registrations import populate_registry, REGISTRY_POPULATED
|
||||
|
||||
if sys.maxunicode == 65535:
|
||||
raise SystemError(Errors.E130)
|
||||
|
|
|
@ -29,7 +29,6 @@ from ..featureextractor import FeatureExtractor
|
|||
from ..staticvectors import StaticVectors
|
||||
|
||||
|
||||
@registry.architectures("spacy.Tok2VecListener.v1")
|
||||
def tok2vec_listener_v1(width: int, upstream: str = "*"):
|
||||
tok2vec = Tok2VecListener(upstream_name=upstream, width=width)
|
||||
return tok2vec
|
||||
|
@ -46,7 +45,6 @@ def get_tok2vec_width(model: Model):
|
|||
return nO
|
||||
|
||||
|
||||
@registry.architectures("spacy.HashEmbedCNN.v2")
|
||||
def build_hash_embed_cnn_tok2vec(
|
||||
*,
|
||||
width: int,
|
||||
|
|
|
@ -183,7 +183,6 @@ def ner_score(examples, **kwargs):
|
|||
return get_ner_prf(examples, **kwargs)
|
||||
|
||||
|
||||
@registry.scorers("spacy.ner_scorer.v1")
|
||||
def make_ner_scorer():
|
||||
return ner_score
|
||||
|
||||
|
|
|
@ -64,7 +64,6 @@ def tagger_score(examples, **kwargs):
|
|||
return Scorer.score_token_attr(examples, "tag", **kwargs)
|
||||
|
||||
|
||||
@registry.scorers("spacy.tagger_scorer.v1")
|
||||
def make_tagger_scorer():
|
||||
return tagger_score
|
||||
|
||||
|
|
|
@ -132,9 +132,17 @@ class registry(thinc.registry):
|
|||
models = catalogue.create("spacy", "models", entry_points=True)
|
||||
cli = catalogue.create("spacy", "cli", entry_points=True)
|
||||
|
||||
@classmethod
|
||||
def ensure_populated(cls) -> None:
|
||||
"""Ensure the registry is populated with all necessary components."""
|
||||
from .registrations import populate_registry, REGISTRY_POPULATED
|
||||
if not REGISTRY_POPULATED:
|
||||
populate_registry()
|
||||
|
||||
@classmethod
|
||||
def get_registry_names(cls) -> List[str]:
|
||||
"""List all available registries."""
|
||||
cls.ensure_populated()
|
||||
names = []
|
||||
for name, value in inspect.getmembers(cls):
|
||||
if not name.startswith("_") and isinstance(value, Registry):
|
||||
|
@ -144,6 +152,7 @@ class registry(thinc.registry):
|
|||
@classmethod
|
||||
def get(cls, registry_name: str, func_name: str) -> Callable:
|
||||
"""Get a registered function from the registry."""
|
||||
cls.ensure_populated()
|
||||
# We're overwriting this classmethod so we're able to provide more
|
||||
# specific error messages and implement a fallback to spacy-legacy.
|
||||
if not hasattr(cls, registry_name):
|
||||
|
@ -179,6 +188,7 @@ class registry(thinc.registry):
|
|||
func_name (str): Name of the registered function.
|
||||
RETURNS (Dict[str, Optional[Union[str, int]]]): The function info.
|
||||
"""
|
||||
cls.ensure_populated()
|
||||
# We're overwriting this classmethod so we're able to provide more
|
||||
# specific error messages and implement a fallback to spacy-legacy.
|
||||
if not hasattr(cls, registry_name):
|
||||
|
@ -205,6 +215,7 @@ class registry(thinc.registry):
|
|||
@classmethod
|
||||
def has(cls, registry_name: str, func_name: str) -> bool:
|
||||
"""Check whether a function is available in a registry."""
|
||||
cls.ensure_populated()
|
||||
if not hasattr(cls, registry_name):
|
||||
return False
|
||||
reg = getattr(cls, registry_name)
|
||||
|
@ -1323,7 +1334,6 @@ def filter_chain_spans(*spans: Iterable["Span"]) -> List["Span"]:
|
|||
return filter_spans(itertools.chain(*spans))
|
||||
|
||||
|
||||
@registry.misc("spacy.first_longest_spans_filter.v1")
|
||||
def make_first_longest_spans_filter():
|
||||
return filter_chain_spans
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user