mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-13 09:42:26 +03:00
Centralise registry calls
This commit is contained in:
parent
9d7b22c52e
commit
43f87b991b
|
@ -19,6 +19,7 @@ from .glossary import explain # noqa: F401
|
||||||
from .language import Language
|
from .language import Language
|
||||||
from .util import logger, registry # noqa: F401
|
from .util import logger, registry # noqa: F401
|
||||||
from .vocab import Vocab
|
from .vocab import Vocab
|
||||||
|
from .registrations import populate_registry, REGISTRY_POPULATED
|
||||||
|
|
||||||
if sys.maxunicode == 65535:
|
if sys.maxunicode == 65535:
|
||||||
raise SystemError(Errors.E130)
|
raise SystemError(Errors.E130)
|
||||||
|
|
|
@ -29,7 +29,6 @@ from ..featureextractor import FeatureExtractor
|
||||||
from ..staticvectors import StaticVectors
|
from ..staticvectors import StaticVectors
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.Tok2VecListener.v1")
|
|
||||||
def tok2vec_listener_v1(width: int, upstream: str = "*"):
|
def tok2vec_listener_v1(width: int, upstream: str = "*"):
|
||||||
tok2vec = Tok2VecListener(upstream_name=upstream, width=width)
|
tok2vec = Tok2VecListener(upstream_name=upstream, width=width)
|
||||||
return tok2vec
|
return tok2vec
|
||||||
|
@ -46,7 +45,6 @@ def get_tok2vec_width(model: Model):
|
||||||
return nO
|
return nO
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.HashEmbedCNN.v2")
|
|
||||||
def build_hash_embed_cnn_tok2vec(
|
def build_hash_embed_cnn_tok2vec(
|
||||||
*,
|
*,
|
||||||
width: int,
|
width: int,
|
||||||
|
|
|
@ -183,7 +183,6 @@ def ner_score(examples, **kwargs):
|
||||||
return get_ner_prf(examples, **kwargs)
|
return get_ner_prf(examples, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@registry.scorers("spacy.ner_scorer.v1")
|
|
||||||
def make_ner_scorer():
|
def make_ner_scorer():
|
||||||
return ner_score
|
return ner_score
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,6 @@ def tagger_score(examples, **kwargs):
|
||||||
return Scorer.score_token_attr(examples, "tag", **kwargs)
|
return Scorer.score_token_attr(examples, "tag", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@registry.scorers("spacy.tagger_scorer.v1")
|
|
||||||
def make_tagger_scorer():
|
def make_tagger_scorer():
|
||||||
return tagger_score
|
return tagger_score
|
||||||
|
|
||||||
|
|
|
@ -132,9 +132,17 @@ class registry(thinc.registry):
|
||||||
models = catalogue.create("spacy", "models", entry_points=True)
|
models = catalogue.create("spacy", "models", entry_points=True)
|
||||||
cli = catalogue.create("spacy", "cli", entry_points=True)
|
cli = catalogue.create("spacy", "cli", entry_points=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def ensure_populated(cls) -> None:
|
||||||
|
"""Ensure the registry is populated with all necessary components."""
|
||||||
|
from .registrations import populate_registry, REGISTRY_POPULATED
|
||||||
|
if not REGISTRY_POPULATED:
|
||||||
|
populate_registry()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_registry_names(cls) -> List[str]:
|
def get_registry_names(cls) -> List[str]:
|
||||||
"""List all available registries."""
|
"""List all available registries."""
|
||||||
|
cls.ensure_populated()
|
||||||
names = []
|
names = []
|
||||||
for name, value in inspect.getmembers(cls):
|
for name, value in inspect.getmembers(cls):
|
||||||
if not name.startswith("_") and isinstance(value, Registry):
|
if not name.startswith("_") and isinstance(value, Registry):
|
||||||
|
@ -144,6 +152,7 @@ class registry(thinc.registry):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, registry_name: str, func_name: str) -> Callable:
|
def get(cls, registry_name: str, func_name: str) -> Callable:
|
||||||
"""Get a registered function from the registry."""
|
"""Get a registered function from the registry."""
|
||||||
|
cls.ensure_populated()
|
||||||
# We're overwriting this classmethod so we're able to provide more
|
# We're overwriting this classmethod so we're able to provide more
|
||||||
# specific error messages and implement a fallback to spacy-legacy.
|
# specific error messages and implement a fallback to spacy-legacy.
|
||||||
if not hasattr(cls, registry_name):
|
if not hasattr(cls, registry_name):
|
||||||
|
@ -179,6 +188,7 @@ class registry(thinc.registry):
|
||||||
func_name (str): Name of the registered function.
|
func_name (str): Name of the registered function.
|
||||||
RETURNS (Dict[str, Optional[Union[str, int]]]): The function info.
|
RETURNS (Dict[str, Optional[Union[str, int]]]): The function info.
|
||||||
"""
|
"""
|
||||||
|
cls.ensure_populated()
|
||||||
# We're overwriting this classmethod so we're able to provide more
|
# We're overwriting this classmethod so we're able to provide more
|
||||||
# specific error messages and implement a fallback to spacy-legacy.
|
# specific error messages and implement a fallback to spacy-legacy.
|
||||||
if not hasattr(cls, registry_name):
|
if not hasattr(cls, registry_name):
|
||||||
|
@ -205,6 +215,7 @@ class registry(thinc.registry):
|
||||||
@classmethod
|
@classmethod
|
||||||
def has(cls, registry_name: str, func_name: str) -> bool:
|
def has(cls, registry_name: str, func_name: str) -> bool:
|
||||||
"""Check whether a function is available in a registry."""
|
"""Check whether a function is available in a registry."""
|
||||||
|
cls.ensure_populated()
|
||||||
if not hasattr(cls, registry_name):
|
if not hasattr(cls, registry_name):
|
||||||
return False
|
return False
|
||||||
reg = getattr(cls, registry_name)
|
reg = getattr(cls, registry_name)
|
||||||
|
@ -1323,7 +1334,6 @@ def filter_chain_spans(*spans: Iterable["Span"]) -> List["Span"]:
|
||||||
return filter_spans(itertools.chain(*spans))
|
return filter_spans(itertools.chain(*spans))
|
||||||
|
|
||||||
|
|
||||||
@registry.misc("spacy.first_longest_spans_filter.v1")
|
|
||||||
def make_first_longest_spans_filter():
|
def make_first_longest_spans_filter():
|
||||||
return filter_chain_spans
|
return filter_chain_spans
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user