mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
Add config callbacks for modifying nlp object before and after init (#5866)
* WIP: Concept for modifying nlp object before and after init * Make callbacks return nlp object Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> * Raise if callbacks don't return correct type * Rename, update types, add after_pipeline_creation Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
parent
586d695775
commit
823e533dc1
|
@ -12,6 +12,9 @@ use_pytorch_for_gpu_memory = false
|
||||||
lang = null
|
lang = null
|
||||||
pipeline = []
|
pipeline = []
|
||||||
load_vocab_data = true
|
load_vocab_data = true
|
||||||
|
before_creation = null
|
||||||
|
after_creation = null
|
||||||
|
after_pipeline_creation = null
|
||||||
|
|
||||||
[nlp.tokenizer]
|
[nlp.tokenizer]
|
||||||
@tokenizers = "spacy.Tokenizer.v1"
|
@tokenizers = "spacy.Tokenizer.v1"
|
||||||
|
|
|
@ -482,6 +482,13 @@ class Errors:
|
||||||
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
E942 = ("Executing after_{name} callback failed. Expected the function to "
|
||||||
|
"return an initialized nlp object but got: {value}. Maybe "
|
||||||
|
"you forgot to return the modified object in your function?")
|
||||||
|
E943 = ("Executing before_creation callback failed. Expected the function to "
|
||||||
|
"return an uninitialized Language subclass but got: {value}. Maybe "
|
||||||
|
"you forgot to return the modified object in your function or "
|
||||||
|
"returned the initialized nlp object instead?")
|
||||||
E944 = ("Can't copy pipeline component '{name}' from source model '{model}': "
|
E944 = ("Can't copy pipeline component '{name}' from source model '{model}': "
|
||||||
"not found in pipeline. Available components: {opts}")
|
"not found in pipeline. Available components: {opts}")
|
||||||
E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "
|
E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "
|
||||||
|
|
|
@ -1464,11 +1464,27 @@ class Language:
|
||||||
config["components"] = orig_pipeline
|
config["components"] = orig_pipeline
|
||||||
create_tokenizer = resolved["nlp"]["tokenizer"]
|
create_tokenizer = resolved["nlp"]["tokenizer"]
|
||||||
create_lemmatizer = resolved["nlp"]["lemmatizer"]
|
create_lemmatizer = resolved["nlp"]["lemmatizer"]
|
||||||
nlp = cls(
|
before_creation = resolved["nlp"]["before_creation"]
|
||||||
|
after_creation = resolved["nlp"]["after_creation"]
|
||||||
|
after_pipeline_creation = resolved["nlp"]["after_pipeline_creation"]
|
||||||
|
lang_cls = cls
|
||||||
|
if before_creation is not None:
|
||||||
|
lang_cls = before_creation(cls)
|
||||||
|
if (
|
||||||
|
not isinstance(lang_cls, type)
|
||||||
|
or not issubclass(lang_cls, cls)
|
||||||
|
or lang_cls is not cls
|
||||||
|
):
|
||||||
|
raise ValueError(Errors.E943.format(value=type(lang_cls)))
|
||||||
|
nlp = lang_cls(
|
||||||
vocab=vocab,
|
vocab=vocab,
|
||||||
create_tokenizer=create_tokenizer,
|
create_tokenizer=create_tokenizer,
|
||||||
create_lemmatizer=create_lemmatizer,
|
create_lemmatizer=create_lemmatizer,
|
||||||
)
|
)
|
||||||
|
if after_creation is not None:
|
||||||
|
nlp = after_creation(nlp)
|
||||||
|
if not isinstance(nlp, cls):
|
||||||
|
raise ValueError(Errors.E942.format(name="creation", value=type(nlp)))
|
||||||
# Note that we don't load vectors here, instead they get loaded explicitly
|
# Note that we don't load vectors here, instead they get loaded explicitly
|
||||||
# inside stuff like the spacy train function. If we loaded them here,
|
# inside stuff like the spacy train function. If we loaded them here,
|
||||||
# then we would load them twice at runtime: once when we make from config,
|
# then we would load them twice at runtime: once when we make from config,
|
||||||
|
@ -1509,6 +1525,12 @@ class Language:
|
||||||
nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
|
nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
|
||||||
nlp.config = filled if auto_fill else config
|
nlp.config = filled if auto_fill else config
|
||||||
nlp.resolved = resolved
|
nlp.resolved = resolved
|
||||||
|
if after_pipeline_creation is not None:
|
||||||
|
nlp = after_pipeline_creation(nlp)
|
||||||
|
if not isinstance(nlp, cls):
|
||||||
|
raise ValueError(
|
||||||
|
Errors.E942.format(name="pipeline_creation", value=type(nlp))
|
||||||
|
)
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
def to_disk(
|
def to_disk(
|
||||||
|
|
|
@ -222,6 +222,9 @@ class ConfigSchemaNlp(BaseModel):
|
||||||
tokenizer: Callable = Field(..., title="The tokenizer to use")
|
tokenizer: Callable = Field(..., title="The tokenizer to use")
|
||||||
lemmatizer: Callable = Field(..., title="The lemmatizer to use")
|
lemmatizer: Callable = Field(..., title="The lemmatizer to use")
|
||||||
load_vocab_data: StrictBool = Field(..., title="Whether to load additional vocab data from spacy-lookups-data")
|
load_vocab_data: StrictBool = Field(..., title="Whether to load additional vocab data from spacy-lookups-data")
|
||||||
|
before_creation: Optional[Callable[[Type["Language"]], Type["Language"]]] = Field(..., title="Optional callback to modify Language class before initialization")
|
||||||
|
after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed")
|
||||||
|
after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|
|
@ -3,10 +3,11 @@ import pytest
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.tokens import Doc, Span
|
from spacy.tokens import Doc, Span
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
|
from spacy.util import registry
|
||||||
|
|
||||||
from .util import add_vecs_to_vocab, assert_docs_equal
|
from .util import add_vecs_to_vocab, assert_docs_equal
|
||||||
from ..gold import Example
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -153,6 +154,85 @@ def test_language_pipe_stream(nlp2, n_process, texts):
|
||||||
assert_docs_equal(doc, expected_doc)
|
assert_docs_equal(doc, expected_doc)
|
||||||
|
|
||||||
|
|
||||||
def test_language_from_config():
|
def test_language_from_config_before_after_init():
|
||||||
English.from_config()
|
name = "test_language_from_config_before_after_init"
|
||||||
# TODO: add more tests
|
ran_before = False
|
||||||
|
ran_after = False
|
||||||
|
ran_after_pipeline = False
|
||||||
|
|
||||||
|
@registry.callbacks(f"{name}_before")
|
||||||
|
def make_before_creation():
|
||||||
|
def before_creation(lang_cls):
|
||||||
|
nonlocal ran_before
|
||||||
|
ran_before = True
|
||||||
|
assert lang_cls is English
|
||||||
|
lang_cls.Defaults.foo = "bar"
|
||||||
|
return lang_cls
|
||||||
|
|
||||||
|
return before_creation
|
||||||
|
|
||||||
|
@registry.callbacks(f"{name}_after")
|
||||||
|
def make_after_creation():
|
||||||
|
def after_creation(nlp):
|
||||||
|
nonlocal ran_after
|
||||||
|
ran_after = True
|
||||||
|
assert isinstance(nlp, English)
|
||||||
|
assert nlp.pipe_names == []
|
||||||
|
assert nlp.Defaults.foo == "bar"
|
||||||
|
nlp.meta["foo"] = "bar"
|
||||||
|
return nlp
|
||||||
|
|
||||||
|
return after_creation
|
||||||
|
|
||||||
|
@registry.callbacks(f"{name}_after_pipeline")
|
||||||
|
def make_after_pipeline_creation():
|
||||||
|
def after_pipeline_creation(nlp):
|
||||||
|
nonlocal ran_after_pipeline
|
||||||
|
ran_after_pipeline = True
|
||||||
|
assert isinstance(nlp, English)
|
||||||
|
assert nlp.pipe_names == ["sentencizer"]
|
||||||
|
assert nlp.Defaults.foo == "bar"
|
||||||
|
assert nlp.meta["foo"] == "bar"
|
||||||
|
nlp.meta["bar"] = "baz"
|
||||||
|
return nlp
|
||||||
|
|
||||||
|
return after_pipeline_creation
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"nlp": {
|
||||||
|
"pipeline": ["sentencizer"],
|
||||||
|
"before_creation": {"@callbacks": f"{name}_before"},
|
||||||
|
"after_creation": {"@callbacks": f"{name}_after"},
|
||||||
|
"after_pipeline_creation": {"@callbacks": f"{name}_after_pipeline"},
|
||||||
|
},
|
||||||
|
"components": {"sentencizer": {"factory": "sentencizer"}},
|
||||||
|
}
|
||||||
|
nlp = English.from_config(config)
|
||||||
|
assert all([ran_before, ran_after, ran_after_pipeline])
|
||||||
|
assert nlp.Defaults.foo == "bar"
|
||||||
|
assert nlp.meta["foo"] == "bar"
|
||||||
|
assert nlp.meta["bar"] == "baz"
|
||||||
|
assert nlp.pipe_names == ["sentencizer"]
|
||||||
|
assert nlp("text")
|
||||||
|
|
||||||
|
|
||||||
|
def test_language_from_config_before_after_init_invalid():
|
||||||
|
"""Check that an error is raised if function doesn't return nlp."""
|
||||||
|
name = "test_language_from_config_before_after_init_invalid"
|
||||||
|
registry.callbacks(f"{name}_before1", func=lambda: lambda nlp: None)
|
||||||
|
registry.callbacks(f"{name}_before2", func=lambda: lambda nlp: nlp())
|
||||||
|
registry.callbacks(f"{name}_after1", func=lambda: lambda nlp: None)
|
||||||
|
registry.callbacks(f"{name}_after1", func=lambda: lambda nlp: English)
|
||||||
|
|
||||||
|
for callback_name in [f"{name}_before1", f"{name}_before2"]:
|
||||||
|
config = {"nlp": {"before_creation": {"@callbacks": callback_name}}}
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
English.from_config(config)
|
||||||
|
for callback_name in [f"{name}_after1", f"{name}_after2"]:
|
||||||
|
config = {"nlp": {"after_creation": {"@callbacks": callback_name}}}
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
English.from_config(config)
|
||||||
|
for callback_name in [f"{name}_after1", f"{name}_after2"]:
|
||||||
|
config = {"nlp": {"after_pipeline_creation": {"@callbacks": callback_name}}}
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
English.from_config(config)
|
||||||
|
|
|
@ -67,6 +67,8 @@ class registry(thinc.registry):
|
||||||
lookups = catalogue.create("spacy", "lookups", entry_points=True)
|
lookups = catalogue.create("spacy", "lookups", entry_points=True)
|
||||||
displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True)
|
displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True)
|
||||||
assets = catalogue.create("spacy", "assets", entry_points=True)
|
assets = catalogue.create("spacy", "assets", entry_points=True)
|
||||||
|
# Callback functions used to manipulate nlp object etc.
|
||||||
|
callbacks = catalogue.create("spacy", "callbacks")
|
||||||
batchers = catalogue.create("spacy", "batchers", entry_points=True)
|
batchers = catalogue.create("spacy", "batchers", entry_points=True)
|
||||||
readers = catalogue.create("spacy", "readers", entry_points=True)
|
readers = catalogue.create("spacy", "readers", entry_points=True)
|
||||||
# These are factories registered via third-party packages and the
|
# These are factories registered via third-party packages and the
|
||||||
|
|
Loading…
Reference in New Issue
Block a user