mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-12 17:22:25 +03:00
Centralise registrations
This commit is contained in:
parent
43f87b991b
commit
15bd029be5
|
@ -100,7 +100,6 @@ def build_hash_embed_cnn_tok2vec(
|
|||
)
|
||||
|
||||
|
||||
@registry.architectures("spacy.Tok2Vec.v2")
|
||||
def build_Tok2Vec_model(
|
||||
embed: Model[List[Doc], List[Floats2d]],
|
||||
encode: Model[List[Floats2d], List[Floats2d]],
|
||||
|
@ -121,7 +120,6 @@ def build_Tok2Vec_model(
|
|||
return tok2vec
|
||||
|
||||
|
||||
@registry.architectures("spacy.MultiHashEmbed.v2")
|
||||
def MultiHashEmbed(
|
||||
width: int,
|
||||
attrs: List[Union[str, int]],
|
||||
|
@ -199,7 +197,6 @@ def MultiHashEmbed(
|
|||
return model
|
||||
|
||||
|
||||
@registry.architectures("spacy.CharacterEmbed.v2")
|
||||
def CharacterEmbed(
|
||||
width: int,
|
||||
rows: int,
|
||||
|
@ -276,7 +273,6 @@ def CharacterEmbed(
|
|||
return model
|
||||
|
||||
|
||||
@registry.architectures("spacy.MaxoutWindowEncoder.v2")
|
||||
def MaxoutWindowEncoder(
|
||||
width: int, window_size: int, maxout_pieces: int, depth: int
|
||||
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||
|
@ -308,7 +304,6 @@ def MaxoutWindowEncoder(
|
|||
return with_array(model, pad=receptive_field)
|
||||
|
||||
|
||||
@registry.architectures("spacy.MishWindowEncoder.v2")
|
||||
def MishWindowEncoder(
|
||||
width: int, window_size: int, depth: int
|
||||
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||
|
@ -331,7 +326,6 @@ def MishWindowEncoder(
|
|||
return with_array(model)
|
||||
|
||||
|
||||
@registry.architectures("spacy.TorchBiLSTMEncoder.v1")
|
||||
def BiLSTMEncoder(
|
||||
width: int, depth: int, dropout: float
|
||||
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||
|
|
|
@ -63,7 +63,6 @@ def entity_ruler_score(examples, **kwargs):
|
|||
return get_ner_prf(examples)
|
||||
|
||||
|
||||
@registry.scorers("spacy.entity_ruler_scorer.v1")
|
||||
def make_entity_ruler_scorer():
|
||||
return entity_ruler_score
|
||||
|
||||
|
|
|
@ -44,7 +44,6 @@ def lemmatizer_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
|||
return Scorer.score_token_attr(examples, "lemma", **kwargs)
|
||||
|
||||
|
||||
@registry.scorers("spacy.lemmatizer_scorer.v1")
|
||||
def make_lemmatizer_scorer():
|
||||
return lemmatizer_score
|
||||
|
||||
|
|
|
@ -53,7 +53,6 @@ def senter_score(examples, **kwargs):
|
|||
return results
|
||||
|
||||
|
||||
@registry.scorers("spacy.senter_scorer.v1")
|
||||
def make_senter_scorer():
|
||||
return senter_score
|
||||
|
||||
|
|
|
@ -97,7 +97,6 @@ def make_span_finder(
|
|||
)
|
||||
|
||||
|
||||
@registry.scorers("spacy.span_finder_scorer.v1")
|
||||
def make_span_finder_scorer():
|
||||
return span_finder_score
|
||||
|
||||
|
|
|
@ -134,7 +134,6 @@ def preset_spans_suggester(
|
|||
return output
|
||||
|
||||
|
||||
@registry.misc("spacy.ngram_suggester.v1")
|
||||
def build_ngram_suggester(sizes: List[int]) -> Suggester:
|
||||
"""Suggest all spans of the given lengths. Spans are returned as a ragged
|
||||
array of integers. The array has two columns, indicating the start and end
|
||||
|
@ -143,7 +142,6 @@ def build_ngram_suggester(sizes: List[int]) -> Suggester:
|
|||
return partial(ngram_suggester, sizes=sizes)
|
||||
|
||||
|
||||
@registry.misc("spacy.ngram_range_suggester.v1")
|
||||
def build_ngram_range_suggester(min_size: int, max_size: int) -> Suggester:
|
||||
"""Suggest all spans of the given lengths between a given min and max value - both inclusive.
|
||||
Spans are returned as a ragged array of integers. The array has two columns,
|
||||
|
@ -152,7 +150,6 @@ def build_ngram_range_suggester(min_size: int, max_size: int) -> Suggester:
|
|||
return build_ngram_suggester(sizes)
|
||||
|
||||
|
||||
@registry.misc("spacy.preset_spans_suggester.v1")
|
||||
def build_preset_spans_suggester(spans_key: str) -> Suggester:
|
||||
"""Suggest all spans that are already stored in doc.spans[spans_key].
|
||||
This is useful when an upstream component is used to set the spans
|
||||
|
@ -303,7 +300,6 @@ def spancat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
|||
return Scorer.score_spans(examples, **kwargs)
|
||||
|
||||
|
||||
@registry.scorers("spacy.spancat_scorer.v1")
|
||||
def make_spancat_scorer():
|
||||
return spancat_score
|
||||
|
||||
|
|
|
@ -123,7 +123,6 @@ def textcat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
|||
)
|
||||
|
||||
|
||||
@registry.scorers("spacy.textcat_scorer.v2")
|
||||
def make_textcat_scorer():
|
||||
return textcat_score
|
||||
|
||||
|
|
|
@ -124,7 +124,6 @@ def textcat_multilabel_score(examples: Iterable[Example], **kwargs) -> Dict[str,
|
|||
)
|
||||
|
||||
|
||||
@registry.scorers("spacy.textcat_multilabel_scorer.v2")
|
||||
def make_textcat_multilabel_scorer():
|
||||
return textcat_multilabel_score
|
||||
|
||||
|
|
|
@ -22,12 +22,23 @@ def populate_registry() -> None:
|
|||
# Import all necessary modules
|
||||
from .util import registry, make_first_longest_spans_filter
|
||||
|
||||
# Register miscellaneous components
|
||||
registry.misc("spacy.first_longest_spans_filter.v1")(make_first_longest_spans_filter)
|
||||
|
||||
# Import all pipeline components that were using registry decorators
|
||||
from .pipeline.tagger import make_tagger_scorer
|
||||
from .pipeline.ner import make_ner_scorer
|
||||
from .pipeline.lemmatizer import make_lemmatizer_scorer
|
||||
from .pipeline.span_finder import make_span_finder_scorer
|
||||
from .pipeline.spancat import make_spancat_scorer, build_ngram_suggester, build_ngram_range_suggester, build_preset_spans_suggester
|
||||
from .pipeline.entityruler import make_entity_ruler_scorer as make_entityruler_scorer
|
||||
from .pipeline.sentencizer import senter_score as make_sentencizer_scorer
|
||||
from .pipeline.senter import make_senter_scorer
|
||||
from .pipeline.textcat import make_textcat_scorer
|
||||
from .pipeline.textcat_multilabel import make_textcat_multilabel_scorer
|
||||
|
||||
# Register miscellaneous components
|
||||
registry.misc("spacy.first_longest_spans_filter.v1")(make_first_longest_spans_filter)
|
||||
registry.misc("spacy.ngram_suggester.v1")(build_ngram_suggester)
|
||||
registry.misc("spacy.ngram_range_suggester.v1")(build_ngram_range_suggester)
|
||||
registry.misc("spacy.preset_spans_suggester.v1")(build_preset_spans_suggester)
|
||||
|
||||
# Need to get references to the existing functions in registry by importing the function that is there
|
||||
# For the registry that was previously decorated
|
||||
|
@ -36,7 +47,7 @@ def populate_registry() -> None:
|
|||
from .scorer import get_ner_prf # Used for entity_ruler_scorer
|
||||
|
||||
# Import ML components that use registry
|
||||
from .ml.models.tok2vec import tok2vec_listener_v1, build_hash_embed_cnn_tok2vec
|
||||
from .ml.models.tok2vec import tok2vec_listener_v1, build_hash_embed_cnn_tok2vec, build_Tok2Vec_model, MultiHashEmbed, CharacterEmbed, MaxoutWindowEncoder, MishWindowEncoder, BiLSTMEncoder
|
||||
|
||||
# Register scorers
|
||||
registry.scorers("spacy.tagger_scorer.v1")(make_tagger_scorer)
|
||||
|
@ -46,13 +57,22 @@ def populate_registry() -> None:
|
|||
registry.scorers("spacy.sentencizer_scorer.v1")(make_sentencizer_scorer)
|
||||
registry.scorers("spacy.senter_scorer.v1")(make_senter_scorer)
|
||||
registry.scorers("spacy.textcat_scorer.v1")(make_textcat_scorer)
|
||||
registry.scorers("spacy.textcat_scorer.v2")(make_textcat_scorer)
|
||||
registry.scorers("spacy.textcat_multilabel_scorer.v1")(make_textcat_multilabel_scorer)
|
||||
registry.scorers("spacy.textcat_multilabel_scorer.v2")(make_textcat_multilabel_scorer)
|
||||
registry.scorers("spacy.lemmatizer_scorer.v1")(make_lemmatizer_scorer)
|
||||
registry.scorers("spacy.span_finder_scorer.v1")(make_span_finder_scorer)
|
||||
registry.scorers("spacy.spancat_scorer.v1")(make_spancat_scorer)
|
||||
|
||||
# Register tok2vec architectures we've modified
|
||||
registry.architectures("spacy.Tok2VecListener.v1")(tok2vec_listener_v1)
|
||||
registry.architectures("spacy.HashEmbedCNN.v2")(build_hash_embed_cnn_tok2vec)
|
||||
registry.architectures("spacy.Tok2Vec.v2")(build_Tok2Vec_model)
|
||||
registry.architectures("spacy.MultiHashEmbed.v2")(MultiHashEmbed)
|
||||
registry.architectures("spacy.CharacterEmbed.v2")(CharacterEmbed)
|
||||
registry.architectures("spacy.MaxoutWindowEncoder.v2")(MaxoutWindowEncoder)
|
||||
registry.architectures("spacy.MishWindowEncoder.v2")(MishWindowEncoder)
|
||||
registry.architectures("spacy.TorchBiLSTMEncoder.v1")(BiLSTMEncoder)
|
||||
|
||||
# Set the flag to indicate that the registry has been populated
|
||||
REGISTRY_POPULATED = True
|
Loading…
Reference in New Issue
Block a user