diff --git a/spacy/registrations.py b/spacy/registrations.py index 1dc852b4f..bf0a1ec4e 100644 --- a/spacy/registrations.py +++ b/spacy/registrations.py @@ -4,7 +4,27 @@ This module centralizes registry decorations to prevent circular import issues with Cython annotation changes from __future__ import annotations. Functions remain in their original locations, but decoration is moved here. """ -from typing import Dict, Any, Callable, Iterable, List, Optional, Union +from typing import Dict, Any, Callable, Iterable, List, Optional, Union, Tuple +from thinc.api import Model +from thinc.types import Floats2d, Ragged +from .tokens.doc import Doc +from .tokens.span import Span +from .kb import KnowledgeBase, Candidate +from .vocab import Vocab +from .pipeline.textcat import TextCategorizer +from .pipeline.tok2vec import Tok2Vec +from .pipeline.spancat import SpanCategorizer, Suggester +from .pipeline.textcat_multilabel import MultiLabel_TextCategorizer +from .pipeline.entityruler import EntityRuler +from .pipeline.span_finder import SpanFinder +from .pipeline.ner import EntityRecognizer +from .pipeline._parser_internals.transition_system import TransitionSystem +from .pipeline.ner import EntityRecognizer +from .pipeline.dep_parser import DependencyParser +from .pipeline.dep_parser import DependencyParser +from .pipeline.tagger import Tagger +from .pipeline.multitask import MultitaskObjective +from .pipeline.senter import SentenceRecognizer # Global flag to track if registry has been populated REGISTRY_POPULATED = False @@ -112,24 +132,11 @@ def register_factories() -> None: """ global FACTORIES_REGISTERED - from .language import Language - from .pipeline.sentencizer import Sentencizer - if FACTORIES_REGISTERED: return - # TODO: We seem to still get cycle problems with these functions defined in Cython. We need - # a Python _factories module maybe? - def make_sentencizer( - nlp: Language, - name: str, - punct_chars: Optional[List[str]], - overwrite: bool, - scorer: Optional[Callable], - ): - return Sentencizer( - name, punct_chars=punct_chars, overwrite=overwrite, scorer=scorer - ) + from .language import Language + from .pipeline.sentencizer import Sentencizer # Import factory default configurations from .pipeline.entity_linker import DEFAULT_NEL_MODEL @@ -150,31 +157,453 @@ def register_factories() -> None: from .pipeline.dep_parser import DEFAULT_PARSER_MODEL from .pipeline.tagger import DEFAULT_TAGGER_MODEL from .pipeline.multitask import DEFAULT_MT_MODEL + from .pipeline.textcat import DEFAULT_SINGLE_TEXTCAT_MODEL - # Import all factory functions - from .pipeline.attributeruler import make_attribute_ruler - from .pipeline.entity_linker import make_entity_linker - from .pipeline.entityruler import make_entity_ruler - from .pipeline.lemmatizer import make_lemmatizer - from .pipeline.textcat import make_textcat, DEFAULT_SINGLE_TEXTCAT_MODEL - from .pipeline.functions import make_token_splitter, make_doc_cleaner - from .pipeline.tok2vec import make_tok2vec - from .pipeline.senter import make_senter - from .pipeline.morphologizer import make_morphologizer - from .pipeline.spancat import make_spancat, make_spancat_singlelabel - from .pipeline.span_ruler import ( - make_entity_ruler as make_span_entity_ruler, - make_span_ruler, - ) - from .pipeline.edit_tree_lemmatizer import make_edit_tree_lemmatizer - from .pipeline.textcat_multilabel import make_multilabel_textcat - from .pipeline.span_finder import make_span_finder - from .pipeline.ner import make_ner, make_beam_ner - from .pipeline.dep_parser import make_parser, make_beam_parser - from .pipeline.tagger import make_tagger - from .pipeline.multitask import make_nn_labeller + # We can't have function implementations for these factories in Cython, because + # we need to build a Pydantic model for them dynamically, reading their argument + # structure from the signature. In Cython 3, this doesn't work because the + # from __future__ import annotations semantics are used, which means the types + # are stored as strings. + def make_sentencizer( + nlp: Language, + name: str, + punct_chars: Optional[List[str]], + overwrite: bool, + scorer: Optional[Callable], + ): + return Sentencizer( + name, punct_chars=punct_chars, overwrite=overwrite, scorer=scorer + ) - # from .pipeline.sentencizer import make_sentencizer + def make_attribute_ruler( + nlp: Language, name: str, validate: bool, scorer: Optional[Callable] + ): + from .pipeline.attributeruler import AttributeRuler + return AttributeRuler(nlp.vocab, name, validate=validate, scorer=scorer) + + def make_entity_linker( + nlp: Language, + name: str, + model: Model, + *, + labels_discard: Iterable[str], + n_sents: int, + incl_prior: bool, + incl_context: bool, + entity_vector_length: int, + get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], + get_candidates_batch: Callable[ + [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] + ], + generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], + overwrite: bool, + scorer: Optional[Callable], + use_gold_ents: bool, + candidates_batch_size: int, + threshold: Optional[float] = None, + ): + from .pipeline.entity_linker import EntityLinker, EntityLinker_v1 + + if not model.attrs.get("include_span_maker", False): + # The only difference in arguments here is that use_gold_ents and threshold aren't available. + return EntityLinker_v1( + nlp.vocab, + model, + name, + labels_discard=labels_discard, + n_sents=n_sents, + incl_prior=incl_prior, + incl_context=incl_context, + entity_vector_length=entity_vector_length, + get_candidates=get_candidates, + overwrite=overwrite, + scorer=scorer, + ) + return EntityLinker( + nlp.vocab, + model, + name, + labels_discard=labels_discard, + n_sents=n_sents, + incl_prior=incl_prior, + incl_context=incl_context, + entity_vector_length=entity_vector_length, + get_candidates=get_candidates, + get_candidates_batch=get_candidates_batch, + generate_empty_kb=generate_empty_kb, + overwrite=overwrite, + scorer=scorer, + use_gold_ents=use_gold_ents, + candidates_batch_size=candidates_batch_size, + threshold=threshold, + ) + + def make_lemmatizer( + nlp: Language, + model: Optional[Model], + name: str, + mode: str, + overwrite: bool, + scorer: Optional[Callable], + ): + from .pipeline.lemmatizer import Lemmatizer + return Lemmatizer( + nlp.vocab, model, name, mode=mode, overwrite=overwrite, scorer=scorer + ) + + def make_textcat( + nlp: Language, + name: str, + model: Model[List[Doc], List[Floats2d]], + threshold: float, + scorer: Optional[Callable], + ) -> TextCategorizer: + return TextCategorizer(nlp.vocab, model, name, threshold=threshold, scorer=scorer) + + def make_token_splitter( + nlp: Language, name: str, *, min_length: int = 0, split_length: int = 0 + ): + from .pipeline.functions import TokenSplitter + return TokenSplitter(min_length=min_length, split_length=split_length) + + def make_doc_cleaner(nlp: Language, name: str, *, attrs: Dict[str, Any], silent: bool): + from .pipeline.functions import DocCleaner + return DocCleaner(attrs, silent=silent) + + def make_tok2vec(nlp: Language, name: str, model: Model) -> Tok2Vec: + return Tok2Vec(nlp.vocab, model, name) + + def make_spancat( + nlp: Language, + name: str, + suggester: Suggester, + model: Model[Tuple[List[Doc], Ragged], Floats2d], + spans_key: str, + scorer: Optional[Callable], + threshold: float, + max_positive: Optional[int], + ) -> SpanCategorizer: + return SpanCategorizer( + nlp.vocab, + model=model, + suggester=suggester, + name=name, + spans_key=spans_key, + negative_weight=None, + allow_overlap=True, + max_positive=max_positive, + threshold=threshold, + scorer=scorer, + add_negative_label=False, + ) + + def make_spancat_singlelabel( + nlp: Language, + name: str, + suggester: Suggester, + model: Model[Tuple[List[Doc], Ragged], Floats2d], + spans_key: str, + negative_weight: float, + allow_overlap: bool, + scorer: Optional[Callable], + ) -> "SpanCategorizer": + from .pipeline.spancat import SpanCategorizer + return SpanCategorizer( + nlp.vocab, + model=model, + suggester=suggester, + name=name, + spans_key=spans_key, + negative_weight=negative_weight, + allow_overlap=allow_overlap, + max_positive=1, + add_negative_label=True, + threshold=None, + scorer=scorer, + ) + + def make_future_entity_ruler( + nlp: Language, + name: str, + phrase_matcher_attr: Optional[Union[int, str]], + matcher_fuzzy_compare: Callable, + validate: bool, + overwrite_ents: bool, + scorer: Optional[Callable], + ent_id_sep: str, + ): + from .pipeline.span_ruler import SpanRuler, prioritize_new_ents_filter, prioritize_existing_ents_filter + if overwrite_ents: + ents_filter = prioritize_new_ents_filter + else: + ents_filter = prioritize_existing_ents_filter + return SpanRuler( + nlp, + name, + spans_key=None, + spans_filter=None, + annotate_ents=True, + ents_filter=ents_filter, + phrase_matcher_attr=phrase_matcher_attr, + matcher_fuzzy_compare=matcher_fuzzy_compare, + validate=validate, + overwrite=False, + scorer=scorer, + ) + + def make_entity_ruler( + nlp: Language, + name: str, + phrase_matcher_attr: Optional[Union[int, str]], + matcher_fuzzy_compare: Callable, + validate: bool, + overwrite_ents: bool, + ent_id_sep: str, + scorer: Optional[Callable], + ): + return EntityRuler( + nlp, + name, + phrase_matcher_attr=phrase_matcher_attr, + matcher_fuzzy_compare=matcher_fuzzy_compare, + validate=validate, + overwrite_ents=overwrite_ents, + ent_id_sep=ent_id_sep, + scorer=scorer, + ) + + def make_span_ruler( + nlp: Language, + name: str, + spans_key: Optional[str], + spans_filter: Optional[Callable[[Iterable[Span], Iterable[Span]], Iterable[Span]]], + annotate_ents: bool, + ents_filter: Callable[[Iterable[Span], Iterable[Span]], Iterable[Span]], + phrase_matcher_attr: Optional[Union[int, str]], + matcher_fuzzy_compare: Callable, + validate: bool, + overwrite: bool, + scorer: Optional[Callable], + ): + from .pipeline.span_ruler import SpanRuler + return SpanRuler( + nlp, + name, + spans_key=spans_key, + spans_filter=spans_filter, + annotate_ents=annotate_ents, + ents_filter=ents_filter, + phrase_matcher_attr=phrase_matcher_attr, + matcher_fuzzy_compare=matcher_fuzzy_compare, + validate=validate, + overwrite=overwrite, + scorer=scorer, + ) + + def make_edit_tree_lemmatizer( + nlp: Language, + name: str, + model: Model, + backoff: Optional[str], + min_tree_freq: int, + overwrite: bool, + top_k: int, + scorer: Optional[Callable], + ): + from .pipeline.edit_tree_lemmatizer import EditTreeLemmatizer + return EditTreeLemmatizer( + nlp.vocab, + model, + name, + backoff=backoff, + min_tree_freq=min_tree_freq, + overwrite=overwrite, + top_k=top_k, + scorer=scorer, + ) + + def make_multilabel_textcat( + nlp: Language, + name: str, + model: Model[List[Doc], List[Floats2d]], + threshold: float, + scorer: Optional[Callable], + ) -> MultiLabel_TextCategorizer: + return MultiLabel_TextCategorizer( + nlp.vocab, model, name, threshold=threshold, scorer=scorer + ) + + def make_span_finder( + nlp: Language, + name: str, + model: Model[Iterable[Doc], Floats2d], + spans_key: str, + threshold: float, + max_length: Optional[int], + min_length: Optional[int], + scorer: Optional[Callable], + ) -> SpanFinder: + return SpanFinder( + nlp, + model=model, + threshold=threshold, + name=name, + scorer=scorer, + max_length=max_length, + min_length=min_length, + spans_key=spans_key, + ) + + def make_ner( + nlp: Language, + name: str, + model: Model, + moves: Optional[TransitionSystem], + update_with_oracle_cut_size: int, + incorrect_spans_key: Optional[str], + scorer: Optional[Callable], + ): + return EntityRecognizer( + nlp.vocab, + model, + name=name, + moves=moves, + update_with_oracle_cut_size=update_with_oracle_cut_size, + incorrect_spans_key=incorrect_spans_key, + scorer=scorer, + ) + + def make_beam_ner( + nlp: Language, + name: str, + model: Model, + moves: Optional[TransitionSystem], + update_with_oracle_cut_size: int, + beam_width: int, + beam_density: float, + beam_update_prob: float, + incorrect_spans_key: Optional[str], + scorer: Optional[Callable], + ): + return EntityRecognizer( + nlp.vocab, + model, + name=name, + moves=moves, + update_with_oracle_cut_size=update_with_oracle_cut_size, + beam_width=beam_width, + beam_density=beam_density, + beam_update_prob=beam_update_prob, + incorrect_spans_key=incorrect_spans_key, + scorer=scorer, + ) + + def make_parser( + nlp: Language, + name: str, + model: Model, + moves: Optional[TransitionSystem], + update_with_oracle_cut_size: int, + learn_tokens: bool, + min_action_freq: int, + scorer: Optional[Callable], + ): + return DependencyParser( + nlp.vocab, + model, + name=name, + moves=moves, + update_with_oracle_cut_size=update_with_oracle_cut_size, + learn_tokens=learn_tokens, + min_action_freq=min_action_freq, + scorer=scorer, + ) + + def make_beam_parser( + nlp: Language, + name: str, + model: Model, + moves: Optional[TransitionSystem], + update_with_oracle_cut_size: int, + learn_tokens: bool, + min_action_freq: int, + beam_width: int, + beam_density: float, + beam_update_prob: float, + scorer: Optional[Callable], + ): + return DependencyParser( + nlp.vocab, + model, + name=name, + moves=moves, + update_with_oracle_cut_size=update_with_oracle_cut_size, + learn_tokens=learn_tokens, + min_action_freq=min_action_freq, + beam_width=beam_width, + beam_density=beam_density, + beam_update_prob=beam_update_prob, + scorer=scorer, + ) + + def make_tagger( + nlp: Language, + name: str, + model: Model, + overwrite: bool, + scorer: Optional[Callable], + neg_prefix: str, + label_smoothing: float, + ): + return Tagger( + nlp.vocab, + model, + name=name, + overwrite=overwrite, + scorer=scorer, + neg_prefix=neg_prefix, + label_smoothing=label_smoothing, + ) + + def make_nn_labeller( + nlp: Language, + name: str, + model: Model, + labels: Optional[dict], + target: str + ): + return MultitaskObjective(nlp.vocab, model, name, target=target) + + def make_morphologizer( + nlp: Language, + model: Model, + name: str, + overwrite: bool, + extend: bool, + label_smoothing: float, + scorer: Optional[Callable], + ): + from .pipeline.morphologizer import Morphologizer + return Morphologizer( + nlp.vocab, model, name, + overwrite=overwrite, + extend=extend, + label_smoothing=label_smoothing, + scorer=scorer + ) + + def make_senter( + nlp: Language, + name: str, + model: Model, + overwrite: bool, + scorer: Optional[Callable] + ): + return SentenceRecognizer( + nlp.vocab, model, name, + overwrite=overwrite, + scorer=scorer + ) # Register factories using the same pattern as Language.factory decorator # We use Language.factory()() pattern which exactly mimics the decorator @@ -370,7 +799,7 @@ def register_factories() -> None: "ents_r": 0.0, "ents_per_type": None, }, - )(make_span_entity_ruler) + )(make_future_entity_ruler) # span_ruler Language.factory(