Move more functions

This commit is contained in:
Matthew Honnibal 2025-05-21 22:59:54 +02:00
parent 24b0670be7
commit e6c190bc86
6 changed files with 43 additions and 16 deletions

View File

@ -22,7 +22,6 @@ from ...util import registry
from ..extract_spans import extract_spans
@registry.layers("spacy.LinearLogistic.v1")
def build_linear_logistic(nO=None, nI=None) -> Model[Floats2d, Floats2d]:
"""An output layer for multi-label classification. It uses a linear layer
followed by a logistic activation.
@ -30,7 +29,6 @@ def build_linear_logistic(nO=None, nI=None) -> Model[Floats2d, Floats2d]:
return chain(Linear(nO=nO, nI=nI, init_W=glorot_uniform_init), Logistic())
@registry.layers("spacy.mean_max_reducer.v1")
def build_mean_max_reducer(hidden_size: int) -> Model[Ragged, Floats2d]:
"""Reduce sequences by concatenating their mean and max pooled vectors,
and then combine the concatenated vectors with a hidden layer.
@ -46,7 +44,6 @@ def build_mean_max_reducer(hidden_size: int) -> Model[Ragged, Floats2d]:
)
@registry.architectures("spacy.SpanCategorizer.v1")
def build_spancat_model(
tok2vec: Model[List[Doc], List[Floats2d]],
reducer: Model[Ragged, Floats2d],

View File

@ -44,7 +44,6 @@ from .tok2vec import get_tok2vec_width
NEG_VALUE = -5000
@registry.architectures("spacy.TextCatCNN.v2")
def build_simple_cnn_text_classifier(
tok2vec: Model, exclusive_classes: bool, nO: Optional[int] = None
) -> Model[List[Doc], Floats2d]:
@ -72,7 +71,6 @@ def resize_and_set_ref(model, new_nO, resizable_layer):
return model
@registry.architectures("spacy.TextCatBOW.v2")
def build_bow_text_classifier(
exclusive_classes: bool,
ngram_size: int,
@ -88,7 +86,6 @@ def build_bow_text_classifier(
)
@registry.architectures("spacy.TextCatBOW.v3")
def build_bow_text_classifier_v3(
exclusive_classes: bool,
ngram_size: int,
@ -142,7 +139,6 @@ def _build_bow_text_classifier(
return model
@registry.architectures("spacy.TextCatEnsemble.v2")
def build_text_classifier_v2(
tok2vec: Model[List[Doc], List[Floats2d]],
linear_model: Model[List[Doc], Floats2d],
@ -200,7 +196,6 @@ def init_ensemble_textcat(model, X, Y) -> Model:
return model
@registry.architectures("spacy.TextCatLowData.v1")
def build_text_classifier_lowdata(
width: int, dropout: Optional[float], nO: Optional[int] = None
) -> Model[List[Doc], Floats2d]:
@ -221,7 +216,6 @@ def build_text_classifier_lowdata(
return model
@registry.architectures("spacy.TextCatParametricAttention.v1")
def build_textcat_parametric_attention_v1(
tok2vec: Model[List[Doc], List[Floats2d]],
exclusive_classes: bool,
@ -294,7 +288,6 @@ def _init_parametric_attention_with_residual_nonlinear(model, X, Y) -> Model:
return model
@registry.architectures("spacy.TextCatReduce.v1")
def build_reduce_text_classifier(
tok2vec: Model,
exclusive_classes: bool,

View File

@ -93,7 +93,28 @@ def populate_registry() -> None:
create_candidates,
create_candidates_batch
)
from .ml.models.textcat import (
build_simple_cnn_text_classifier,
build_bow_text_classifier,
build_bow_text_classifier_v3,
build_text_classifier_v2,
build_text_classifier_lowdata,
build_textcat_parametric_attention_v1,
build_reduce_text_classifier
)
from .ml.models.spancat import (
build_linear_logistic,
build_mean_max_reducer,
build_spancat_model
)
from .matcher.levenshtein import make_levenshtein_compare
from .training.callbacks import create_copy_from_base_model
from .training.loggers import console_logger, console_logger_v3
from .training.batchers import (
configure_minibatch_by_padded_size,
configure_minibatch_by_words,
configure_minibatch
)
# Register scorers
registry.scorers("spacy.tagger_scorer.v1")(make_tagger_scorer)
@ -127,11 +148,33 @@ def populate_registry() -> None:
registry.architectures("spacy.MishWindowEncoder.v2")(MishWindowEncoder)
registry.architectures("spacy.TorchBiLSTMEncoder.v1")(BiLSTMEncoder)
registry.architectures("spacy.EntityLinker.v2")(build_nel_encoder)
registry.architectures("spacy.TextCatCNN.v2")(build_simple_cnn_text_classifier)
registry.architectures("spacy.TextCatBOW.v2")(build_bow_text_classifier)
registry.architectures("spacy.TextCatBOW.v3")(build_bow_text_classifier_v3)
registry.architectures("spacy.TextCatEnsemble.v2")(build_text_classifier_v2)
registry.architectures("spacy.TextCatLowData.v1")(build_text_classifier_lowdata)
registry.architectures("spacy.TextCatParametricAttention.v1")(build_textcat_parametric_attention_v1)
registry.architectures("spacy.TextCatReduce.v1")(build_reduce_text_classifier)
registry.architectures("spacy.SpanCategorizer.v1")(build_spancat_model)
# Register layers
registry.layers("spacy.FeatureExtractor.v1")(FeatureExtractor)
registry.layers("spacy.extract_spans.v1")(extract_spans)
registry.layers("spacy.extract_ngrams.v1")(extract_ngrams)
registry.layers("spacy.LinearLogistic.v1")(build_linear_logistic)
registry.layers("spacy.mean_max_reducer.v1")(build_mean_max_reducer)
# Register callbacks
registry.callbacks("spacy.copy_from_base_model.v1")(create_copy_from_base_model)
# Register loggers
registry.loggers("spacy.ConsoleLogger.v2")(console_logger)
registry.loggers("spacy.ConsoleLogger.v3")(console_logger_v3)
# Register batchers
registry.batchers("spacy.batch_by_padded.v1")(configure_minibatch_by_padded_size)
registry.batchers("spacy.batch_by_words.v1")(configure_minibatch_by_words)
registry.batchers("spacy.batch_by_sequence.v1")(configure_minibatch)
# Set the flag to indicate that the registry has been populated
REGISTRY_POPULATED = True

View File

@ -19,7 +19,6 @@ ItemT = TypeVar("ItemT")
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
@registry.batchers("spacy.batch_by_padded.v1")
def configure_minibatch_by_padded_size(
*,
size: Sizing,
@ -54,7 +53,6 @@ def configure_minibatch_by_padded_size(
)
@registry.batchers("spacy.batch_by_words.v1")
def configure_minibatch_by_words(
*,
size: Sizing,
@ -82,7 +80,6 @@ def configure_minibatch_by_words(
)
@registry.batchers("spacy.batch_by_sequence.v1")
def configure_minibatch(
size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None
) -> BatcherT:

View File

@ -7,7 +7,6 @@ if TYPE_CHECKING:
from ..language import Language
@registry.callbacks("spacy.copy_from_base_model.v1")
def create_copy_from_base_model(
tokenizer: Optional[str] = None,
vocab: Optional[str] = None,

View File

@ -29,7 +29,6 @@ def setup_table(
# We cannot rename this method as it's directly imported
# and used by external packages such as spacy-loggers.
@registry.loggers("spacy.ConsoleLogger.v2")
def console_logger(
progress_bar: bool = False,
console_output: bool = True,
@ -47,7 +46,6 @@ def console_logger(
)
@registry.loggers("spacy.ConsoleLogger.v3")
def console_logger_v3(
progress_bar: Optional[str] = None,
console_output: bool = True,