From 823e533dc182eadce00695da3f8b0798de7ed4f9 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Wed, 5 Aug 2020 19:47:54 +0200 Subject: [PATCH] Add config callbacks for modifying nlp object before and after init (#5866) * WIP: Concept for modifying nlp object before and after init * Make callbacks return nlp object Co-authored-by: Matthew Honnibal * Raise if callbacks don't return correct type * Rename, update types, add after_pipeline_creation Co-authored-by: Matthew Honnibal --- spacy/default_config.cfg | 3 ++ spacy/errors.py | 7 +++ spacy/language.py | 24 +++++++++- spacy/schemas.py | 3 ++ spacy/tests/test_language.py | 88 ++++++++++++++++++++++++++++++++++-- spacy/util.py | 2 + 6 files changed, 122 insertions(+), 5 deletions(-) diff --git a/spacy/default_config.cfg b/spacy/default_config.cfg index f35be605c..353924280 100644 --- a/spacy/default_config.cfg +++ b/spacy/default_config.cfg @@ -12,6 +12,9 @@ use_pytorch_for_gpu_memory = false lang = null pipeline = [] load_vocab_data = true +before_creation = null +after_creation = null +after_pipeline_creation = null [nlp.tokenizer] @tokenizers = "spacy.Tokenizer.v1" diff --git a/spacy/errors.py b/spacy/errors.py index 973843bb7..378641fec 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -482,6 +482,13 @@ class Errors: E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") # TODO: fix numbering after merging develop into master + E942 = ("Executing after_{name} callback failed. Expected the function to " + "return an initialized nlp object but got: {value}. Maybe " + "you forgot to return the modified object in your function?") + E943 = ("Executing before_creation callback failed. Expected the function to " + "return an uninitialized Language subclass but got: {value}. Maybe " + "you forgot to return the modified object in your function or " + "returned the initialized nlp object instead?") E944 = ("Can't copy pipeline component '{name}' from source model '{model}': " "not found in pipeline. Available components: {opts}") E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded " diff --git a/spacy/language.py b/spacy/language.py index 31bb744db..4b44b9820 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1464,11 +1464,27 @@ class Language: config["components"] = orig_pipeline create_tokenizer = resolved["nlp"]["tokenizer"] create_lemmatizer = resolved["nlp"]["lemmatizer"] - nlp = cls( + before_creation = resolved["nlp"]["before_creation"] + after_creation = resolved["nlp"]["after_creation"] + after_pipeline_creation = resolved["nlp"]["after_pipeline_creation"] + lang_cls = cls + if before_creation is not None: + lang_cls = before_creation(cls) + if ( + not isinstance(lang_cls, type) + or not issubclass(lang_cls, cls) + or lang_cls is not cls + ): + raise ValueError(Errors.E943.format(value=type(lang_cls))) + nlp = lang_cls( vocab=vocab, create_tokenizer=create_tokenizer, create_lemmatizer=create_lemmatizer, ) + if after_creation is not None: + nlp = after_creation(nlp) + if not isinstance(nlp, cls): + raise ValueError(Errors.E942.format(name="creation", value=type(nlp))) # Note that we don't load vectors here, instead they get loaded explicitly # inside stuff like the spacy train function. If we loaded them here, # then we would load them twice at runtime: once when we make from config, @@ -1509,6 +1525,12 @@ class Language: nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name) nlp.config = filled if auto_fill else config nlp.resolved = resolved + if after_pipeline_creation is not None: + nlp = after_pipeline_creation(nlp) + if not isinstance(nlp, cls): + raise ValueError( + Errors.E942.format(name="pipeline_creation", value=type(nlp)) + ) return nlp def to_disk( diff --git a/spacy/schemas.py b/spacy/schemas.py index 745d46333..d599ccbb2 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -222,6 +222,9 @@ class ConfigSchemaNlp(BaseModel): tokenizer: Callable = Field(..., title="The tokenizer to use") lemmatizer: Callable = Field(..., title="The lemmatizer to use") load_vocab_data: StrictBool = Field(..., title="Whether to load additional vocab data from spacy-lookups-data") + before_creation: Optional[Callable[[Type["Language"]], Type["Language"]]] = Field(..., title="Optional callback to modify Language class before initialization") + after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed") + after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed") # fmt: on class Config: diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py index a63a8e24c..6865cd1e5 100644 --- a/spacy/tests/test_language.py +++ b/spacy/tests/test_language.py @@ -3,10 +3,11 @@ import pytest from spacy.language import Language from spacy.tokens import Doc, Span from spacy.vocab import Vocab +from spacy.gold import Example from spacy.lang.en import English +from spacy.util import registry from .util import add_vecs_to_vocab, assert_docs_equal -from ..gold import Example @pytest.fixture @@ -153,6 +154,85 @@ def test_language_pipe_stream(nlp2, n_process, texts): assert_docs_equal(doc, expected_doc) -def test_language_from_config(): - English.from_config() - # TODO: add more tests +def test_language_from_config_before_after_init(): + name = "test_language_from_config_before_after_init" + ran_before = False + ran_after = False + ran_after_pipeline = False + + @registry.callbacks(f"{name}_before") + def make_before_creation(): + def before_creation(lang_cls): + nonlocal ran_before + ran_before = True + assert lang_cls is English + lang_cls.Defaults.foo = "bar" + return lang_cls + + return before_creation + + @registry.callbacks(f"{name}_after") + def make_after_creation(): + def after_creation(nlp): + nonlocal ran_after + ran_after = True + assert isinstance(nlp, English) + assert nlp.pipe_names == [] + assert nlp.Defaults.foo == "bar" + nlp.meta["foo"] = "bar" + return nlp + + return after_creation + + @registry.callbacks(f"{name}_after_pipeline") + def make_after_pipeline_creation(): + def after_pipeline_creation(nlp): + nonlocal ran_after_pipeline + ran_after_pipeline = True + assert isinstance(nlp, English) + assert nlp.pipe_names == ["sentencizer"] + assert nlp.Defaults.foo == "bar" + assert nlp.meta["foo"] == "bar" + nlp.meta["bar"] = "baz" + return nlp + + return after_pipeline_creation + + config = { + "nlp": { + "pipeline": ["sentencizer"], + "before_creation": {"@callbacks": f"{name}_before"}, + "after_creation": {"@callbacks": f"{name}_after"}, + "after_pipeline_creation": {"@callbacks": f"{name}_after_pipeline"}, + }, + "components": {"sentencizer": {"factory": "sentencizer"}}, + } + nlp = English.from_config(config) + assert all([ran_before, ran_after, ran_after_pipeline]) + assert nlp.Defaults.foo == "bar" + assert nlp.meta["foo"] == "bar" + assert nlp.meta["bar"] == "baz" + assert nlp.pipe_names == ["sentencizer"] + assert nlp("text") + + +def test_language_from_config_before_after_init_invalid(): + """Check that an error is raised if function doesn't return nlp.""" + name = "test_language_from_config_before_after_init_invalid" + registry.callbacks(f"{name}_before1", func=lambda: lambda nlp: None) + registry.callbacks(f"{name}_before2", func=lambda: lambda nlp: nlp()) + registry.callbacks(f"{name}_after1", func=lambda: lambda nlp: None) + registry.callbacks(f"{name}_after1", func=lambda: lambda nlp: English) + + for callback_name in [f"{name}_before1", f"{name}_before2"]: + config = {"nlp": {"before_creation": {"@callbacks": callback_name}}} + with pytest.raises(ValueError): + English.from_config(config) + for callback_name in [f"{name}_after1", f"{name}_after2"]: + config = {"nlp": {"after_creation": {"@callbacks": callback_name}}} + with pytest.raises(ValueError): + English.from_config(config) + for callback_name in [f"{name}_after1", f"{name}_after2"]: + config = {"nlp": {"after_pipeline_creation": {"@callbacks": callback_name}}} + with pytest.raises(ValueError): + English.from_config(config) diff --git a/spacy/util.py b/spacy/util.py index 96257896f..b5140d420 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -67,6 +67,8 @@ class registry(thinc.registry): lookups = catalogue.create("spacy", "lookups", entry_points=True) displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True) assets = catalogue.create("spacy", "assets", entry_points=True) + # Callback functions used to manipulate nlp object etc. + callbacks = catalogue.create("spacy", "callbacks") batchers = catalogue.create("spacy", "batchers", entry_points=True) readers = catalogue.create("spacy", "readers", entry_points=True) # These are factories registered via third-party packages and the