Make factories top-level functions in registrations.py

This commit is contained in:
Matthew Honnibal 2025-05-21 14:03:11 +02:00
parent 5c331884c3
commit d8388aa591
2 changed files with 443 additions and 438 deletions

View File

@ -19,8 +19,6 @@ from .pipeline.entityruler import EntityRuler
from .pipeline.span_finder import SpanFinder
from .pipeline.ner import EntityRecognizer
from .pipeline._parser_internals.transition_system import TransitionSystem
from .pipeline.ner import EntityRecognizer
from .pipeline.dep_parser import DependencyParser
from .pipeline.dep_parser import DependencyParser
from .pipeline.tagger import Tagger
from .pipeline.multitask import MultitaskObjective
@ -169,442 +167,6 @@ def register_factories() -> None:
if FACTORIES_REGISTERED:
return
# We can't have function implementations for these factories in Cython, because
# we need to build a Pydantic model for them dynamically, reading their argument
# structure from the signature. In Cython 3, this doesn't work because the
# from __future__ import annotations semantics are used, which means the types
# are stored as strings.
def make_sentencizer(
nlp: Language,
name: str,
punct_chars: Optional[List[str]],
overwrite: bool,
scorer: Optional[Callable],
):
return Sentencizer(
name, punct_chars=punct_chars, overwrite=overwrite, scorer=scorer
)
def make_attribute_ruler(
nlp: Language, name: str, validate: bool, scorer: Optional[Callable]
):
return AttributeRuler(nlp.vocab, name, validate=validate, scorer=scorer)
def make_entity_linker(
nlp: Language,
name: str,
model: Model,
*,
labels_discard: Iterable[str],
n_sents: int,
incl_prior: bool,
incl_context: bool,
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_batch: Callable[
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool,
scorer: Optional[Callable],
use_gold_ents: bool,
candidates_batch_size: int,
threshold: Optional[float] = None,
):
if not model.attrs.get("include_span_maker", False):
# The only difference in arguments here is that use_gold_ents and threshold aren't available.
return EntityLinker_v1(
nlp.vocab,
model,
name,
labels_discard=labels_discard,
n_sents=n_sents,
incl_prior=incl_prior,
incl_context=incl_context,
entity_vector_length=entity_vector_length,
get_candidates=get_candidates,
overwrite=overwrite,
scorer=scorer,
)
return EntityLinker(
nlp.vocab,
model,
name,
labels_discard=labels_discard,
n_sents=n_sents,
incl_prior=incl_prior,
incl_context=incl_context,
entity_vector_length=entity_vector_length,
get_candidates=get_candidates,
get_candidates_batch=get_candidates_batch,
generate_empty_kb=generate_empty_kb,
overwrite=overwrite,
scorer=scorer,
use_gold_ents=use_gold_ents,
candidates_batch_size=candidates_batch_size,
threshold=threshold,
)
def make_lemmatizer(
nlp: Language,
model: Optional[Model],
name: str,
mode: str,
overwrite: bool,
scorer: Optional[Callable],
):
return Lemmatizer(
nlp.vocab, model, name, mode=mode, overwrite=overwrite, scorer=scorer
)
def make_textcat(
nlp: Language,
name: str,
model: Model[List[Doc], List[Floats2d]],
threshold: float,
scorer: Optional[Callable],
) -> TextCategorizer:
return TextCategorizer(nlp.vocab, model, name, threshold=threshold, scorer=scorer)
def make_token_splitter(
nlp: Language, name: str, *, min_length: int = 0, split_length: int = 0
):
return TokenSplitter(min_length=min_length, split_length=split_length)
def make_doc_cleaner(nlp: Language, name: str, *, attrs: Dict[str, Any], silent: bool):
return DocCleaner(attrs, silent=silent)
def make_tok2vec(nlp: Language, name: str, model: Model) -> Tok2Vec:
return Tok2Vec(nlp.vocab, model, name)
def make_spancat(
nlp: Language,
name: str,
suggester: Suggester,
model: Model[Tuple[List[Doc], Ragged], Floats2d],
spans_key: str,
scorer: Optional[Callable],
threshold: float,
max_positive: Optional[int],
) -> SpanCategorizer:
return SpanCategorizer(
nlp.vocab,
model=model,
suggester=suggester,
name=name,
spans_key=spans_key,
negative_weight=None,
allow_overlap=True,
max_positive=max_positive,
threshold=threshold,
scorer=scorer,
add_negative_label=False,
)
def make_spancat_singlelabel(
nlp: Language,
name: str,
suggester: Suggester,
model: Model[Tuple[List[Doc], Ragged], Floats2d],
spans_key: str,
negative_weight: float,
allow_overlap: bool,
scorer: Optional[Callable],
) -> SpanCategorizer:
return SpanCategorizer(
nlp.vocab,
model=model,
suggester=suggester,
name=name,
spans_key=spans_key,
negative_weight=negative_weight,
allow_overlap=allow_overlap,
max_positive=1,
add_negative_label=True,
threshold=None,
scorer=scorer,
)
def make_future_entity_ruler(
nlp: Language,
name: str,
phrase_matcher_attr: Optional[Union[int, str]],
matcher_fuzzy_compare: Callable,
validate: bool,
overwrite_ents: bool,
scorer: Optional[Callable],
ent_id_sep: str,
):
if overwrite_ents:
ents_filter = prioritize_new_ents_filter
else:
ents_filter = prioritize_existing_ents_filter
return SpanRuler(
nlp,
name,
spans_key=None,
spans_filter=None,
annotate_ents=True,
ents_filter=ents_filter,
phrase_matcher_attr=phrase_matcher_attr,
matcher_fuzzy_compare=matcher_fuzzy_compare,
validate=validate,
overwrite=False,
scorer=scorer,
)
def make_entity_ruler(
nlp: Language,
name: str,
phrase_matcher_attr: Optional[Union[int, str]],
matcher_fuzzy_compare: Callable,
validate: bool,
overwrite_ents: bool,
ent_id_sep: str,
scorer: Optional[Callable],
):
return EntityRuler(
nlp,
name,
phrase_matcher_attr=phrase_matcher_attr,
matcher_fuzzy_compare=matcher_fuzzy_compare,
validate=validate,
overwrite_ents=overwrite_ents,
ent_id_sep=ent_id_sep,
scorer=scorer,
)
def make_span_ruler(
nlp: Language,
name: str,
spans_key: Optional[str],
spans_filter: Optional[Callable[[Iterable[Span], Iterable[Span]], Iterable[Span]]],
annotate_ents: bool,
ents_filter: Callable[[Iterable[Span], Iterable[Span]], Iterable[Span]],
phrase_matcher_attr: Optional[Union[int, str]],
matcher_fuzzy_compare: Callable,
validate: bool,
overwrite: bool,
scorer: Optional[Callable],
):
return SpanRuler(
nlp,
name,
spans_key=spans_key,
spans_filter=spans_filter,
annotate_ents=annotate_ents,
ents_filter=ents_filter,
phrase_matcher_attr=phrase_matcher_attr,
matcher_fuzzy_compare=matcher_fuzzy_compare,
validate=validate,
overwrite=overwrite,
scorer=scorer,
)
def make_edit_tree_lemmatizer(
nlp: Language,
name: str,
model: Model,
backoff: Optional[str],
min_tree_freq: int,
overwrite: bool,
top_k: int,
scorer: Optional[Callable],
):
return EditTreeLemmatizer(
nlp.vocab,
model,
name,
backoff=backoff,
min_tree_freq=min_tree_freq,
overwrite=overwrite,
top_k=top_k,
scorer=scorer,
)
def make_multilabel_textcat(
nlp: Language,
name: str,
model: Model[List[Doc], List[Floats2d]],
threshold: float,
scorer: Optional[Callable],
) -> MultiLabel_TextCategorizer:
return MultiLabel_TextCategorizer(
nlp.vocab, model, name, threshold=threshold, scorer=scorer
)
def make_span_finder(
nlp: Language,
name: str,
model: Model[Iterable[Doc], Floats2d],
spans_key: str,
threshold: float,
max_length: Optional[int],
min_length: Optional[int],
scorer: Optional[Callable],
) -> SpanFinder:
return SpanFinder(
nlp,
model=model,
threshold=threshold,
name=name,
scorer=scorer,
max_length=max_length,
min_length=min_length,
spans_key=spans_key,
)
def make_ner(
nlp: Language,
name: str,
model: Model,
moves: Optional[TransitionSystem],
update_with_oracle_cut_size: int,
incorrect_spans_key: Optional[str],
scorer: Optional[Callable],
):
return EntityRecognizer(
nlp.vocab,
model,
name=name,
moves=moves,
update_with_oracle_cut_size=update_with_oracle_cut_size,
incorrect_spans_key=incorrect_spans_key,
scorer=scorer,
)
def make_beam_ner(
nlp: Language,
name: str,
model: Model,
moves: Optional[TransitionSystem],
update_with_oracle_cut_size: int,
beam_width: int,
beam_density: float,
beam_update_prob: float,
incorrect_spans_key: Optional[str],
scorer: Optional[Callable],
):
return EntityRecognizer(
nlp.vocab,
model,
name=name,
moves=moves,
update_with_oracle_cut_size=update_with_oracle_cut_size,
beam_width=beam_width,
beam_density=beam_density,
beam_update_prob=beam_update_prob,
incorrect_spans_key=incorrect_spans_key,
scorer=scorer,
)
def make_parser(
nlp: Language,
name: str,
model: Model,
moves: Optional[TransitionSystem],
update_with_oracle_cut_size: int,
learn_tokens: bool,
min_action_freq: int,
scorer: Optional[Callable],
):
return DependencyParser(
nlp.vocab,
model,
name=name,
moves=moves,
update_with_oracle_cut_size=update_with_oracle_cut_size,
learn_tokens=learn_tokens,
min_action_freq=min_action_freq,
scorer=scorer,
)
def make_beam_parser(
nlp: Language,
name: str,
model: Model,
moves: Optional[TransitionSystem],
update_with_oracle_cut_size: int,
learn_tokens: bool,
min_action_freq: int,
beam_width: int,
beam_density: float,
beam_update_prob: float,
scorer: Optional[Callable],
):
return DependencyParser(
nlp.vocab,
model,
name=name,
moves=moves,
update_with_oracle_cut_size=update_with_oracle_cut_size,
learn_tokens=learn_tokens,
min_action_freq=min_action_freq,
beam_width=beam_width,
beam_density=beam_density,
beam_update_prob=beam_update_prob,
scorer=scorer,
)
def make_tagger(
nlp: Language,
name: str,
model: Model,
overwrite: bool,
scorer: Optional[Callable],
neg_prefix: str,
label_smoothing: float,
):
return Tagger(
nlp.vocab,
model,
name=name,
overwrite=overwrite,
scorer=scorer,
neg_prefix=neg_prefix,
label_smoothing=label_smoothing,
)
def make_nn_labeller(
nlp: Language,
name: str,
model: Model,
labels: Optional[dict],
target: str
):
return MultitaskObjective(nlp.vocab, model, name, target=target)
def make_morphologizer(
nlp: Language,
model: Model,
name: str,
overwrite: bool,
extend: bool,
label_smoothing: float,
scorer: Optional[Callable],
):
return Morphologizer(
nlp.vocab, model, name,
overwrite=overwrite,
extend=extend,
label_smoothing=label_smoothing,
scorer=scorer
)
def make_senter(
nlp: Language,
name: str,
model: Model,
overwrite: bool,
scorer: Optional[Callable]
):
return SentenceRecognizer(
nlp.vocab, model, name,
overwrite=overwrite,
scorer=scorer
)
# Register factories using the same pattern as Language.factory decorator
# We use Language.factory()() pattern which exactly mimics the decorator
@ -1017,3 +579,442 @@ def register_factories() -> None:
# Set the flag to indicate that all factories have been registered
FACTORIES_REGISTERED = True
# We can't have function implementations for these factories in Cython, because
# we need to build a Pydantic model for them dynamically, reading their argument
# structure from the signature. In Cython 3, this doesn't work because the
# from __future__ import annotations semantics are used, which means the types
# are stored as strings.
def make_sentencizer(
nlp: Language,
name: str,
punct_chars: Optional[List[str]],
overwrite: bool,
scorer: Optional[Callable],
):
return Sentencizer(
name, punct_chars=punct_chars, overwrite=overwrite, scorer=scorer
)
def make_attribute_ruler(
nlp: Language, name: str, validate: bool, scorer: Optional[Callable]
):
return AttributeRuler(nlp.vocab, name, validate=validate, scorer=scorer)
def make_entity_linker(
nlp: Language,
name: str,
model: Model,
*,
labels_discard: Iterable[str],
n_sents: int,
incl_prior: bool,
incl_context: bool,
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_batch: Callable[
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool,
scorer: Optional[Callable],
use_gold_ents: bool,
candidates_batch_size: int,
threshold: Optional[float] = None,
):
if not model.attrs.get("include_span_maker", False):
# The only difference in arguments here is that use_gold_ents and threshold aren't available.
return EntityLinker_v1(
nlp.vocab,
model,
name,
labels_discard=labels_discard,
n_sents=n_sents,
incl_prior=incl_prior,
incl_context=incl_context,
entity_vector_length=entity_vector_length,
get_candidates=get_candidates,
overwrite=overwrite,
scorer=scorer,
)
return EntityLinker(
nlp.vocab,
model,
name,
labels_discard=labels_discard,
n_sents=n_sents,
incl_prior=incl_prior,
incl_context=incl_context,
entity_vector_length=entity_vector_length,
get_candidates=get_candidates,
get_candidates_batch=get_candidates_batch,
generate_empty_kb=generate_empty_kb,
overwrite=overwrite,
scorer=scorer,
use_gold_ents=use_gold_ents,
candidates_batch_size=candidates_batch_size,
threshold=threshold,
)
def make_lemmatizer(
nlp: Language,
model: Optional[Model],
name: str,
mode: str,
overwrite: bool,
scorer: Optional[Callable],
):
return Lemmatizer(
nlp.vocab, model, name, mode=mode, overwrite=overwrite, scorer=scorer
)
def make_textcat(
nlp: Language,
name: str,
model: Model[List[Doc], List[Floats2d]],
threshold: float,
scorer: Optional[Callable],
) -> TextCategorizer:
return TextCategorizer(nlp.vocab, model, name, threshold=threshold, scorer=scorer)
def make_token_splitter(
nlp: Language, name: str, *, min_length: int = 0, split_length: int = 0
):
return TokenSplitter(min_length=min_length, split_length=split_length)
def make_doc_cleaner(nlp: Language, name: str, *, attrs: Dict[str, Any], silent: bool):
return DocCleaner(attrs, silent=silent)
def make_tok2vec(nlp: Language, name: str, model: Model) -> Tok2Vec:
return Tok2Vec(nlp.vocab, model, name)
def make_spancat(
nlp: Language,
name: str,
suggester: Suggester,
model: Model[Tuple[List[Doc], Ragged], Floats2d],
spans_key: str,
scorer: Optional[Callable],
threshold: float,
max_positive: Optional[int],
) -> SpanCategorizer:
return SpanCategorizer(
nlp.vocab,
model=model,
suggester=suggester,
name=name,
spans_key=spans_key,
negative_weight=None,
allow_overlap=True,
max_positive=max_positive,
threshold=threshold,
scorer=scorer,
add_negative_label=False,
)
def make_spancat_singlelabel(
nlp: Language,
name: str,
suggester: Suggester,
model: Model[Tuple[List[Doc], Ragged], Floats2d],
spans_key: str,
negative_weight: float,
allow_overlap: bool,
scorer: Optional[Callable],
) -> SpanCategorizer:
return SpanCategorizer(
nlp.vocab,
model=model,
suggester=suggester,
name=name,
spans_key=spans_key,
negative_weight=negative_weight,
allow_overlap=allow_overlap,
max_positive=1,
add_negative_label=True,
threshold=None,
scorer=scorer,
)
def make_future_entity_ruler(
nlp: Language,
name: str,
phrase_matcher_attr: Optional[Union[int, str]],
matcher_fuzzy_compare: Callable,
validate: bool,
overwrite_ents: bool,
scorer: Optional[Callable],
ent_id_sep: str,
):
if overwrite_ents:
ents_filter = prioritize_new_ents_filter
else:
ents_filter = prioritize_existing_ents_filter
return SpanRuler(
nlp,
name,
spans_key=None,
spans_filter=None,
annotate_ents=True,
ents_filter=ents_filter,
phrase_matcher_attr=phrase_matcher_attr,
matcher_fuzzy_compare=matcher_fuzzy_compare,
validate=validate,
overwrite=False,
scorer=scorer,
)
def make_entity_ruler(
nlp: Language,
name: str,
phrase_matcher_attr: Optional[Union[int, str]],
matcher_fuzzy_compare: Callable,
validate: bool,
overwrite_ents: bool,
ent_id_sep: str,
scorer: Optional[Callable],
):
return EntityRuler(
nlp,
name,
phrase_matcher_attr=phrase_matcher_attr,
matcher_fuzzy_compare=matcher_fuzzy_compare,
validate=validate,
overwrite_ents=overwrite_ents,
ent_id_sep=ent_id_sep,
scorer=scorer,
)
def make_span_ruler(
nlp: Language,
name: str,
spans_key: Optional[str],
spans_filter: Optional[Callable[[Iterable[Span], Iterable[Span]], Iterable[Span]]],
annotate_ents: bool,
ents_filter: Callable[[Iterable[Span], Iterable[Span]], Iterable[Span]],
phrase_matcher_attr: Optional[Union[int, str]],
matcher_fuzzy_compare: Callable,
validate: bool,
overwrite: bool,
scorer: Optional[Callable],
):
return SpanRuler(
nlp,
name,
spans_key=spans_key,
spans_filter=spans_filter,
annotate_ents=annotate_ents,
ents_filter=ents_filter,
phrase_matcher_attr=phrase_matcher_attr,
matcher_fuzzy_compare=matcher_fuzzy_compare,
validate=validate,
overwrite=overwrite,
scorer=scorer,
)
def make_edit_tree_lemmatizer(
nlp: Language,
name: str,
model: Model,
backoff: Optional[str],
min_tree_freq: int,
overwrite: bool,
top_k: int,
scorer: Optional[Callable],
):
return EditTreeLemmatizer(
nlp.vocab,
model,
name,
backoff=backoff,
min_tree_freq=min_tree_freq,
overwrite=overwrite,
top_k=top_k,
scorer=scorer,
)
def make_multilabel_textcat(
nlp: Language,
name: str,
model: Model[List[Doc], List[Floats2d]],
threshold: float,
scorer: Optional[Callable],
) -> MultiLabel_TextCategorizer:
return MultiLabel_TextCategorizer(
nlp.vocab, model, name, threshold=threshold, scorer=scorer
)
def make_span_finder(
nlp: Language,
name: str,
model: Model[Iterable[Doc], Floats2d],
spans_key: str,
threshold: float,
max_length: Optional[int],
min_length: Optional[int],
scorer: Optional[Callable],
) -> SpanFinder:
return SpanFinder(
nlp,
model=model,
threshold=threshold,
name=name,
scorer=scorer,
max_length=max_length,
min_length=min_length,
spans_key=spans_key,
)
def make_ner(
nlp: Language,
name: str,
model: Model,
moves: Optional[TransitionSystem],
update_with_oracle_cut_size: int,
incorrect_spans_key: Optional[str],
scorer: Optional[Callable],
):
return EntityRecognizer(
nlp.vocab,
model,
name=name,
moves=moves,
update_with_oracle_cut_size=update_with_oracle_cut_size,
incorrect_spans_key=incorrect_spans_key,
scorer=scorer,
)
def make_beam_ner(
nlp: Language,
name: str,
model: Model,
moves: Optional[TransitionSystem],
update_with_oracle_cut_size: int,
beam_width: int,
beam_density: float,
beam_update_prob: float,
incorrect_spans_key: Optional[str],
scorer: Optional[Callable],
):
return EntityRecognizer(
nlp.vocab,
model,
name=name,
moves=moves,
update_with_oracle_cut_size=update_with_oracle_cut_size,
beam_width=beam_width,
beam_density=beam_density,
beam_update_prob=beam_update_prob,
incorrect_spans_key=incorrect_spans_key,
scorer=scorer,
)
def make_parser(
nlp: Language,
name: str,
model: Model,
moves: Optional[TransitionSystem],
update_with_oracle_cut_size: int,
learn_tokens: bool,
min_action_freq: int,
scorer: Optional[Callable],
):
return DependencyParser(
nlp.vocab,
model,
name=name,
moves=moves,
update_with_oracle_cut_size=update_with_oracle_cut_size,
learn_tokens=learn_tokens,
min_action_freq=min_action_freq,
scorer=scorer,
)
def make_beam_parser(
nlp: Language,
name: str,
model: Model,
moves: Optional[TransitionSystem],
update_with_oracle_cut_size: int,
learn_tokens: bool,
min_action_freq: int,
beam_width: int,
beam_density: float,
beam_update_prob: float,
scorer: Optional[Callable],
):
return DependencyParser(
nlp.vocab,
model,
name=name,
moves=moves,
update_with_oracle_cut_size=update_with_oracle_cut_size,
learn_tokens=learn_tokens,
min_action_freq=min_action_freq,
beam_width=beam_width,
beam_density=beam_density,
beam_update_prob=beam_update_prob,
scorer=scorer,
)
def make_tagger(
nlp: Language,
name: str,
model: Model,
overwrite: bool,
scorer: Optional[Callable],
neg_prefix: str,
label_smoothing: float,
):
return Tagger(
nlp.vocab,
model,
name=name,
overwrite=overwrite,
scorer=scorer,
neg_prefix=neg_prefix,
label_smoothing=label_smoothing,
)
def make_nn_labeller(
nlp: Language,
name: str,
model: Model,
labels: Optional[dict],
target: str
):
return MultitaskObjective(nlp.vocab, model, name, target=target)
def make_morphologizer(
nlp: Language,
model: Model,
name: str,
overwrite: bool,
extend: bool,
label_smoothing: float,
scorer: Optional[Callable],
):
return Morphologizer(
nlp.vocab, model, name,
overwrite=overwrite,
extend=extend,
label_smoothing=label_smoothing,
scorer=scorer
)
def make_senter(
nlp: Language,
name: str,
model: Model,
overwrite: bool,
scorer: Optional[Callable]
):
return SentenceRecognizer(
nlp.vocab, model, name,
overwrite=overwrite,
scorer=scorer
)

View File

@ -101,7 +101,11 @@ def test_cat_readers(reader, additional_config):
nlp = load_model_from_config(config, auto_fill=True)
T = registry.resolve(nlp.config["training"], schema=ConfigSchemaTraining)
dot_names = [T["train_corpus"], T["dev_corpus"]]
print("T", T)
print("dot names", dot_names)
train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names)
data = list(train_corpus(nlp))
print(len(data))
optimizer = T["optimizer"]
# simulate a training loop
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)