mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-13 09:42:26 +03:00
Move more functions
This commit is contained in:
parent
24b0670be7
commit
e6c190bc86
|
@ -22,7 +22,6 @@ from ...util import registry
|
||||||
from ..extract_spans import extract_spans
|
from ..extract_spans import extract_spans
|
||||||
|
|
||||||
|
|
||||||
@registry.layers("spacy.LinearLogistic.v1")
|
|
||||||
def build_linear_logistic(nO=None, nI=None) -> Model[Floats2d, Floats2d]:
|
def build_linear_logistic(nO=None, nI=None) -> Model[Floats2d, Floats2d]:
|
||||||
"""An output layer for multi-label classification. It uses a linear layer
|
"""An output layer for multi-label classification. It uses a linear layer
|
||||||
followed by a logistic activation.
|
followed by a logistic activation.
|
||||||
|
@ -30,7 +29,6 @@ def build_linear_logistic(nO=None, nI=None) -> Model[Floats2d, Floats2d]:
|
||||||
return chain(Linear(nO=nO, nI=nI, init_W=glorot_uniform_init), Logistic())
|
return chain(Linear(nO=nO, nI=nI, init_W=glorot_uniform_init), Logistic())
|
||||||
|
|
||||||
|
|
||||||
@registry.layers("spacy.mean_max_reducer.v1")
|
|
||||||
def build_mean_max_reducer(hidden_size: int) -> Model[Ragged, Floats2d]:
|
def build_mean_max_reducer(hidden_size: int) -> Model[Ragged, Floats2d]:
|
||||||
"""Reduce sequences by concatenating their mean and max pooled vectors,
|
"""Reduce sequences by concatenating their mean and max pooled vectors,
|
||||||
and then combine the concatenated vectors with a hidden layer.
|
and then combine the concatenated vectors with a hidden layer.
|
||||||
|
@ -46,7 +44,6 @@ def build_mean_max_reducer(hidden_size: int) -> Model[Ragged, Floats2d]:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.SpanCategorizer.v1")
|
|
||||||
def build_spancat_model(
|
def build_spancat_model(
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
reducer: Model[Ragged, Floats2d],
|
reducer: Model[Ragged, Floats2d],
|
||||||
|
|
|
@ -44,7 +44,6 @@ from .tok2vec import get_tok2vec_width
|
||||||
NEG_VALUE = -5000
|
NEG_VALUE = -5000
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.TextCatCNN.v2")
|
|
||||||
def build_simple_cnn_text_classifier(
|
def build_simple_cnn_text_classifier(
|
||||||
tok2vec: Model, exclusive_classes: bool, nO: Optional[int] = None
|
tok2vec: Model, exclusive_classes: bool, nO: Optional[int] = None
|
||||||
) -> Model[List[Doc], Floats2d]:
|
) -> Model[List[Doc], Floats2d]:
|
||||||
|
@ -72,7 +71,6 @@ def resize_and_set_ref(model, new_nO, resizable_layer):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.TextCatBOW.v2")
|
|
||||||
def build_bow_text_classifier(
|
def build_bow_text_classifier(
|
||||||
exclusive_classes: bool,
|
exclusive_classes: bool,
|
||||||
ngram_size: int,
|
ngram_size: int,
|
||||||
|
@ -88,7 +86,6 @@ def build_bow_text_classifier(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.TextCatBOW.v3")
|
|
||||||
def build_bow_text_classifier_v3(
|
def build_bow_text_classifier_v3(
|
||||||
exclusive_classes: bool,
|
exclusive_classes: bool,
|
||||||
ngram_size: int,
|
ngram_size: int,
|
||||||
|
@ -142,7 +139,6 @@ def _build_bow_text_classifier(
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.TextCatEnsemble.v2")
|
|
||||||
def build_text_classifier_v2(
|
def build_text_classifier_v2(
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
linear_model: Model[List[Doc], Floats2d],
|
linear_model: Model[List[Doc], Floats2d],
|
||||||
|
@ -200,7 +196,6 @@ def init_ensemble_textcat(model, X, Y) -> Model:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.TextCatLowData.v1")
|
|
||||||
def build_text_classifier_lowdata(
|
def build_text_classifier_lowdata(
|
||||||
width: int, dropout: Optional[float], nO: Optional[int] = None
|
width: int, dropout: Optional[float], nO: Optional[int] = None
|
||||||
) -> Model[List[Doc], Floats2d]:
|
) -> Model[List[Doc], Floats2d]:
|
||||||
|
@ -221,7 +216,6 @@ def build_text_classifier_lowdata(
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.TextCatParametricAttention.v1")
|
|
||||||
def build_textcat_parametric_attention_v1(
|
def build_textcat_parametric_attention_v1(
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
exclusive_classes: bool,
|
exclusive_classes: bool,
|
||||||
|
@ -294,7 +288,6 @@ def _init_parametric_attention_with_residual_nonlinear(model, X, Y) -> Model:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.TextCatReduce.v1")
|
|
||||||
def build_reduce_text_classifier(
|
def build_reduce_text_classifier(
|
||||||
tok2vec: Model,
|
tok2vec: Model,
|
||||||
exclusive_classes: bool,
|
exclusive_classes: bool,
|
||||||
|
|
|
@ -93,7 +93,28 @@ def populate_registry() -> None:
|
||||||
create_candidates,
|
create_candidates,
|
||||||
create_candidates_batch
|
create_candidates_batch
|
||||||
)
|
)
|
||||||
|
from .ml.models.textcat import (
|
||||||
|
build_simple_cnn_text_classifier,
|
||||||
|
build_bow_text_classifier,
|
||||||
|
build_bow_text_classifier_v3,
|
||||||
|
build_text_classifier_v2,
|
||||||
|
build_text_classifier_lowdata,
|
||||||
|
build_textcat_parametric_attention_v1,
|
||||||
|
build_reduce_text_classifier
|
||||||
|
)
|
||||||
|
from .ml.models.spancat import (
|
||||||
|
build_linear_logistic,
|
||||||
|
build_mean_max_reducer,
|
||||||
|
build_spancat_model
|
||||||
|
)
|
||||||
from .matcher.levenshtein import make_levenshtein_compare
|
from .matcher.levenshtein import make_levenshtein_compare
|
||||||
|
from .training.callbacks import create_copy_from_base_model
|
||||||
|
from .training.loggers import console_logger, console_logger_v3
|
||||||
|
from .training.batchers import (
|
||||||
|
configure_minibatch_by_padded_size,
|
||||||
|
configure_minibatch_by_words,
|
||||||
|
configure_minibatch
|
||||||
|
)
|
||||||
|
|
||||||
# Register scorers
|
# Register scorers
|
||||||
registry.scorers("spacy.tagger_scorer.v1")(make_tagger_scorer)
|
registry.scorers("spacy.tagger_scorer.v1")(make_tagger_scorer)
|
||||||
|
@ -127,11 +148,33 @@ def populate_registry() -> None:
|
||||||
registry.architectures("spacy.MishWindowEncoder.v2")(MishWindowEncoder)
|
registry.architectures("spacy.MishWindowEncoder.v2")(MishWindowEncoder)
|
||||||
registry.architectures("spacy.TorchBiLSTMEncoder.v1")(BiLSTMEncoder)
|
registry.architectures("spacy.TorchBiLSTMEncoder.v1")(BiLSTMEncoder)
|
||||||
registry.architectures("spacy.EntityLinker.v2")(build_nel_encoder)
|
registry.architectures("spacy.EntityLinker.v2")(build_nel_encoder)
|
||||||
|
registry.architectures("spacy.TextCatCNN.v2")(build_simple_cnn_text_classifier)
|
||||||
|
registry.architectures("spacy.TextCatBOW.v2")(build_bow_text_classifier)
|
||||||
|
registry.architectures("spacy.TextCatBOW.v3")(build_bow_text_classifier_v3)
|
||||||
|
registry.architectures("spacy.TextCatEnsemble.v2")(build_text_classifier_v2)
|
||||||
|
registry.architectures("spacy.TextCatLowData.v1")(build_text_classifier_lowdata)
|
||||||
|
registry.architectures("spacy.TextCatParametricAttention.v1")(build_textcat_parametric_attention_v1)
|
||||||
|
registry.architectures("spacy.TextCatReduce.v1")(build_reduce_text_classifier)
|
||||||
|
registry.architectures("spacy.SpanCategorizer.v1")(build_spancat_model)
|
||||||
|
|
||||||
# Register layers
|
# Register layers
|
||||||
registry.layers("spacy.FeatureExtractor.v1")(FeatureExtractor)
|
registry.layers("spacy.FeatureExtractor.v1")(FeatureExtractor)
|
||||||
registry.layers("spacy.extract_spans.v1")(extract_spans)
|
registry.layers("spacy.extract_spans.v1")(extract_spans)
|
||||||
registry.layers("spacy.extract_ngrams.v1")(extract_ngrams)
|
registry.layers("spacy.extract_ngrams.v1")(extract_ngrams)
|
||||||
|
registry.layers("spacy.LinearLogistic.v1")(build_linear_logistic)
|
||||||
|
registry.layers("spacy.mean_max_reducer.v1")(build_mean_max_reducer)
|
||||||
|
|
||||||
|
# Register callbacks
|
||||||
|
registry.callbacks("spacy.copy_from_base_model.v1")(create_copy_from_base_model)
|
||||||
|
|
||||||
|
# Register loggers
|
||||||
|
registry.loggers("spacy.ConsoleLogger.v2")(console_logger)
|
||||||
|
registry.loggers("spacy.ConsoleLogger.v3")(console_logger_v3)
|
||||||
|
|
||||||
|
# Register batchers
|
||||||
|
registry.batchers("spacy.batch_by_padded.v1")(configure_minibatch_by_padded_size)
|
||||||
|
registry.batchers("spacy.batch_by_words.v1")(configure_minibatch_by_words)
|
||||||
|
registry.batchers("spacy.batch_by_sequence.v1")(configure_minibatch)
|
||||||
|
|
||||||
# Set the flag to indicate that the registry has been populated
|
# Set the flag to indicate that the registry has been populated
|
||||||
REGISTRY_POPULATED = True
|
REGISTRY_POPULATED = True
|
||||||
|
|
|
@ -19,7 +19,6 @@ ItemT = TypeVar("ItemT")
|
||||||
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
|
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
|
||||||
|
|
||||||
|
|
||||||
@registry.batchers("spacy.batch_by_padded.v1")
|
|
||||||
def configure_minibatch_by_padded_size(
|
def configure_minibatch_by_padded_size(
|
||||||
*,
|
*,
|
||||||
size: Sizing,
|
size: Sizing,
|
||||||
|
@ -54,7 +53,6 @@ def configure_minibatch_by_padded_size(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@registry.batchers("spacy.batch_by_words.v1")
|
|
||||||
def configure_minibatch_by_words(
|
def configure_minibatch_by_words(
|
||||||
*,
|
*,
|
||||||
size: Sizing,
|
size: Sizing,
|
||||||
|
@ -82,7 +80,6 @@ def configure_minibatch_by_words(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@registry.batchers("spacy.batch_by_sequence.v1")
|
|
||||||
def configure_minibatch(
|
def configure_minibatch(
|
||||||
size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None
|
size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None
|
||||||
) -> BatcherT:
|
) -> BatcherT:
|
||||||
|
|
|
@ -7,7 +7,6 @@ if TYPE_CHECKING:
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
|
|
||||||
|
|
||||||
@registry.callbacks("spacy.copy_from_base_model.v1")
|
|
||||||
def create_copy_from_base_model(
|
def create_copy_from_base_model(
|
||||||
tokenizer: Optional[str] = None,
|
tokenizer: Optional[str] = None,
|
||||||
vocab: Optional[str] = None,
|
vocab: Optional[str] = None,
|
||||||
|
|
|
@ -29,7 +29,6 @@ def setup_table(
|
||||||
|
|
||||||
# We cannot rename this method as it's directly imported
|
# We cannot rename this method as it's directly imported
|
||||||
# and used by external packages such as spacy-loggers.
|
# and used by external packages such as spacy-loggers.
|
||||||
@registry.loggers("spacy.ConsoleLogger.v2")
|
|
||||||
def console_logger(
|
def console_logger(
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
console_output: bool = True,
|
console_output: bool = True,
|
||||||
|
@ -47,7 +46,6 @@ def console_logger(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@registry.loggers("spacy.ConsoleLogger.v3")
|
|
||||||
def console_logger_v3(
|
def console_logger_v3(
|
||||||
progress_bar: Optional[str] = None,
|
progress_bar: Optional[str] = None,
|
||||||
console_output: bool = True,
|
console_output: bool = True,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user