Fix registry

This commit is contained in:
Matthew Honnibal 2025-05-22 00:18:20 +02:00
parent c8d7dd968a
commit 752c2dd403
2 changed files with 47 additions and 32 deletions

View File

@ -183,6 +183,8 @@ class Language:
DOCS: https://spacy.io/api/language#init DOCS: https://spacy.io/api/language#init
""" """
from .pipeline.factories import register_factories
register_factories()
# We're only calling this to import all factories provided via entry # We're only calling this to import all factories provided via entry
# points. The factory decorator applied to these functions takes care # points. The factory decorator applied to these functions takes care
# of the rest. # of the rest.

View File

@ -26,7 +26,6 @@ def populate_registry() -> None:
) )
from .pipeline.lemmatizer import make_lemmatizer_scorer from .pipeline.lemmatizer import make_lemmatizer_scorer
from .pipeline.ner import make_ner_scorer from .pipeline.ner import make_ner_scorer
from .pipeline.sentencizer import senter_score as make_sentencizer_scorer
from .pipeline.senter import make_senter_scorer from .pipeline.senter import make_senter_scorer
from .pipeline.span_finder import make_span_finder_scorer from .pipeline.span_finder import make_span_finder_scorer
from .pipeline.spancat import ( from .pipeline.spancat import (
@ -35,6 +34,7 @@ def populate_registry() -> None:
build_preset_spans_suggester, build_preset_spans_suggester,
make_spancat_scorer, make_spancat_scorer,
) )
# Import the functions we refactored by removing direct registry decorators # Import the functions we refactored by removing direct registry decorators
from .pipeline.entity_linker import make_entity_linker_scorer from .pipeline.entity_linker import make_entity_linker_scorer
from .pipeline.span_ruler import ( from .pipeline.span_ruler import (
@ -45,6 +45,10 @@ def populate_registry() -> None:
from .pipeline.attributeruler import make_attribute_ruler_scorer from .pipeline.attributeruler import make_attribute_ruler_scorer
from .pipeline.dep_parser import make_parser_scorer from .pipeline.dep_parser import make_parser_scorer
from .pipeline.morphologizer import make_morphologizer_scorer from .pipeline.morphologizer import make_morphologizer_scorer
from .ml.models.entity_linker import load_kb, empty_kb_for_config, empty_kb
from .ml.models.entity_linker import create_candidates
from .ml.models.entity_linker import create_candidates_batch
from .language import load_lookups_data
from .lang.ja import create_tokenizer as create_japanese_tokenizer from .lang.ja import create_tokenizer as create_japanese_tokenizer
from .lang.zh import create_chinese_tokenizer from .lang.zh import create_chinese_tokenizer
from .lang.ko import create_tokenizer as create_korean_tokenizer from .lang.ko import create_tokenizer as create_korean_tokenizer
@ -56,6 +60,7 @@ def populate_registry() -> None:
from .pipeline.textcat import make_textcat_scorer from .pipeline.textcat import make_textcat_scorer
from .pipeline.textcat_multilabel import make_textcat_multilabel_scorer from .pipeline.textcat_multilabel import make_textcat_multilabel_scorer
from .util import make_first_longest_spans_filter, registry from .util import make_first_longest_spans_filter, registry
from .matcher.levenshtein import make_levenshtein_compare
# Register miscellaneous components # Register miscellaneous components
registry.misc("spacy.first_longest_spans_filter.v1")( registry.misc("spacy.first_longest_spans_filter.v1")(
@ -64,9 +69,14 @@ def populate_registry() -> None:
registry.misc("spacy.ngram_suggester.v1")(build_ngram_suggester) registry.misc("spacy.ngram_suggester.v1")(build_ngram_suggester)
registry.misc("spacy.ngram_range_suggester.v1")(build_ngram_range_suggester) registry.misc("spacy.ngram_range_suggester.v1")(build_ngram_range_suggester)
registry.misc("spacy.preset_spans_suggester.v1")(build_preset_spans_suggester) registry.misc("spacy.preset_spans_suggester.v1")(build_preset_spans_suggester)
registry.misc("spacy.prioritize_new_ents_filter.v1")(make_prioritize_new_ents_filter) registry.misc("spacy.prioritize_new_ents_filter.v1")(
registry.misc("spacy.prioritize_existing_ents_filter.v1")(make_preserve_existing_ents_filter) make_prioritize_new_ents_filter
)
registry.misc("spacy.prioritize_existing_ents_filter.v1")(
make_preserve_existing_ents_filter
)
registry.misc("spacy.levenshtein_compare.v1")(make_levenshtein_compare) registry.misc("spacy.levenshtein_compare.v1")(make_levenshtein_compare)
# KB-related registrations
registry.misc("spacy.KBFromFile.v1")(load_kb) registry.misc("spacy.KBFromFile.v1")(load_kb)
registry.misc("spacy.EmptyKB.v2")(empty_kb_for_config) registry.misc("spacy.EmptyKB.v2")(empty_kb_for_config)
registry.misc("spacy.EmptyKB.v1")(empty_kb) registry.misc("spacy.EmptyKB.v1")(empty_kb)
@ -93,14 +103,7 @@ def populate_registry() -> None:
from .ml.featureextractor import FeatureExtractor from .ml.featureextractor import FeatureExtractor
from .ml.extract_spans import extract_spans from .ml.extract_spans import extract_spans
from .ml.extract_ngrams import extract_ngrams from .ml.extract_ngrams import extract_ngrams
from .ml.models.entity_linker import ( from .ml.models.entity_linker import build_nel_encoder
build_nel_encoder,
load_kb,
empty_kb_for_config,
empty_kb,
create_candidates,
create_candidates_batch
)
from .ml.models.textcat import ( from .ml.models.textcat import (
build_simple_cnn_text_classifier, build_simple_cnn_text_classifier,
build_bow_text_classifier, build_bow_text_classifier,
@ -108,35 +111,39 @@ def populate_registry() -> None:
build_text_classifier_v2, build_text_classifier_v2,
build_text_classifier_lowdata, build_text_classifier_lowdata,
build_textcat_parametric_attention_v1, build_textcat_parametric_attention_v1,
build_reduce_text_classifier build_reduce_text_classifier,
) )
from .ml.models.spancat import ( from .ml.models.spancat import (
build_linear_logistic, build_linear_logistic,
build_mean_max_reducer, build_mean_max_reducer,
build_spancat_model build_spancat_model,
) )
from .ml.models.span_finder import build_finder_model from .ml.models.span_finder import build_finder_model
from .ml.models.parser import build_tb_parser_model from .ml.models.parser import build_tb_parser_model
from .ml.models.multi_task import create_pretrain_vectors, create_pretrain_characters from .ml.models.multi_task import (
create_pretrain_vectors,
create_pretrain_characters,
)
from .ml.models.tagger import build_tagger_model from .ml.models.tagger import build_tagger_model
from .ml.staticvectors import StaticVectors from .ml.staticvectors import StaticVectors
from .ml._precomputable_affine import PrecomputableAffine from .ml._precomputable_affine import PrecomputableAffine
from .ml._character_embed import CharacterEmbed
from .ml.tb_framework import TransitionModel from .ml.tb_framework import TransitionModel
from .language import create_tokenizer, load_lookups_data from .language import create_tokenizer
from .matcher.levenshtein import make_levenshtein_compare
from .training.callbacks import create_copy_from_base_model from .training.callbacks import create_copy_from_base_model
from .ml.callbacks import create_models_with_nvtx_range, create_models_and_pipes_with_nvtx_range from .ml.callbacks import (
create_models_with_nvtx_range,
create_models_and_pipes_with_nvtx_range,
)
from .training.loggers import console_logger, console_logger_v3 from .training.loggers import console_logger, console_logger_v3
from .training.batchers import ( from .training.batchers import (
configure_minibatch_by_padded_size, configure_minibatch_by_padded_size,
configure_minibatch_by_words, configure_minibatch_by_words,
configure_minibatch configure_minibatch,
) )
from .training.augment import ( from .training.augment import (
create_combined_augmenter, create_combined_augmenter,
create_lower_casing_augmenter, create_lower_casing_augmenter,
create_orth_variants_augmenter create_orth_variants_augmenter,
) )
# Register scorers # Register scorers
@ -144,7 +151,6 @@ def populate_registry() -> None:
registry.scorers("spacy.ner_scorer.v1")(make_ner_scorer) registry.scorers("spacy.ner_scorer.v1")(make_ner_scorer)
# span_ruler_scorer removed as it's not in span_ruler.py # span_ruler_scorer removed as it's not in span_ruler.py
registry.scorers("spacy.entity_ruler_scorer.v1")(make_entityruler_scorer) registry.scorers("spacy.entity_ruler_scorer.v1")(make_entityruler_scorer)
registry.scorers("spacy.sentencizer_scorer.v1")(make_sentencizer_scorer)
registry.scorers("spacy.senter_scorer.v1")(make_senter_scorer) registry.scorers("spacy.senter_scorer.v1")(make_senter_scorer)
registry.scorers("spacy.textcat_scorer.v1")(make_textcat_scorer) registry.scorers("spacy.textcat_scorer.v1")(make_textcat_scorer)
registry.scorers("spacy.textcat_scorer.v2")(make_textcat_scorer) registry.scorers("spacy.textcat_scorer.v2")(make_textcat_scorer)
@ -158,12 +164,15 @@ def populate_registry() -> None:
registry.scorers("spacy.span_finder_scorer.v1")(make_span_finder_scorer) registry.scorers("spacy.span_finder_scorer.v1")(make_span_finder_scorer)
registry.scorers("spacy.spancat_scorer.v1")(make_spancat_scorer) registry.scorers("spacy.spancat_scorer.v1")(make_spancat_scorer)
registry.scorers("spacy.entity_linker_scorer.v1")(make_entity_linker_scorer) registry.scorers("spacy.entity_linker_scorer.v1")(make_entity_linker_scorer)
registry.scorers("spacy.overlapping_labeled_spans_scorer.v1")(make_overlapping_labeled_spans_scorer) registry.scorers("spacy.overlapping_labeled_spans_scorer.v1")(
make_overlapping_labeled_spans_scorer
)
registry.scorers("spacy.attribute_ruler_scorer.v1")(make_attribute_ruler_scorer) registry.scorers("spacy.attribute_ruler_scorer.v1")(make_attribute_ruler_scorer)
registry.scorers("spacy.parser_scorer.v1")(make_parser_scorer) registry.scorers("spacy.parser_scorer.v1")(make_parser_scorer)
registry.scorers("spacy.morphologizer_scorer.v1")(make_morphologizer_scorer) registry.scorers("spacy.morphologizer_scorer.v1")(make_morphologizer_scorer)
# Register tokenizers # Register tokenizers
registry.tokenizers("spacy.Tokenizer.v1")(create_tokenizer)
registry.tokenizers("spacy.ja.JapaneseTokenizer")(create_japanese_tokenizer) registry.tokenizers("spacy.ja.JapaneseTokenizer")(create_japanese_tokenizer)
registry.tokenizers("spacy.zh.ChineseTokenizer")(create_chinese_tokenizer) registry.tokenizers("spacy.zh.ChineseTokenizer")(create_chinese_tokenizer)
registry.tokenizers("spacy.ko.KoreanTokenizer")(create_korean_tokenizer) registry.tokenizers("spacy.ko.KoreanTokenizer")(create_korean_tokenizer)
@ -185,7 +194,9 @@ def populate_registry() -> None:
registry.architectures("spacy.TextCatBOW.v3")(build_bow_text_classifier_v3) registry.architectures("spacy.TextCatBOW.v3")(build_bow_text_classifier_v3)
registry.architectures("spacy.TextCatEnsemble.v2")(build_text_classifier_v2) registry.architectures("spacy.TextCatEnsemble.v2")(build_text_classifier_v2)
registry.architectures("spacy.TextCatLowData.v1")(build_text_classifier_lowdata) registry.architectures("spacy.TextCatLowData.v1")(build_text_classifier_lowdata)
registry.architectures("spacy.TextCatParametricAttention.v1")(build_textcat_parametric_attention_v1) registry.architectures("spacy.TextCatParametricAttention.v1")(
build_textcat_parametric_attention_v1
)
registry.architectures("spacy.TextCatReduce.v1")(build_reduce_text_classifier) registry.architectures("spacy.TextCatReduce.v1")(build_reduce_text_classifier)
registry.architectures("spacy.SpanCategorizer.v1")(build_spancat_model) registry.architectures("spacy.SpanCategorizer.v1")(build_spancat_model)
registry.architectures("spacy.SpanFinder.v1")(build_finder_model) registry.architectures("spacy.SpanFinder.v1")(build_finder_model)
@ -208,7 +219,9 @@ def populate_registry() -> None:
# Register callbacks # Register callbacks
registry.callbacks("spacy.copy_from_base_model.v1")(create_copy_from_base_model) registry.callbacks("spacy.copy_from_base_model.v1")(create_copy_from_base_model)
registry.callbacks("spacy.models_with_nvtx_range.v1")(create_models_with_nvtx_range) registry.callbacks("spacy.models_with_nvtx_range.v1")(create_models_with_nvtx_range)
registry.callbacks("spacy.models_and_pipes_with_nvtx_range.v1")(create_models_and_pipes_with_nvtx_range) registry.callbacks("spacy.models_and_pipes_with_nvtx_range.v1")(
create_models_and_pipes_with_nvtx_range
)
# Register loggers # Register loggers
registry.loggers("spacy.ConsoleLogger.v2")(console_logger) registry.loggers("spacy.ConsoleLogger.v2")(console_logger)