Move more functions

This commit is contained in:
Matthew Honnibal 2025-05-21 23:12:47 +02:00
parent e6c190bc86
commit ab5f1c1013
7 changed files with 27 additions and 9 deletions

View File

@ -7,7 +7,6 @@ from ..tokens import Doc
from ..util import registry from ..util import registry
@registry.layers("spacy.CharEmbed.v1")
def CharacterEmbed(nM: int, nC: int) -> Model[List[Doc], List[Floats2d]]: def CharacterEmbed(nM: int, nC: int) -> Model[List[Doc], List[Floats2d]]:
# nM: Number of dimensions per character. nC: Number of characters. # nM: Number of dimensions per character. nC: Number of characters.
return Model( return Model(

View File

@ -3,7 +3,6 @@ from thinc.api import Model, normal_init
from ..util import registry from ..util import registry
@registry.layers("spacy.PrecomputableAffine.v1")
def PrecomputableAffine(nO, nI, nF, nP, dropout=0.1): def PrecomputableAffine(nO, nI, nF, nP, dropout=0.1):
model = Model( model = Model(
"precomputable_affine", "precomputable_affine",

View File

@ -50,7 +50,6 @@ def models_with_nvtx_range(nlp, forward_color: int, backprop_color: int):
return nlp return nlp
@registry.callbacks("spacy.models_with_nvtx_range.v1")
def create_models_with_nvtx_range( def create_models_with_nvtx_range(
forward_color: int = -1, backprop_color: int = -1 forward_color: int = -1, backprop_color: int = -1
) -> Callable[["Language"], "Language"]: ) -> Callable[["Language"], "Language"]:
@ -110,7 +109,6 @@ def pipes_with_nvtx_range(
return nlp return nlp
@registry.callbacks("spacy.models_and_pipes_with_nvtx_range.v1")
def create_models_and_pipes_with_nvtx_range( def create_models_and_pipes_with_nvtx_range(
forward_color: int = -1, forward_color: int = -1,
backprop_color: int = -1, backprop_color: int = -1,

View File

@ -10,7 +10,6 @@ InT = List[Doc]
OutT = Floats2d OutT = Floats2d
@registry.architectures("spacy.SpanFinder.v1")
def build_finder_model( def build_finder_model(
tok2vec: Model[InT, List[Floats2d]], scorer: Model[OutT, OutT] tok2vec: Model[InT, List[Floats2d]], scorer: Model[OutT, OutT]
) -> Model[InT, OutT]: ) -> Model[InT, OutT]:

View File

@ -13,7 +13,6 @@ from ..vectors import Mode, Vectors
from ..vocab import Vocab from ..vocab import Vocab
@registry.layers("spacy.StaticVectors.v2")
def StaticVectors( def StaticVectors(
nO: Optional[int] = None, nO: Optional[int] = None,
nM: Optional[int] = None, nM: Optional[int] = None,

View File

@ -107,14 +107,27 @@ def populate_registry() -> None:
build_mean_max_reducer, build_mean_max_reducer,
build_spancat_model build_spancat_model
) )
from .ml.models.span_finder import build_finder_model
from .ml.models.parser import build_tb_parser_model
from .ml.models.multi_task import create_pretrain_vectors
from .ml.models.tagger import build_tagger_model
from .ml.staticvectors import StaticVectors
from .ml._precomputable_affine import PrecomputableAffine
from .ml._character_embed import CharacterEmbed
from .matcher.levenshtein import make_levenshtein_compare from .matcher.levenshtein import make_levenshtein_compare
from .training.callbacks import create_copy_from_base_model from .training.callbacks import create_copy_from_base_model
from .ml.callbacks import create_models_with_nvtx_range, create_models_and_pipes_with_nvtx_range
from .training.loggers import console_logger, console_logger_v3 from .training.loggers import console_logger, console_logger_v3
from .training.batchers import ( from .training.batchers import (
configure_minibatch_by_padded_size, configure_minibatch_by_padded_size,
configure_minibatch_by_words, configure_minibatch_by_words,
configure_minibatch configure_minibatch
) )
from .training.augment import (
create_combined_augmenter,
create_lower_casing_augmenter,
create_orth_variants_augmenter
)
# Register scorers # Register scorers
registry.scorers("spacy.tagger_scorer.v1")(make_tagger_scorer) registry.scorers("spacy.tagger_scorer.v1")(make_tagger_scorer)
@ -156,6 +169,10 @@ def populate_registry() -> None:
registry.architectures("spacy.TextCatParametricAttention.v1")(build_textcat_parametric_attention_v1) registry.architectures("spacy.TextCatParametricAttention.v1")(build_textcat_parametric_attention_v1)
registry.architectures("spacy.TextCatReduce.v1")(build_reduce_text_classifier) registry.architectures("spacy.TextCatReduce.v1")(build_reduce_text_classifier)
registry.architectures("spacy.SpanCategorizer.v1")(build_spancat_model) registry.architectures("spacy.SpanCategorizer.v1")(build_spancat_model)
registry.architectures("spacy.SpanFinder.v1")(build_finder_model)
registry.architectures("spacy.TransitionBasedParser.v2")(build_tb_parser_model)
registry.architectures("spacy.PretrainVectors.v1")(create_pretrain_vectors)
registry.architectures("spacy.Tagger.v2")(build_tagger_model)
# Register layers # Register layers
registry.layers("spacy.FeatureExtractor.v1")(FeatureExtractor) registry.layers("spacy.FeatureExtractor.v1")(FeatureExtractor)
@ -163,9 +180,14 @@ def populate_registry() -> None:
registry.layers("spacy.extract_ngrams.v1")(extract_ngrams) registry.layers("spacy.extract_ngrams.v1")(extract_ngrams)
registry.layers("spacy.LinearLogistic.v1")(build_linear_logistic) registry.layers("spacy.LinearLogistic.v1")(build_linear_logistic)
registry.layers("spacy.mean_max_reducer.v1")(build_mean_max_reducer) registry.layers("spacy.mean_max_reducer.v1")(build_mean_max_reducer)
registry.layers("spacy.StaticVectors.v2")(StaticVectors)
registry.layers("spacy.PrecomputableAffine.v1")(PrecomputableAffine)
registry.layers("spacy.CharEmbed.v1")(CharacterEmbed)
# Register callbacks # Register callbacks
registry.callbacks("spacy.copy_from_base_model.v1")(create_copy_from_base_model) registry.callbacks("spacy.copy_from_base_model.v1")(create_copy_from_base_model)
registry.callbacks("spacy.models_with_nvtx_range.v1")(create_models_with_nvtx_range)
registry.callbacks("spacy.models_and_pipes_with_nvtx_range.v1")(create_models_and_pipes_with_nvtx_range)
# Register loggers # Register loggers
registry.loggers("spacy.ConsoleLogger.v2")(console_logger) registry.loggers("spacy.ConsoleLogger.v2")(console_logger)
@ -176,5 +198,10 @@ def populate_registry() -> None:
registry.batchers("spacy.batch_by_words.v1")(configure_minibatch_by_words) registry.batchers("spacy.batch_by_words.v1")(configure_minibatch_by_words)
registry.batchers("spacy.batch_by_sequence.v1")(configure_minibatch) registry.batchers("spacy.batch_by_sequence.v1")(configure_minibatch)
# Register augmenters
registry.augmenters("spacy.combined_augmenter.v1")(create_combined_augmenter)
registry.augmenters("spacy.lower_case.v1")(create_lower_casing_augmenter)
registry.augmenters("spacy.orth_variants.v1")(create_orth_variants_augmenter)
# Set the flag to indicate that the registry has been populated # Set the flag to indicate that the registry has been populated
REGISTRY_POPULATED = True REGISTRY_POPULATED = True

View File

@ -11,7 +11,6 @@ if TYPE_CHECKING:
from ..language import Language # noqa: F401 from ..language import Language # noqa: F401
@registry.augmenters("spacy.combined_augmenter.v1")
def create_combined_augmenter( def create_combined_augmenter(
lower_level: float, lower_level: float,
orth_level: float, orth_level: float,
@ -84,7 +83,6 @@ def combined_augmenter(
yield example yield example
@registry.augmenters("spacy.orth_variants.v1")
def create_orth_variants_augmenter( def create_orth_variants_augmenter(
level: float, lower: float, orth_variants: Dict[str, List[Dict]] level: float, lower: float, orth_variants: Dict[str, List[Dict]]
) -> Callable[["Language", Example], Iterator[Example]]: ) -> Callable[["Language", Example], Iterator[Example]]:
@ -102,7 +100,6 @@ def create_orth_variants_augmenter(
) )
@registry.augmenters("spacy.lower_case.v1")
def create_lower_casing_augmenter( def create_lower_casing_augmenter(
level: float, level: float,
) -> Callable[["Language", Example], Iterator[Example]]: ) -> Callable[["Language", Example], Iterator[Example]]: