This commit is contained in:
Matthew Honnibal 2025-05-19 17:41:34 +02:00
parent cda2bd01d4
commit c3f9fab5e8
3 changed files with 157 additions and 71 deletions

View File

@ -12,43 +12,62 @@ REGISTRY_POPULATED = False
# Global flag to track if factories have been registered
FACTORIES_REGISTERED = False
def populate_registry() -> None:
"""Populate the registry with all necessary components.
This function should be called before accessing the registry, to ensure
This function should be called before accessing the registry, to ensure
it's populated. The function uses a global flag to prevent repopulation.
"""
global REGISTRY_POPULATED
if REGISTRY_POPULATED:
return
# Import all necessary modules
from .util import registry, make_first_longest_spans_filter
# Import all pipeline components that were using registry decorators
from .pipeline.tagger import make_tagger_scorer
from .pipeline.ner import make_ner_scorer
from .pipeline.lemmatizer import make_lemmatizer_scorer
from .pipeline.span_finder import make_span_finder_scorer
from .pipeline.spancat import make_spancat_scorer, build_ngram_suggester, build_ngram_range_suggester, build_preset_spans_suggester
from .pipeline.entityruler import make_entity_ruler_scorer as make_entityruler_scorer
from .pipeline.spancat import (
make_spancat_scorer,
build_ngram_suggester,
build_ngram_range_suggester,
build_preset_spans_suggester,
)
from .pipeline.entityruler import (
make_entity_ruler_scorer as make_entityruler_scorer,
)
from .pipeline.sentencizer import senter_score as make_sentencizer_scorer
from .pipeline.senter import make_senter_scorer
from .pipeline.textcat import make_textcat_scorer
from .pipeline.textcat_multilabel import make_textcat_multilabel_scorer
# Register miscellaneous components
registry.misc("spacy.first_longest_spans_filter.v1")(make_first_longest_spans_filter)
# Register miscellaneous components
registry.misc("spacy.first_longest_spans_filter.v1")(
make_first_longest_spans_filter
)
registry.misc("spacy.ngram_suggester.v1")(build_ngram_suggester)
registry.misc("spacy.ngram_range_suggester.v1")(build_ngram_range_suggester)
registry.misc("spacy.preset_spans_suggester.v1")(build_preset_spans_suggester)
# Need to get references to the existing functions in registry by importing the function that is there
# For the registry that was previously decorated
# Import ML components that use registry
from .ml.models.tok2vec import tok2vec_listener_v1, build_hash_embed_cnn_tok2vec, build_Tok2Vec_model, MultiHashEmbed, CharacterEmbed, MaxoutWindowEncoder, MishWindowEncoder, BiLSTMEncoder
from .ml.models.tok2vec import (
tok2vec_listener_v1,
build_hash_embed_cnn_tok2vec,
build_Tok2Vec_model,
MultiHashEmbed,
CharacterEmbed,
MaxoutWindowEncoder,
MishWindowEncoder,
BiLSTMEncoder,
)
# Register scorers
registry.scorers("spacy.tagger_scorer.v1")(make_tagger_scorer)
registry.scorers("spacy.ner_scorer.v1")(make_ner_scorer)
@ -58,12 +77,16 @@ def populate_registry() -> None:
registry.scorers("spacy.senter_scorer.v1")(make_senter_scorer)
registry.scorers("spacy.textcat_scorer.v1")(make_textcat_scorer)
registry.scorers("spacy.textcat_scorer.v2")(make_textcat_scorer)
registry.scorers("spacy.textcat_multilabel_scorer.v1")(make_textcat_multilabel_scorer)
registry.scorers("spacy.textcat_multilabel_scorer.v2")(make_textcat_multilabel_scorer)
registry.scorers("spacy.textcat_multilabel_scorer.v1")(
make_textcat_multilabel_scorer
)
registry.scorers("spacy.textcat_multilabel_scorer.v2")(
make_textcat_multilabel_scorer
)
registry.scorers("spacy.lemmatizer_scorer.v1")(make_lemmatizer_scorer)
registry.scorers("spacy.span_finder_scorer.v1")(make_span_finder_scorer)
registry.scorers("spacy.spancat_scorer.v1")(make_spancat_scorer)
# Register tok2vec architectures we've modified
registry.architectures("spacy.Tok2VecListener.v1")(tok2vec_listener_v1)
registry.architectures("spacy.HashEmbedCNN.v2")(build_hash_embed_cnn_tok2vec)
@ -73,33 +96,52 @@ def populate_registry() -> None:
registry.architectures("spacy.MaxoutWindowEncoder.v2")(MaxoutWindowEncoder)
registry.architectures("spacy.MishWindowEncoder.v2")(MishWindowEncoder)
registry.architectures("spacy.TorchBiLSTMEncoder.v1")(BiLSTMEncoder)
# Register factory components
register_factories()
# Set the flag to indicate that the registry has been populated
REGISTRY_POPULATED = True
def register_factories() -> None:
"""Register all factories with the registry.
This function registers all pipeline component factories, centralizing
the registrations that were previously done with @Language.factory decorators.
"""
global FACTORIES_REGISTERED
from .language import Language
from .pipeline.sentencizer import Sentencizer
if FACTORIES_REGISTERED:
return
from .language import Language
# TODO: We seem to still get cycle problems with these functions defined in Cython. We need
# a Python _factories module maybe?
def make_sentencizer(
nlp: Language,
name: str,
punct_chars: Optional[List[str]],
overwrite: bool,
scorer: Optional[Callable],
):
return Sentencizer(
name, punct_chars=punct_chars, overwrite=overwrite, scorer=scorer
)
# Import factory default configurations
from .pipeline.entity_linker import DEFAULT_NEL_MODEL
from .pipeline.entityruler import DEFAULT_ENT_ID_SEP
from .pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
from .pipeline.senter import DEFAULT_SENTER_MODEL
from .pipeline.morphologizer import DEFAULT_MORPH_MODEL
from .pipeline.spancat import DEFAULT_SPANCAT_MODEL, DEFAULT_SPANCAT_SINGLELABEL_MODEL, DEFAULT_SPANS_KEY
from .pipeline.spancat import (
DEFAULT_SPANCAT_MODEL,
DEFAULT_SPANCAT_SINGLELABEL_MODEL,
DEFAULT_SPANS_KEY,
)
from .pipeline.span_ruler import DEFAULT_SPANS_KEY as SPAN_RULER_DEFAULT_SPANS_KEY
from .pipeline.edit_tree_lemmatizer import DEFAULT_EDIT_TREE_LEMMATIZER_MODEL
from .pipeline.textcat_multilabel import DEFAULT_MULTI_TEXTCAT_MODEL
@ -108,7 +150,7 @@ def register_factories() -> None:
from .pipeline.dep_parser import DEFAULT_PARSER_MODEL
from .pipeline.tagger import DEFAULT_TAGGER_MODEL
from .pipeline.multitask import DEFAULT_MT_MODEL
# Import all factory functions
from .pipeline.attributeruler import make_attribute_ruler
from .pipeline.entity_linker import make_entity_linker
@ -120,7 +162,10 @@ def register_factories() -> None:
from .pipeline.senter import make_senter
from .pipeline.morphologizer import make_morphologizer
from .pipeline.spancat import make_spancat, make_spancat_singlelabel
from .pipeline.span_ruler import make_entity_ruler as make_span_entity_ruler, make_span_ruler
from .pipeline.span_ruler import (
make_entity_ruler as make_span_entity_ruler,
make_span_ruler,
)
from .pipeline.edit_tree_lemmatizer import make_edit_tree_lemmatizer
from .pipeline.textcat_multilabel import make_multilabel_textcat
from .pipeline.span_finder import make_span_finder
@ -128,11 +173,12 @@ def register_factories() -> None:
from .pipeline.dep_parser import make_parser, make_beam_parser
from .pipeline.tagger import make_tagger
from .pipeline.multitask import make_nn_labeller
from .pipeline.sentencizer import make_sentencizer
# from .pipeline.sentencizer import make_sentencizer
# Register factories using the same pattern as Language.factory decorator
# We use Language.factory()() pattern which exactly mimics the decorator
# attributeruler
Language.factory(
"attribute_ruler",
@ -141,7 +187,7 @@ def register_factories() -> None:
"scorer": {"@scorers": "spacy.attribute_ruler_scorer.v1"},
},
)(make_attribute_ruler)
# entity_linker
Language.factory(
"entity_linker",
@ -169,7 +215,7 @@ def register_factories() -> None:
"nel_micro_p": None,
},
)(make_entity_linker)
# entity_ruler
Language.factory(
"entity_ruler",
@ -189,7 +235,7 @@ def register_factories() -> None:
"ents_per_type": None,
},
)(make_entity_ruler)
# lemmatizer
Language.factory(
"lemmatizer",
@ -202,7 +248,7 @@ def register_factories() -> None:
},
default_score_weights={"lemma_acc": 1.0},
)(make_lemmatizer)
# textcat
Language.factory(
"textcat",
@ -225,49 +271,57 @@ def register_factories() -> None:
"cats_f_per_type": None,
},
)(make_textcat)
# token_splitter
Language.factory(
"token_splitter",
default_config={"min_length": 25, "split_length": 10},
retokenizes=True,
)(make_token_splitter)
# doc_cleaner
Language.factory(
"doc_cleaner",
default_config={"attrs": {"tensor": None, "_.trf_data": None}, "silent": True},
)(make_doc_cleaner)
# tok2vec
Language.factory(
"tok2vec",
assigns=["doc.tensor"],
default_config={"model": DEFAULT_TOK2VEC_MODEL}
"tok2vec",
assigns=["doc.tensor"],
default_config={"model": DEFAULT_TOK2VEC_MODEL},
)(make_tok2vec)
# senter
Language.factory(
"senter",
assigns=["token.is_sent_start"],
default_config={"model": DEFAULT_SENTER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.senter_scorer.v1"}},
default_config={
"model": DEFAULT_SENTER_MODEL,
"overwrite": False,
"scorer": {"@scorers": "spacy.senter_scorer.v1"},
},
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
)(make_senter)
# morphologizer
Language.factory(
"morphologizer",
assigns=["token.morph", "token.pos"],
default_config={
"model": DEFAULT_MORPH_MODEL,
"overwrite": True,
"model": DEFAULT_MORPH_MODEL,
"overwrite": True,
"extend": False,
"scorer": {"@scorers": "spacy.morphologizer_scorer.v1"},
"label_smoothing": 0.0
"scorer": {"@scorers": "spacy.morphologizer_scorer.v1"},
"label_smoothing": 0.0,
},
default_score_weights={
"pos_acc": 0.5,
"morph_acc": 0.5,
"morph_per_feat": None,
},
default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5, "morph_per_feat": None},
)(make_morphologizer)
# spancat
Language.factory(
"spancat",
@ -282,7 +336,7 @@ def register_factories() -> None:
},
default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0},
)(make_spancat)
# spancat_singlelabel
Language.factory(
"spancat_singlelabel",
@ -297,7 +351,7 @@ def register_factories() -> None:
},
default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0},
)(make_spancat_singlelabel)
# future_entity_ruler
Language.factory(
"future_entity_ruler",
@ -317,7 +371,7 @@ def register_factories() -> None:
"ents_per_type": None,
},
)(make_span_entity_ruler)
# span_ruler
Language.factory(
"span_ruler",
@ -343,7 +397,7 @@ def register_factories() -> None:
f"spans_{SPAN_RULER_DEFAULT_SPANS_KEY}_per_type": None,
},
)(make_span_ruler)
# trainable_lemmatizer
Language.factory(
"trainable_lemmatizer",
@ -359,7 +413,7 @@ def register_factories() -> None:
},
default_score_weights={"lemma_acc": 1.0},
)(make_edit_tree_lemmatizer)
# textcat_multilabel
Language.factory(
"textcat_multilabel",
@ -382,7 +436,7 @@ def register_factories() -> None:
"cats_f_per_type": None,
},
)(make_multilabel_textcat)
# span_finder
Language.factory(
"span_finder",
@ -401,7 +455,7 @@ def register_factories() -> None:
f"spans_{DEFAULT_SPANS_KEY}_r": 0.0,
},
)(make_span_finder)
# ner
Language.factory(
"ner",
@ -413,9 +467,14 @@ def register_factories() -> None:
"incorrect_spans_key": None,
"scorer": {"@scorers": "spacy.ner_scorer.v1"},
},
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},
default_score_weights={
"ents_f": 1.0,
"ents_p": 0.0,
"ents_r": 0.0,
"ents_per_type": None,
},
)(make_ner)
# beam_ner
Language.factory(
"beam_ner",
@ -430,9 +489,14 @@ def register_factories() -> None:
"incorrect_spans_key": None,
"scorer": {"@scorers": "spacy.ner_scorer.v1"},
},
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},
default_score_weights={
"ents_f": 1.0,
"ents_p": 0.0,
"ents_r": 0.0,
"ents_per_type": None,
},
)(make_beam_ner)
# parser
Language.factory(
"parser",
@ -454,7 +518,7 @@ def register_factories() -> None:
"sents_f": 0.0,
},
)(make_parser)
# beam_parser
Language.factory(
"beam_parser",
@ -479,28 +543,48 @@ def register_factories() -> None:
"sents_f": 0.0,
},
)(make_beam_parser)
# tagger
Language.factory(
"tagger",
assigns=["token.tag"],
default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}, "neg_prefix": "!", "label_smoothing": 0.0},
default_score_weights={"tag_acc": 1.0, "pos_acc": 0.0, "tag_micro_p": None, "tag_micro_r": None, "tag_micro_f": None},
default_config={
"model": DEFAULT_TAGGER_MODEL,
"overwrite": False,
"scorer": {"@scorers": "spacy.tagger_scorer.v1"},
"neg_prefix": "!",
"label_smoothing": 0.0,
},
default_score_weights={
"tag_acc": 1.0,
"pos_acc": 0.0,
"tag_micro_p": None,
"tag_micro_r": None,
"tag_micro_f": None,
},
)(make_tagger)
# nn_labeller
Language.factory(
"nn_labeller",
default_config={"labels": None, "target": "dep_tag_offset", "model": DEFAULT_MT_MODEL}
default_config={
"labels": None,
"target": "dep_tag_offset",
"model": DEFAULT_MT_MODEL,
},
)(make_nn_labeller)
# sentencizer
Language.factory(
"sentencizer",
assigns=["token.is_sent_start", "doc.sents"],
default_config={"punct_chars": None, "overwrite": False, "scorer": {"@scorers": "spacy.senter_scorer.v1"}},
default_config={
"punct_chars": None,
"overwrite": False,
"scorer": {"@scorers": "spacy.senter_scorer.v1"},
},
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
)(make_sentencizer)
# Set the flag to indicate that all factories have been registered
FACTORIES_REGISTERED = True

View File

@ -479,3 +479,4 @@ NAMES = [it[0] for it in sorted(IDS.items(), key=sort_nums)]
# (which is generating an enormous amount of C++ in Cython 0.24+)
# We keep the enum cdef, and just make sure the names are available to Python
locals().update(IDS)

View File

@ -87,7 +87,7 @@ def entity_linker():
objects_to_test = (
[nlp(), vectors(), custom_pipe(), tagger(), entity_linker()],
[nlp, vectors, custom_pipe, tagger, entity_linker],
["nlp", "vectors", "custom_pipe", "tagger", "entity_linker"],
)
@ -101,8 +101,9 @@ def write_obj_and_catch_warnings(obj):
return list(filter(lambda x: isinstance(x, ResourceWarning), warnings_list))
@pytest.mark.parametrize("obj", objects_to_test[0], ids=objects_to_test[1])
def test_to_disk_resource_warning(obj):
@pytest.mark.parametrize("obj_factory", objects_to_test[0], ids=objects_to_test[1])
def test_to_disk_resource_warning(obj_factory):
obj = obj_factory()
warnings_list = write_obj_and_catch_warnings(obj)
assert len(warnings_list) == 0
@ -139,7 +140,7 @@ def test_save_and_load_knowledge_base():
class TestToDiskResourceWarningUnittest(TestCase):
def test_resource_warning(self):
scenarios = zip(*objects_to_test)
scenarios = zip(*[x() for x in objects_to_test]) # type: ignore
for scenario in scenarios:
with self.subTest(msg=scenario[1]):