Centralise registrations

This commit is contained in:
Matthew Honnibal 2025-05-19 13:07:21 +02:00
parent 43f87b991b
commit 15bd029be5
9 changed files with 24 additions and 20 deletions

View File

@ -100,7 +100,6 @@ def build_hash_embed_cnn_tok2vec(
)
@registry.architectures("spacy.Tok2Vec.v2")
def build_Tok2Vec_model(
embed: Model[List[Doc], List[Floats2d]],
encode: Model[List[Floats2d], List[Floats2d]],
@ -121,7 +120,6 @@ def build_Tok2Vec_model(
return tok2vec
@registry.architectures("spacy.MultiHashEmbed.v2")
def MultiHashEmbed(
width: int,
attrs: List[Union[str, int]],
@ -199,7 +197,6 @@ def MultiHashEmbed(
return model
@registry.architectures("spacy.CharacterEmbed.v2")
def CharacterEmbed(
width: int,
rows: int,
@ -276,7 +273,6 @@ def CharacterEmbed(
return model
@registry.architectures("spacy.MaxoutWindowEncoder.v2")
def MaxoutWindowEncoder(
width: int, window_size: int, maxout_pieces: int, depth: int
) -> Model[List[Floats2d], List[Floats2d]]:
@ -308,7 +304,6 @@ def MaxoutWindowEncoder(
return with_array(model, pad=receptive_field)
@registry.architectures("spacy.MishWindowEncoder.v2")
def MishWindowEncoder(
width: int, window_size: int, depth: int
) -> Model[List[Floats2d], List[Floats2d]]:
@ -331,7 +326,6 @@ def MishWindowEncoder(
return with_array(model)
@registry.architectures("spacy.TorchBiLSTMEncoder.v1")
def BiLSTMEncoder(
width: int, depth: int, dropout: float
) -> Model[List[Floats2d], List[Floats2d]]:

View File

@ -63,7 +63,6 @@ def entity_ruler_score(examples, **kwargs):
return get_ner_prf(examples)
@registry.scorers("spacy.entity_ruler_scorer.v1")
def make_entity_ruler_scorer():
return entity_ruler_score

View File

@ -44,7 +44,6 @@ def lemmatizer_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
return Scorer.score_token_attr(examples, "lemma", **kwargs)
@registry.scorers("spacy.lemmatizer_scorer.v1")
def make_lemmatizer_scorer():
return lemmatizer_score

View File

@ -53,7 +53,6 @@ def senter_score(examples, **kwargs):
return results
@registry.scorers("spacy.senter_scorer.v1")
def make_senter_scorer():
return senter_score

View File

@ -97,7 +97,6 @@ def make_span_finder(
)
@registry.scorers("spacy.span_finder_scorer.v1")
def make_span_finder_scorer():
return span_finder_score

View File

@ -134,7 +134,6 @@ def preset_spans_suggester(
return output
@registry.misc("spacy.ngram_suggester.v1")
def build_ngram_suggester(sizes: List[int]) -> Suggester:
"""Suggest all spans of the given lengths. Spans are returned as a ragged
array of integers. The array has two columns, indicating the start and end
@ -143,7 +142,6 @@ def build_ngram_suggester(sizes: List[int]) -> Suggester:
return partial(ngram_suggester, sizes=sizes)
@registry.misc("spacy.ngram_range_suggester.v1")
def build_ngram_range_suggester(min_size: int, max_size: int) -> Suggester:
"""Suggest all spans of the given lengths between a given min and max value - both inclusive.
Spans are returned as a ragged array of integers. The array has two columns,
@ -152,7 +150,6 @@ def build_ngram_range_suggester(min_size: int, max_size: int) -> Suggester:
return build_ngram_suggester(sizes)
@registry.misc("spacy.preset_spans_suggester.v1")
def build_preset_spans_suggester(spans_key: str) -> Suggester:
"""Suggest all spans that are already stored in doc.spans[spans_key].
This is useful when an upstream component is used to set the spans
@ -303,7 +300,6 @@ def spancat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
return Scorer.score_spans(examples, **kwargs)
@registry.scorers("spacy.spancat_scorer.v1")
def make_spancat_scorer():
return spancat_score

View File

@ -123,7 +123,6 @@ def textcat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
)
@registry.scorers("spacy.textcat_scorer.v2")
def make_textcat_scorer():
return textcat_score

View File

@ -124,7 +124,6 @@ def textcat_multilabel_score(examples: Iterable[Example], **kwargs) -> Dict[str,
)
@registry.scorers("spacy.textcat_multilabel_scorer.v2")
def make_textcat_multilabel_scorer():
return textcat_multilabel_score

View File

@ -22,12 +22,23 @@ def populate_registry() -> None:
# Import all necessary modules
from .util import registry, make_first_longest_spans_filter
# Register miscellaneous components
registry.misc("spacy.first_longest_spans_filter.v1")(make_first_longest_spans_filter)
# Import all pipeline components that were using registry decorators
from .pipeline.tagger import make_tagger_scorer
from .pipeline.ner import make_ner_scorer
from .pipeline.lemmatizer import make_lemmatizer_scorer
from .pipeline.span_finder import make_span_finder_scorer
from .pipeline.spancat import make_spancat_scorer, build_ngram_suggester, build_ngram_range_suggester, build_preset_spans_suggester
from .pipeline.entityruler import make_entity_ruler_scorer as make_entityruler_scorer
from .pipeline.sentencizer import senter_score as make_sentencizer_scorer
from .pipeline.senter import make_senter_scorer
from .pipeline.textcat import make_textcat_scorer
from .pipeline.textcat_multilabel import make_textcat_multilabel_scorer
# Register miscellaneous components
registry.misc("spacy.first_longest_spans_filter.v1")(make_first_longest_spans_filter)
registry.misc("spacy.ngram_suggester.v1")(build_ngram_suggester)
registry.misc("spacy.ngram_range_suggester.v1")(build_ngram_range_suggester)
registry.misc("spacy.preset_spans_suggester.v1")(build_preset_spans_suggester)
# Need to get references to the existing functions in registry by importing the function that is there
# For the registry that was previously decorated
@ -36,7 +47,7 @@ def populate_registry() -> None:
from .scorer import get_ner_prf # Used for entity_ruler_scorer
# Import ML components that use registry
from .ml.models.tok2vec import tok2vec_listener_v1, build_hash_embed_cnn_tok2vec
from .ml.models.tok2vec import tok2vec_listener_v1, build_hash_embed_cnn_tok2vec, build_Tok2Vec_model, MultiHashEmbed, CharacterEmbed, MaxoutWindowEncoder, MishWindowEncoder, BiLSTMEncoder
# Register scorers
registry.scorers("spacy.tagger_scorer.v1")(make_tagger_scorer)
@ -46,13 +57,22 @@ def populate_registry() -> None:
registry.scorers("spacy.sentencizer_scorer.v1")(make_sentencizer_scorer)
registry.scorers("spacy.senter_scorer.v1")(make_senter_scorer)
registry.scorers("spacy.textcat_scorer.v1")(make_textcat_scorer)
registry.scorers("spacy.textcat_scorer.v2")(make_textcat_scorer)
registry.scorers("spacy.textcat_multilabel_scorer.v1")(make_textcat_multilabel_scorer)
registry.scorers("spacy.textcat_multilabel_scorer.v2")(make_textcat_multilabel_scorer)
registry.scorers("spacy.lemmatizer_scorer.v1")(make_lemmatizer_scorer)
registry.scorers("spacy.span_finder_scorer.v1")(make_span_finder_scorer)
registry.scorers("spacy.spancat_scorer.v1")(make_spancat_scorer)
# Register tok2vec architectures we've modified
registry.architectures("spacy.Tok2VecListener.v1")(tok2vec_listener_v1)
registry.architectures("spacy.HashEmbedCNN.v2")(build_hash_embed_cnn_tok2vec)
registry.architectures("spacy.Tok2Vec.v2")(build_Tok2Vec_model)
registry.architectures("spacy.MultiHashEmbed.v2")(MultiHashEmbed)
registry.architectures("spacy.CharacterEmbed.v2")(CharacterEmbed)
registry.architectures("spacy.MaxoutWindowEncoder.v2")(MaxoutWindowEncoder)
registry.architectures("spacy.MishWindowEncoder.v2")(MishWindowEncoder)
registry.architectures("spacy.TorchBiLSTMEncoder.v1")(BiLSTMEncoder)
# Set the flag to indicate that the registry has been populated
REGISTRY_POPULATED = True