Add config callbacks for modifying nlp object before and after init (#5866)

* WIP: Concept for modifying nlp object before and after init

* Make callbacks return nlp object

Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>

* Raise if callbacks don't return correct type

* Rename, update types, add after_pipeline_creation

Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
Ines Montani 2020-08-05 19:47:54 +02:00 committed by GitHub
parent 586d695775
commit 823e533dc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 122 additions and 5 deletions

View File

@ -12,6 +12,9 @@ use_pytorch_for_gpu_memory = false
lang = null lang = null
pipeline = [] pipeline = []
load_vocab_data = true load_vocab_data = true
before_creation = null
after_creation = null
after_pipeline_creation = null
[nlp.tokenizer] [nlp.tokenizer]
@tokenizers = "spacy.Tokenizer.v1" @tokenizers = "spacy.Tokenizer.v1"

View File

@ -482,6 +482,13 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E942 = ("Executing after_{name} callback failed. Expected the function to "
"return an initialized nlp object but got: {value}. Maybe "
"you forgot to return the modified object in your function?")
E943 = ("Executing before_creation callback failed. Expected the function to "
"return an uninitialized Language subclass but got: {value}. Maybe "
"you forgot to return the modified object in your function or "
"returned the initialized nlp object instead?")
E944 = ("Can't copy pipeline component '{name}' from source model '{model}': " E944 = ("Can't copy pipeline component '{name}' from source model '{model}': "
"not found in pipeline. Available components: {opts}") "not found in pipeline. Available components: {opts}")
E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded " E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "

View File

@ -1464,11 +1464,27 @@ class Language:
config["components"] = orig_pipeline config["components"] = orig_pipeline
create_tokenizer = resolved["nlp"]["tokenizer"] create_tokenizer = resolved["nlp"]["tokenizer"]
create_lemmatizer = resolved["nlp"]["lemmatizer"] create_lemmatizer = resolved["nlp"]["lemmatizer"]
nlp = cls( before_creation = resolved["nlp"]["before_creation"]
after_creation = resolved["nlp"]["after_creation"]
after_pipeline_creation = resolved["nlp"]["after_pipeline_creation"]
lang_cls = cls
if before_creation is not None:
lang_cls = before_creation(cls)
if (
not isinstance(lang_cls, type)
or not issubclass(lang_cls, cls)
or lang_cls is not cls
):
raise ValueError(Errors.E943.format(value=type(lang_cls)))
nlp = lang_cls(
vocab=vocab, vocab=vocab,
create_tokenizer=create_tokenizer, create_tokenizer=create_tokenizer,
create_lemmatizer=create_lemmatizer, create_lemmatizer=create_lemmatizer,
) )
if after_creation is not None:
nlp = after_creation(nlp)
if not isinstance(nlp, cls):
raise ValueError(Errors.E942.format(name="creation", value=type(nlp)))
# Note that we don't load vectors here, instead they get loaded explicitly # Note that we don't load vectors here, instead they get loaded explicitly
# inside stuff like the spacy train function. If we loaded them here, # inside stuff like the spacy train function. If we loaded them here,
# then we would load them twice at runtime: once when we make from config, # then we would load them twice at runtime: once when we make from config,
@ -1509,6 +1525,12 @@ class Language:
nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name) nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
nlp.config = filled if auto_fill else config nlp.config = filled if auto_fill else config
nlp.resolved = resolved nlp.resolved = resolved
if after_pipeline_creation is not None:
nlp = after_pipeline_creation(nlp)
if not isinstance(nlp, cls):
raise ValueError(
Errors.E942.format(name="pipeline_creation", value=type(nlp))
)
return nlp return nlp
def to_disk( def to_disk(

View File

@ -222,6 +222,9 @@ class ConfigSchemaNlp(BaseModel):
tokenizer: Callable = Field(..., title="The tokenizer to use") tokenizer: Callable = Field(..., title="The tokenizer to use")
lemmatizer: Callable = Field(..., title="The lemmatizer to use") lemmatizer: Callable = Field(..., title="The lemmatizer to use")
load_vocab_data: StrictBool = Field(..., title="Whether to load additional vocab data from spacy-lookups-data") load_vocab_data: StrictBool = Field(..., title="Whether to load additional vocab data from spacy-lookups-data")
before_creation: Optional[Callable[[Type["Language"]], Type["Language"]]] = Field(..., title="Optional callback to modify Language class before initialization")
after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed")
after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed")
# fmt: on # fmt: on
class Config: class Config:

View File

@ -3,10 +3,11 @@ import pytest
from spacy.language import Language from spacy.language import Language
from spacy.tokens import Doc, Span from spacy.tokens import Doc, Span
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.gold import Example
from spacy.lang.en import English from spacy.lang.en import English
from spacy.util import registry
from .util import add_vecs_to_vocab, assert_docs_equal from .util import add_vecs_to_vocab, assert_docs_equal
from ..gold import Example
@pytest.fixture @pytest.fixture
@ -153,6 +154,85 @@ def test_language_pipe_stream(nlp2, n_process, texts):
assert_docs_equal(doc, expected_doc) assert_docs_equal(doc, expected_doc)
def test_language_from_config(): def test_language_from_config_before_after_init():
English.from_config() name = "test_language_from_config_before_after_init"
# TODO: add more tests ran_before = False
ran_after = False
ran_after_pipeline = False
@registry.callbacks(f"{name}_before")
def make_before_creation():
def before_creation(lang_cls):
nonlocal ran_before
ran_before = True
assert lang_cls is English
lang_cls.Defaults.foo = "bar"
return lang_cls
return before_creation
@registry.callbacks(f"{name}_after")
def make_after_creation():
def after_creation(nlp):
nonlocal ran_after
ran_after = True
assert isinstance(nlp, English)
assert nlp.pipe_names == []
assert nlp.Defaults.foo == "bar"
nlp.meta["foo"] = "bar"
return nlp
return after_creation
@registry.callbacks(f"{name}_after_pipeline")
def make_after_pipeline_creation():
def after_pipeline_creation(nlp):
nonlocal ran_after_pipeline
ran_after_pipeline = True
assert isinstance(nlp, English)
assert nlp.pipe_names == ["sentencizer"]
assert nlp.Defaults.foo == "bar"
assert nlp.meta["foo"] == "bar"
nlp.meta["bar"] = "baz"
return nlp
return after_pipeline_creation
config = {
"nlp": {
"pipeline": ["sentencizer"],
"before_creation": {"@callbacks": f"{name}_before"},
"after_creation": {"@callbacks": f"{name}_after"},
"after_pipeline_creation": {"@callbacks": f"{name}_after_pipeline"},
},
"components": {"sentencizer": {"factory": "sentencizer"}},
}
nlp = English.from_config(config)
assert all([ran_before, ran_after, ran_after_pipeline])
assert nlp.Defaults.foo == "bar"
assert nlp.meta["foo"] == "bar"
assert nlp.meta["bar"] == "baz"
assert nlp.pipe_names == ["sentencizer"]
assert nlp("text")
def test_language_from_config_before_after_init_invalid():
"""Check that an error is raised if function doesn't return nlp."""
name = "test_language_from_config_before_after_init_invalid"
registry.callbacks(f"{name}_before1", func=lambda: lambda nlp: None)
registry.callbacks(f"{name}_before2", func=lambda: lambda nlp: nlp())
registry.callbacks(f"{name}_after1", func=lambda: lambda nlp: None)
registry.callbacks(f"{name}_after1", func=lambda: lambda nlp: English)
for callback_name in [f"{name}_before1", f"{name}_before2"]:
config = {"nlp": {"before_creation": {"@callbacks": callback_name}}}
with pytest.raises(ValueError):
English.from_config(config)
for callback_name in [f"{name}_after1", f"{name}_after2"]:
config = {"nlp": {"after_creation": {"@callbacks": callback_name}}}
with pytest.raises(ValueError):
English.from_config(config)
for callback_name in [f"{name}_after1", f"{name}_after2"]:
config = {"nlp": {"after_pipeline_creation": {"@callbacks": callback_name}}}
with pytest.raises(ValueError):
English.from_config(config)

View File

@ -67,6 +67,8 @@ class registry(thinc.registry):
lookups = catalogue.create("spacy", "lookups", entry_points=True) lookups = catalogue.create("spacy", "lookups", entry_points=True)
displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True) displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True)
assets = catalogue.create("spacy", "assets", entry_points=True) assets = catalogue.create("spacy", "assets", entry_points=True)
# Callback functions used to manipulate nlp object etc.
callbacks = catalogue.create("spacy", "callbacks")
batchers = catalogue.create("spacy", "batchers", entry_points=True) batchers = catalogue.create("spacy", "batchers", entry_points=True)
readers = catalogue.create("spacy", "readers", entry_points=True) readers = catalogue.create("spacy", "readers", entry_points=True)
# These are factories registered via third-party packages and the # These are factories registered via third-party packages and the