mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
Add config callbacks for modifying nlp object before and after init (#5866)
* WIP: Concept for modifying nlp object before and after init * Make callbacks return nlp object Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> * Raise if callbacks don't return correct type * Rename, update types, add after_pipeline_creation Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
parent
586d695775
commit
823e533dc1
|
@ -12,6 +12,9 @@ use_pytorch_for_gpu_memory = false
|
|||
lang = null
|
||||
pipeline = []
|
||||
load_vocab_data = true
|
||||
before_creation = null
|
||||
after_creation = null
|
||||
after_pipeline_creation = null
|
||||
|
||||
[nlp.tokenizer]
|
||||
@tokenizers = "spacy.Tokenizer.v1"
|
||||
|
|
|
@ -482,6 +482,13 @@ class Errors:
|
|||
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
||||
|
||||
# TODO: fix numbering after merging develop into master
|
||||
E942 = ("Executing after_{name} callback failed. Expected the function to "
|
||||
"return an initialized nlp object but got: {value}. Maybe "
|
||||
"you forgot to return the modified object in your function?")
|
||||
E943 = ("Executing before_creation callback failed. Expected the function to "
|
||||
"return an uninitialized Language subclass but got: {value}. Maybe "
|
||||
"you forgot to return the modified object in your function or "
|
||||
"returned the initialized nlp object instead?")
|
||||
E944 = ("Can't copy pipeline component '{name}' from source model '{model}': "
|
||||
"not found in pipeline. Available components: {opts}")
|
||||
E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "
|
||||
|
|
|
@ -1464,11 +1464,27 @@ class Language:
|
|||
config["components"] = orig_pipeline
|
||||
create_tokenizer = resolved["nlp"]["tokenizer"]
|
||||
create_lemmatizer = resolved["nlp"]["lemmatizer"]
|
||||
nlp = cls(
|
||||
before_creation = resolved["nlp"]["before_creation"]
|
||||
after_creation = resolved["nlp"]["after_creation"]
|
||||
after_pipeline_creation = resolved["nlp"]["after_pipeline_creation"]
|
||||
lang_cls = cls
|
||||
if before_creation is not None:
|
||||
lang_cls = before_creation(cls)
|
||||
if (
|
||||
not isinstance(lang_cls, type)
|
||||
or not issubclass(lang_cls, cls)
|
||||
or lang_cls is not cls
|
||||
):
|
||||
raise ValueError(Errors.E943.format(value=type(lang_cls)))
|
||||
nlp = lang_cls(
|
||||
vocab=vocab,
|
||||
create_tokenizer=create_tokenizer,
|
||||
create_lemmatizer=create_lemmatizer,
|
||||
)
|
||||
if after_creation is not None:
|
||||
nlp = after_creation(nlp)
|
||||
if not isinstance(nlp, cls):
|
||||
raise ValueError(Errors.E942.format(name="creation", value=type(nlp)))
|
||||
# Note that we don't load vectors here, instead they get loaded explicitly
|
||||
# inside stuff like the spacy train function. If we loaded them here,
|
||||
# then we would load them twice at runtime: once when we make from config,
|
||||
|
@ -1509,6 +1525,12 @@ class Language:
|
|||
nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
|
||||
nlp.config = filled if auto_fill else config
|
||||
nlp.resolved = resolved
|
||||
if after_pipeline_creation is not None:
|
||||
nlp = after_pipeline_creation(nlp)
|
||||
if not isinstance(nlp, cls):
|
||||
raise ValueError(
|
||||
Errors.E942.format(name="pipeline_creation", value=type(nlp))
|
||||
)
|
||||
return nlp
|
||||
|
||||
def to_disk(
|
||||
|
|
|
@ -222,6 +222,9 @@ class ConfigSchemaNlp(BaseModel):
|
|||
tokenizer: Callable = Field(..., title="The tokenizer to use")
|
||||
lemmatizer: Callable = Field(..., title="The lemmatizer to use")
|
||||
load_vocab_data: StrictBool = Field(..., title="Whether to load additional vocab data from spacy-lookups-data")
|
||||
before_creation: Optional[Callable[[Type["Language"]], Type["Language"]]] = Field(..., title="Optional callback to modify Language class before initialization")
|
||||
after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed")
|
||||
after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
|
|
|
@ -3,10 +3,11 @@ import pytest
|
|||
from spacy.language import Language
|
||||
from spacy.tokens import Doc, Span
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.gold import Example
|
||||
from spacy.lang.en import English
|
||||
from spacy.util import registry
|
||||
|
||||
from .util import add_vecs_to_vocab, assert_docs_equal
|
||||
from ..gold import Example
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -153,6 +154,85 @@ def test_language_pipe_stream(nlp2, n_process, texts):
|
|||
assert_docs_equal(doc, expected_doc)
|
||||
|
||||
|
||||
def test_language_from_config():
|
||||
English.from_config()
|
||||
# TODO: add more tests
|
||||
def test_language_from_config_before_after_init():
|
||||
name = "test_language_from_config_before_after_init"
|
||||
ran_before = False
|
||||
ran_after = False
|
||||
ran_after_pipeline = False
|
||||
|
||||
@registry.callbacks(f"{name}_before")
|
||||
def make_before_creation():
|
||||
def before_creation(lang_cls):
|
||||
nonlocal ran_before
|
||||
ran_before = True
|
||||
assert lang_cls is English
|
||||
lang_cls.Defaults.foo = "bar"
|
||||
return lang_cls
|
||||
|
||||
return before_creation
|
||||
|
||||
@registry.callbacks(f"{name}_after")
|
||||
def make_after_creation():
|
||||
def after_creation(nlp):
|
||||
nonlocal ran_after
|
||||
ran_after = True
|
||||
assert isinstance(nlp, English)
|
||||
assert nlp.pipe_names == []
|
||||
assert nlp.Defaults.foo == "bar"
|
||||
nlp.meta["foo"] = "bar"
|
||||
return nlp
|
||||
|
||||
return after_creation
|
||||
|
||||
@registry.callbacks(f"{name}_after_pipeline")
|
||||
def make_after_pipeline_creation():
|
||||
def after_pipeline_creation(nlp):
|
||||
nonlocal ran_after_pipeline
|
||||
ran_after_pipeline = True
|
||||
assert isinstance(nlp, English)
|
||||
assert nlp.pipe_names == ["sentencizer"]
|
||||
assert nlp.Defaults.foo == "bar"
|
||||
assert nlp.meta["foo"] == "bar"
|
||||
nlp.meta["bar"] = "baz"
|
||||
return nlp
|
||||
|
||||
return after_pipeline_creation
|
||||
|
||||
config = {
|
||||
"nlp": {
|
||||
"pipeline": ["sentencizer"],
|
||||
"before_creation": {"@callbacks": f"{name}_before"},
|
||||
"after_creation": {"@callbacks": f"{name}_after"},
|
||||
"after_pipeline_creation": {"@callbacks": f"{name}_after_pipeline"},
|
||||
},
|
||||
"components": {"sentencizer": {"factory": "sentencizer"}},
|
||||
}
|
||||
nlp = English.from_config(config)
|
||||
assert all([ran_before, ran_after, ran_after_pipeline])
|
||||
assert nlp.Defaults.foo == "bar"
|
||||
assert nlp.meta["foo"] == "bar"
|
||||
assert nlp.meta["bar"] == "baz"
|
||||
assert nlp.pipe_names == ["sentencizer"]
|
||||
assert nlp("text")
|
||||
|
||||
|
||||
def test_language_from_config_before_after_init_invalid():
|
||||
"""Check that an error is raised if function doesn't return nlp."""
|
||||
name = "test_language_from_config_before_after_init_invalid"
|
||||
registry.callbacks(f"{name}_before1", func=lambda: lambda nlp: None)
|
||||
registry.callbacks(f"{name}_before2", func=lambda: lambda nlp: nlp())
|
||||
registry.callbacks(f"{name}_after1", func=lambda: lambda nlp: None)
|
||||
registry.callbacks(f"{name}_after1", func=lambda: lambda nlp: English)
|
||||
|
||||
for callback_name in [f"{name}_before1", f"{name}_before2"]:
|
||||
config = {"nlp": {"before_creation": {"@callbacks": callback_name}}}
|
||||
with pytest.raises(ValueError):
|
||||
English.from_config(config)
|
||||
for callback_name in [f"{name}_after1", f"{name}_after2"]:
|
||||
config = {"nlp": {"after_creation": {"@callbacks": callback_name}}}
|
||||
with pytest.raises(ValueError):
|
||||
English.from_config(config)
|
||||
for callback_name in [f"{name}_after1", f"{name}_after2"]:
|
||||
config = {"nlp": {"after_pipeline_creation": {"@callbacks": callback_name}}}
|
||||
with pytest.raises(ValueError):
|
||||
English.from_config(config)
|
||||
|
|
|
@ -67,6 +67,8 @@ class registry(thinc.registry):
|
|||
lookups = catalogue.create("spacy", "lookups", entry_points=True)
|
||||
displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True)
|
||||
assets = catalogue.create("spacy", "assets", entry_points=True)
|
||||
# Callback functions used to manipulate nlp object etc.
|
||||
callbacks = catalogue.create("spacy", "callbacks")
|
||||
batchers = catalogue.create("spacy", "batchers", entry_points=True)
|
||||
readers = catalogue.create("spacy", "readers", entry_points=True)
|
||||
# These are factories registered via third-party packages and the
|
||||
|
|
Loading…
Reference in New Issue
Block a user