mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-13 09:42:26 +03:00
Move more functions
This commit is contained in:
parent
e6c190bc86
commit
ab5f1c1013
|
@ -7,7 +7,6 @@ from ..tokens import Doc
|
|||
from ..util import registry
|
||||
|
||||
|
||||
@registry.layers("spacy.CharEmbed.v1")
|
||||
def CharacterEmbed(nM: int, nC: int) -> Model[List[Doc], List[Floats2d]]:
|
||||
# nM: Number of dimensions per character. nC: Number of characters.
|
||||
return Model(
|
||||
|
|
|
@ -3,7 +3,6 @@ from thinc.api import Model, normal_init
|
|||
from ..util import registry
|
||||
|
||||
|
||||
@registry.layers("spacy.PrecomputableAffine.v1")
|
||||
def PrecomputableAffine(nO, nI, nF, nP, dropout=0.1):
|
||||
model = Model(
|
||||
"precomputable_affine",
|
||||
|
|
|
@ -50,7 +50,6 @@ def models_with_nvtx_range(nlp, forward_color: int, backprop_color: int):
|
|||
return nlp
|
||||
|
||||
|
||||
@registry.callbacks("spacy.models_with_nvtx_range.v1")
|
||||
def create_models_with_nvtx_range(
|
||||
forward_color: int = -1, backprop_color: int = -1
|
||||
) -> Callable[["Language"], "Language"]:
|
||||
|
@ -110,7 +109,6 @@ def pipes_with_nvtx_range(
|
|||
return nlp
|
||||
|
||||
|
||||
@registry.callbacks("spacy.models_and_pipes_with_nvtx_range.v1")
|
||||
def create_models_and_pipes_with_nvtx_range(
|
||||
forward_color: int = -1,
|
||||
backprop_color: int = -1,
|
||||
|
|
|
@ -10,7 +10,6 @@ InT = List[Doc]
|
|||
OutT = Floats2d
|
||||
|
||||
|
||||
@registry.architectures("spacy.SpanFinder.v1")
|
||||
def build_finder_model(
|
||||
tok2vec: Model[InT, List[Floats2d]], scorer: Model[OutT, OutT]
|
||||
) -> Model[InT, OutT]:
|
||||
|
|
|
@ -13,7 +13,6 @@ from ..vectors import Mode, Vectors
|
|||
from ..vocab import Vocab
|
||||
|
||||
|
||||
@registry.layers("spacy.StaticVectors.v2")
|
||||
def StaticVectors(
|
||||
nO: Optional[int] = None,
|
||||
nM: Optional[int] = None,
|
||||
|
|
|
@ -107,14 +107,27 @@ def populate_registry() -> None:
|
|||
build_mean_max_reducer,
|
||||
build_spancat_model
|
||||
)
|
||||
from .ml.models.span_finder import build_finder_model
|
||||
from .ml.models.parser import build_tb_parser_model
|
||||
from .ml.models.multi_task import create_pretrain_vectors
|
||||
from .ml.models.tagger import build_tagger_model
|
||||
from .ml.staticvectors import StaticVectors
|
||||
from .ml._precomputable_affine import PrecomputableAffine
|
||||
from .ml._character_embed import CharacterEmbed
|
||||
from .matcher.levenshtein import make_levenshtein_compare
|
||||
from .training.callbacks import create_copy_from_base_model
|
||||
from .ml.callbacks import create_models_with_nvtx_range, create_models_and_pipes_with_nvtx_range
|
||||
from .training.loggers import console_logger, console_logger_v3
|
||||
from .training.batchers import (
|
||||
configure_minibatch_by_padded_size,
|
||||
configure_minibatch_by_words,
|
||||
configure_minibatch
|
||||
)
|
||||
from .training.augment import (
|
||||
create_combined_augmenter,
|
||||
create_lower_casing_augmenter,
|
||||
create_orth_variants_augmenter
|
||||
)
|
||||
|
||||
# Register scorers
|
||||
registry.scorers("spacy.tagger_scorer.v1")(make_tagger_scorer)
|
||||
|
@ -156,6 +169,10 @@ def populate_registry() -> None:
|
|||
registry.architectures("spacy.TextCatParametricAttention.v1")(build_textcat_parametric_attention_v1)
|
||||
registry.architectures("spacy.TextCatReduce.v1")(build_reduce_text_classifier)
|
||||
registry.architectures("spacy.SpanCategorizer.v1")(build_spancat_model)
|
||||
registry.architectures("spacy.SpanFinder.v1")(build_finder_model)
|
||||
registry.architectures("spacy.TransitionBasedParser.v2")(build_tb_parser_model)
|
||||
registry.architectures("spacy.PretrainVectors.v1")(create_pretrain_vectors)
|
||||
registry.architectures("spacy.Tagger.v2")(build_tagger_model)
|
||||
|
||||
# Register layers
|
||||
registry.layers("spacy.FeatureExtractor.v1")(FeatureExtractor)
|
||||
|
@ -163,9 +180,14 @@ def populate_registry() -> None:
|
|||
registry.layers("spacy.extract_ngrams.v1")(extract_ngrams)
|
||||
registry.layers("spacy.LinearLogistic.v1")(build_linear_logistic)
|
||||
registry.layers("spacy.mean_max_reducer.v1")(build_mean_max_reducer)
|
||||
registry.layers("spacy.StaticVectors.v2")(StaticVectors)
|
||||
registry.layers("spacy.PrecomputableAffine.v1")(PrecomputableAffine)
|
||||
registry.layers("spacy.CharEmbed.v1")(CharacterEmbed)
|
||||
|
||||
# Register callbacks
|
||||
registry.callbacks("spacy.copy_from_base_model.v1")(create_copy_from_base_model)
|
||||
registry.callbacks("spacy.models_with_nvtx_range.v1")(create_models_with_nvtx_range)
|
||||
registry.callbacks("spacy.models_and_pipes_with_nvtx_range.v1")(create_models_and_pipes_with_nvtx_range)
|
||||
|
||||
# Register loggers
|
||||
registry.loggers("spacy.ConsoleLogger.v2")(console_logger)
|
||||
|
@ -176,5 +198,10 @@ def populate_registry() -> None:
|
|||
registry.batchers("spacy.batch_by_words.v1")(configure_minibatch_by_words)
|
||||
registry.batchers("spacy.batch_by_sequence.v1")(configure_minibatch)
|
||||
|
||||
# Register augmenters
|
||||
registry.augmenters("spacy.combined_augmenter.v1")(create_combined_augmenter)
|
||||
registry.augmenters("spacy.lower_case.v1")(create_lower_casing_augmenter)
|
||||
registry.augmenters("spacy.orth_variants.v1")(create_orth_variants_augmenter)
|
||||
|
||||
# Set the flag to indicate that the registry has been populated
|
||||
REGISTRY_POPULATED = True
|
||||
|
|
|
@ -11,7 +11,6 @@ if TYPE_CHECKING:
|
|||
from ..language import Language # noqa: F401
|
||||
|
||||
|
||||
@registry.augmenters("spacy.combined_augmenter.v1")
|
||||
def create_combined_augmenter(
|
||||
lower_level: float,
|
||||
orth_level: float,
|
||||
|
@ -84,7 +83,6 @@ def combined_augmenter(
|
|||
yield example
|
||||
|
||||
|
||||
@registry.augmenters("spacy.orth_variants.v1")
|
||||
def create_orth_variants_augmenter(
|
||||
level: float, lower: float, orth_variants: Dict[str, List[Dict]]
|
||||
) -> Callable[["Language", Example], Iterator[Example]]:
|
||||
|
@ -102,7 +100,6 @@ def create_orth_variants_augmenter(
|
|||
)
|
||||
|
||||
|
||||
@registry.augmenters("spacy.lower_case.v1")
|
||||
def create_lower_casing_augmenter(
|
||||
level: float,
|
||||
) -> Callable[["Language", Example], Iterator[Example]]:
|
||||
|
|
Loading…
Reference in New Issue
Block a user