Move component registrations under registrations.py

The functions can't be in Cython anymore, as we can't read
the types off the signatures in Cython 3. To avoid having some
in the file and some not, I've moved the Python ones as well.

We'll need to do a re-import of these functions into the files that
had them before to maintain backwards compatibility. This might
require some import trickery to avoid circular imports.
This commit is contained in:
Matthew Honnibal 2025-05-20 18:27:12 +02:00
parent c3f9fab5e8
commit 6f1a65a8bd

View File

@ -4,7 +4,27 @@ This module centralizes registry decorations to prevent circular import issues
with Cython annotation changes from __future__ import annotations. Functions with Cython annotation changes from __future__ import annotations. Functions
remain in their original locations, but decoration is moved here. remain in their original locations, but decoration is moved here.
""" """
from typing import Dict, Any, Callable, Iterable, List, Optional, Union from typing import Dict, Any, Callable, Iterable, List, Optional, Union, Tuple
from thinc.api import Model
from thinc.types import Floats2d, Ragged
from .tokens.doc import Doc
from .tokens.span import Span
from .kb import KnowledgeBase, Candidate
from .vocab import Vocab
from .pipeline.textcat import TextCategorizer
from .pipeline.tok2vec import Tok2Vec
from .pipeline.spancat import SpanCategorizer, Suggester
from .pipeline.textcat_multilabel import MultiLabel_TextCategorizer
from .pipeline.entityruler import EntityRuler
from .pipeline.span_finder import SpanFinder
from .pipeline.ner import EntityRecognizer
from .pipeline._parser_internals.transition_system import TransitionSystem
from .pipeline.ner import EntityRecognizer
from .pipeline.dep_parser import DependencyParser
from .pipeline.dep_parser import DependencyParser
from .pipeline.tagger import Tagger
from .pipeline.multitask import MultitaskObjective
from .pipeline.senter import SentenceRecognizer
# Global flag to track if registry has been populated # Global flag to track if registry has been populated
REGISTRY_POPULATED = False REGISTRY_POPULATED = False
@ -112,24 +132,11 @@ def register_factories() -> None:
""" """
global FACTORIES_REGISTERED global FACTORIES_REGISTERED
from .language import Language
from .pipeline.sentencizer import Sentencizer
if FACTORIES_REGISTERED: if FACTORIES_REGISTERED:
return return
# TODO: We seem to still get cycle problems with these functions defined in Cython. We need from .language import Language
# a Python _factories module maybe? from .pipeline.sentencizer import Sentencizer
def make_sentencizer(
nlp: Language,
name: str,
punct_chars: Optional[List[str]],
overwrite: bool,
scorer: Optional[Callable],
):
return Sentencizer(
name, punct_chars=punct_chars, overwrite=overwrite, scorer=scorer
)
# Import factory default configurations # Import factory default configurations
from .pipeline.entity_linker import DEFAULT_NEL_MODEL from .pipeline.entity_linker import DEFAULT_NEL_MODEL
@ -150,31 +157,453 @@ def register_factories() -> None:
from .pipeline.dep_parser import DEFAULT_PARSER_MODEL from .pipeline.dep_parser import DEFAULT_PARSER_MODEL
from .pipeline.tagger import DEFAULT_TAGGER_MODEL from .pipeline.tagger import DEFAULT_TAGGER_MODEL
from .pipeline.multitask import DEFAULT_MT_MODEL from .pipeline.multitask import DEFAULT_MT_MODEL
from .pipeline.textcat import DEFAULT_SINGLE_TEXTCAT_MODEL
# Import all factory functions # We can't have function implementations for these factories in Cython, because
from .pipeline.attributeruler import make_attribute_ruler # we need to build a Pydantic model for them dynamically, reading their argument
from .pipeline.entity_linker import make_entity_linker # structure from the signature. In Cython 3, this doesn't work because the
from .pipeline.entityruler import make_entity_ruler # from __future__ import annotations semantics are used, which means the types
from .pipeline.lemmatizer import make_lemmatizer # are stored as strings.
from .pipeline.textcat import make_textcat, DEFAULT_SINGLE_TEXTCAT_MODEL def make_sentencizer(
from .pipeline.functions import make_token_splitter, make_doc_cleaner nlp: Language,
from .pipeline.tok2vec import make_tok2vec name: str,
from .pipeline.senter import make_senter punct_chars: Optional[List[str]],
from .pipeline.morphologizer import make_morphologizer overwrite: bool,
from .pipeline.spancat import make_spancat, make_spancat_singlelabel scorer: Optional[Callable],
from .pipeline.span_ruler import ( ):
make_entity_ruler as make_span_entity_ruler, return Sentencizer(
make_span_ruler, name, punct_chars=punct_chars, overwrite=overwrite, scorer=scorer
) )
from .pipeline.edit_tree_lemmatizer import make_edit_tree_lemmatizer
from .pipeline.textcat_multilabel import make_multilabel_textcat
from .pipeline.span_finder import make_span_finder
from .pipeline.ner import make_ner, make_beam_ner
from .pipeline.dep_parser import make_parser, make_beam_parser
from .pipeline.tagger import make_tagger
from .pipeline.multitask import make_nn_labeller
# from .pipeline.sentencizer import make_sentencizer def make_attribute_ruler(
nlp: Language, name: str, validate: bool, scorer: Optional[Callable]
):
from .pipeline.attributeruler import AttributeRuler
return AttributeRuler(nlp.vocab, name, validate=validate, scorer=scorer)
def make_entity_linker(
nlp: Language,
name: str,
model: Model,
*,
labels_discard: Iterable[str],
n_sents: int,
incl_prior: bool,
incl_context: bool,
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_batch: Callable[
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool,
scorer: Optional[Callable],
use_gold_ents: bool,
candidates_batch_size: int,
threshold: Optional[float] = None,
):
from .pipeline.entity_linker import EntityLinker, EntityLinker_v1
if not model.attrs.get("include_span_maker", False):
# The only difference in arguments here is that use_gold_ents and threshold aren't available.
return EntityLinker_v1(
nlp.vocab,
model,
name,
labels_discard=labels_discard,
n_sents=n_sents,
incl_prior=incl_prior,
incl_context=incl_context,
entity_vector_length=entity_vector_length,
get_candidates=get_candidates,
overwrite=overwrite,
scorer=scorer,
)
return EntityLinker(
nlp.vocab,
model,
name,
labels_discard=labels_discard,
n_sents=n_sents,
incl_prior=incl_prior,
incl_context=incl_context,
entity_vector_length=entity_vector_length,
get_candidates=get_candidates,
get_candidates_batch=get_candidates_batch,
generate_empty_kb=generate_empty_kb,
overwrite=overwrite,
scorer=scorer,
use_gold_ents=use_gold_ents,
candidates_batch_size=candidates_batch_size,
threshold=threshold,
)
def make_lemmatizer(
nlp: Language,
model: Optional[Model],
name: str,
mode: str,
overwrite: bool,
scorer: Optional[Callable],
):
from .pipeline.lemmatizer import Lemmatizer
return Lemmatizer(
nlp.vocab, model, name, mode=mode, overwrite=overwrite, scorer=scorer
)
def make_textcat(
nlp: Language,
name: str,
model: Model[List[Doc], List[Floats2d]],
threshold: float,
scorer: Optional[Callable],
) -> TextCategorizer:
return TextCategorizer(nlp.vocab, model, name, threshold=threshold, scorer=scorer)
def make_token_splitter(
nlp: Language, name: str, *, min_length: int = 0, split_length: int = 0
):
from .pipeline.functions import TokenSplitter
return TokenSplitter(min_length=min_length, split_length=split_length)
def make_doc_cleaner(nlp: Language, name: str, *, attrs: Dict[str, Any], silent: bool):
from .pipeline.functions import DocCleaner
return DocCleaner(attrs, silent=silent)
def make_tok2vec(nlp: Language, name: str, model: Model) -> Tok2Vec:
return Tok2Vec(nlp.vocab, model, name)
def make_spancat(
nlp: Language,
name: str,
suggester: Suggester,
model: Model[Tuple[List[Doc], Ragged], Floats2d],
spans_key: str,
scorer: Optional[Callable],
threshold: float,
max_positive: Optional[int],
) -> SpanCategorizer:
return SpanCategorizer(
nlp.vocab,
model=model,
suggester=suggester,
name=name,
spans_key=spans_key,
negative_weight=None,
allow_overlap=True,
max_positive=max_positive,
threshold=threshold,
scorer=scorer,
add_negative_label=False,
)
def make_spancat_singlelabel(
nlp: Language,
name: str,
suggester: Suggester,
model: Model[Tuple[List[Doc], Ragged], Floats2d],
spans_key: str,
negative_weight: float,
allow_overlap: bool,
scorer: Optional[Callable],
) -> "SpanCategorizer":
from .pipeline.spancat import SpanCategorizer
return SpanCategorizer(
nlp.vocab,
model=model,
suggester=suggester,
name=name,
spans_key=spans_key,
negative_weight=negative_weight,
allow_overlap=allow_overlap,
max_positive=1,
add_negative_label=True,
threshold=None,
scorer=scorer,
)
def make_future_entity_ruler(
nlp: Language,
name: str,
phrase_matcher_attr: Optional[Union[int, str]],
matcher_fuzzy_compare: Callable,
validate: bool,
overwrite_ents: bool,
scorer: Optional[Callable],
ent_id_sep: str,
):
from .pipeline.span_ruler import SpanRuler, prioritize_new_ents_filter, prioritize_existing_ents_filter
if overwrite_ents:
ents_filter = prioritize_new_ents_filter
else:
ents_filter = prioritize_existing_ents_filter
return SpanRuler(
nlp,
name,
spans_key=None,
spans_filter=None,
annotate_ents=True,
ents_filter=ents_filter,
phrase_matcher_attr=phrase_matcher_attr,
matcher_fuzzy_compare=matcher_fuzzy_compare,
validate=validate,
overwrite=False,
scorer=scorer,
)
def make_entity_ruler(
nlp: Language,
name: str,
phrase_matcher_attr: Optional[Union[int, str]],
matcher_fuzzy_compare: Callable,
validate: bool,
overwrite_ents: bool,
ent_id_sep: str,
scorer: Optional[Callable],
):
return EntityRuler(
nlp,
name,
phrase_matcher_attr=phrase_matcher_attr,
matcher_fuzzy_compare=matcher_fuzzy_compare,
validate=validate,
overwrite_ents=overwrite_ents,
ent_id_sep=ent_id_sep,
scorer=scorer,
)
def make_span_ruler(
nlp: Language,
name: str,
spans_key: Optional[str],
spans_filter: Optional[Callable[[Iterable[Span], Iterable[Span]], Iterable[Span]]],
annotate_ents: bool,
ents_filter: Callable[[Iterable[Span], Iterable[Span]], Iterable[Span]],
phrase_matcher_attr: Optional[Union[int, str]],
matcher_fuzzy_compare: Callable,
validate: bool,
overwrite: bool,
scorer: Optional[Callable],
):
from .pipeline.span_ruler import SpanRuler
return SpanRuler(
nlp,
name,
spans_key=spans_key,
spans_filter=spans_filter,
annotate_ents=annotate_ents,
ents_filter=ents_filter,
phrase_matcher_attr=phrase_matcher_attr,
matcher_fuzzy_compare=matcher_fuzzy_compare,
validate=validate,
overwrite=overwrite,
scorer=scorer,
)
def make_edit_tree_lemmatizer(
nlp: Language,
name: str,
model: Model,
backoff: Optional[str],
min_tree_freq: int,
overwrite: bool,
top_k: int,
scorer: Optional[Callable],
):
from .pipeline.edit_tree_lemmatizer import EditTreeLemmatizer
return EditTreeLemmatizer(
nlp.vocab,
model,
name,
backoff=backoff,
min_tree_freq=min_tree_freq,
overwrite=overwrite,
top_k=top_k,
scorer=scorer,
)
def make_multilabel_textcat(
nlp: Language,
name: str,
model: Model[List[Doc], List[Floats2d]],
threshold: float,
scorer: Optional[Callable],
) -> MultiLabel_TextCategorizer:
return MultiLabel_TextCategorizer(
nlp.vocab, model, name, threshold=threshold, scorer=scorer
)
def make_span_finder(
nlp: Language,
name: str,
model: Model[Iterable[Doc], Floats2d],
spans_key: str,
threshold: float,
max_length: Optional[int],
min_length: Optional[int],
scorer: Optional[Callable],
) -> SpanFinder:
return SpanFinder(
nlp,
model=model,
threshold=threshold,
name=name,
scorer=scorer,
max_length=max_length,
min_length=min_length,
spans_key=spans_key,
)
def make_ner(
nlp: Language,
name: str,
model: Model,
moves: Optional[TransitionSystem],
update_with_oracle_cut_size: int,
incorrect_spans_key: Optional[str],
scorer: Optional[Callable],
):
return EntityRecognizer(
nlp.vocab,
model,
name=name,
moves=moves,
update_with_oracle_cut_size=update_with_oracle_cut_size,
incorrect_spans_key=incorrect_spans_key,
scorer=scorer,
)
def make_beam_ner(
nlp: Language,
name: str,
model: Model,
moves: Optional[TransitionSystem],
update_with_oracle_cut_size: int,
beam_width: int,
beam_density: float,
beam_update_prob: float,
incorrect_spans_key: Optional[str],
scorer: Optional[Callable],
):
return EntityRecognizer(
nlp.vocab,
model,
name=name,
moves=moves,
update_with_oracle_cut_size=update_with_oracle_cut_size,
beam_width=beam_width,
beam_density=beam_density,
beam_update_prob=beam_update_prob,
incorrect_spans_key=incorrect_spans_key,
scorer=scorer,
)
def make_parser(
nlp: Language,
name: str,
model: Model,
moves: Optional[TransitionSystem],
update_with_oracle_cut_size: int,
learn_tokens: bool,
min_action_freq: int,
scorer: Optional[Callable],
):
return DependencyParser(
nlp.vocab,
model,
name=name,
moves=moves,
update_with_oracle_cut_size=update_with_oracle_cut_size,
learn_tokens=learn_tokens,
min_action_freq=min_action_freq,
scorer=scorer,
)
def make_beam_parser(
nlp: Language,
name: str,
model: Model,
moves: Optional[TransitionSystem],
update_with_oracle_cut_size: int,
learn_tokens: bool,
min_action_freq: int,
beam_width: int,
beam_density: float,
beam_update_prob: float,
scorer: Optional[Callable],
):
return DependencyParser(
nlp.vocab,
model,
name=name,
moves=moves,
update_with_oracle_cut_size=update_with_oracle_cut_size,
learn_tokens=learn_tokens,
min_action_freq=min_action_freq,
beam_width=beam_width,
beam_density=beam_density,
beam_update_prob=beam_update_prob,
scorer=scorer,
)
def make_tagger(
nlp: Language,
name: str,
model: Model,
overwrite: bool,
scorer: Optional[Callable],
neg_prefix: str,
label_smoothing: float,
):
return Tagger(
nlp.vocab,
model,
name=name,
overwrite=overwrite,
scorer=scorer,
neg_prefix=neg_prefix,
label_smoothing=label_smoothing,
)
def make_nn_labeller(
nlp: Language,
name: str,
model: Model,
labels: Optional[dict],
target: str
):
return MultitaskObjective(nlp.vocab, model, name, target=target)
def make_morphologizer(
nlp: Language,
model: Model,
name: str,
overwrite: bool,
extend: bool,
label_smoothing: float,
scorer: Optional[Callable],
):
from .pipeline.morphologizer import Morphologizer
return Morphologizer(
nlp.vocab, model, name,
overwrite=overwrite,
extend=extend,
label_smoothing=label_smoothing,
scorer=scorer
)
def make_senter(
nlp: Language,
name: str,
model: Model,
overwrite: bool,
scorer: Optional[Callable]
):
return SentenceRecognizer(
nlp.vocab, model, name,
overwrite=overwrite,
scorer=scorer
)
# Register factories using the same pattern as Language.factory decorator # Register factories using the same pattern as Language.factory decorator
# We use Language.factory()() pattern which exactly mimics the decorator # We use Language.factory()() pattern which exactly mimics the decorator
@ -370,7 +799,7 @@ def register_factories() -> None:
"ents_r": 0.0, "ents_r": 0.0,
"ents_per_type": None, "ents_per_type": None,
}, },
)(make_span_entity_ruler) )(make_future_entity_ruler)
# span_ruler # span_ruler
Language.factory( Language.factory(