mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Add config callbacks for modifying nlp object before and after init (#5866)
* WIP: Concept for modifying nlp object before and after init * Make callbacks return nlp object Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> * Raise if callbacks don't return correct type * Rename, update types, add after_pipeline_creation Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
		
							parent
							
								
									586d695775
								
							
						
					
					
						commit
						823e533dc1
					
				| 
						 | 
					@ -12,6 +12,9 @@ use_pytorch_for_gpu_memory = false
 | 
				
			||||||
lang = null
 | 
					lang = null
 | 
				
			||||||
pipeline = []
 | 
					pipeline = []
 | 
				
			||||||
load_vocab_data = true
 | 
					load_vocab_data = true
 | 
				
			||||||
 | 
					before_creation = null
 | 
				
			||||||
 | 
					after_creation = null
 | 
				
			||||||
 | 
					after_pipeline_creation = null
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[nlp.tokenizer]
 | 
					[nlp.tokenizer]
 | 
				
			||||||
@tokenizers = "spacy.Tokenizer.v1"
 | 
					@tokenizers = "spacy.Tokenizer.v1"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -482,6 +482,13 @@ class Errors:
 | 
				
			||||||
    E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
 | 
					    E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # TODO: fix numbering after merging develop into master
 | 
					    # TODO: fix numbering after merging develop into master
 | 
				
			||||||
 | 
					    E942 = ("Executing after_{name} callback failed. Expected the function to "
 | 
				
			||||||
 | 
					            "return an initialized nlp object but got: {value}. Maybe "
 | 
				
			||||||
 | 
					            "you forgot to return the modified object in your function?")
 | 
				
			||||||
 | 
					    E943 = ("Executing before_creation callback failed. Expected the function to "
 | 
				
			||||||
 | 
					            "return an uninitialized Language subclass but got: {value}. Maybe "
 | 
				
			||||||
 | 
					            "you forgot to return the modified object in your function or "
 | 
				
			||||||
 | 
					            "returned the initialized nlp object instead?")
 | 
				
			||||||
    E944 = ("Can't copy pipeline component '{name}' from source model '{model}': "
 | 
					    E944 = ("Can't copy pipeline component '{name}' from source model '{model}': "
 | 
				
			||||||
            "not found in pipeline. Available components: {opts}")
 | 
					            "not found in pipeline. Available components: {opts}")
 | 
				
			||||||
    E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "
 | 
					    E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1464,11 +1464,27 @@ class Language:
 | 
				
			||||||
        config["components"] = orig_pipeline
 | 
					        config["components"] = orig_pipeline
 | 
				
			||||||
        create_tokenizer = resolved["nlp"]["tokenizer"]
 | 
					        create_tokenizer = resolved["nlp"]["tokenizer"]
 | 
				
			||||||
        create_lemmatizer = resolved["nlp"]["lemmatizer"]
 | 
					        create_lemmatizer = resolved["nlp"]["lemmatizer"]
 | 
				
			||||||
        nlp = cls(
 | 
					        before_creation = resolved["nlp"]["before_creation"]
 | 
				
			||||||
 | 
					        after_creation = resolved["nlp"]["after_creation"]
 | 
				
			||||||
 | 
					        after_pipeline_creation = resolved["nlp"]["after_pipeline_creation"]
 | 
				
			||||||
 | 
					        lang_cls = cls
 | 
				
			||||||
 | 
					        if before_creation is not None:
 | 
				
			||||||
 | 
					            lang_cls = before_creation(cls)
 | 
				
			||||||
 | 
					            if (
 | 
				
			||||||
 | 
					                not isinstance(lang_cls, type)
 | 
				
			||||||
 | 
					                or not issubclass(lang_cls, cls)
 | 
				
			||||||
 | 
					                or lang_cls is not cls
 | 
				
			||||||
 | 
					            ):
 | 
				
			||||||
 | 
					                raise ValueError(Errors.E943.format(value=type(lang_cls)))
 | 
				
			||||||
 | 
					        nlp = lang_cls(
 | 
				
			||||||
            vocab=vocab,
 | 
					            vocab=vocab,
 | 
				
			||||||
            create_tokenizer=create_tokenizer,
 | 
					            create_tokenizer=create_tokenizer,
 | 
				
			||||||
            create_lemmatizer=create_lemmatizer,
 | 
					            create_lemmatizer=create_lemmatizer,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					        if after_creation is not None:
 | 
				
			||||||
 | 
					            nlp = after_creation(nlp)
 | 
				
			||||||
 | 
					            if not isinstance(nlp, cls):
 | 
				
			||||||
 | 
					                raise ValueError(Errors.E942.format(name="creation", value=type(nlp)))
 | 
				
			||||||
        # Note that we don't load vectors here, instead they get loaded explicitly
 | 
					        # Note that we don't load vectors here, instead they get loaded explicitly
 | 
				
			||||||
        # inside stuff like the spacy train function. If we loaded them here,
 | 
					        # inside stuff like the spacy train function. If we loaded them here,
 | 
				
			||||||
        # then we would load them twice at runtime: once when we make from config,
 | 
					        # then we would load them twice at runtime: once when we make from config,
 | 
				
			||||||
| 
						 | 
					@ -1509,6 +1525,12 @@ class Language:
 | 
				
			||||||
                    nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
 | 
					                    nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
 | 
				
			||||||
        nlp.config = filled if auto_fill else config
 | 
					        nlp.config = filled if auto_fill else config
 | 
				
			||||||
        nlp.resolved = resolved
 | 
					        nlp.resolved = resolved
 | 
				
			||||||
 | 
					        if after_pipeline_creation is not None:
 | 
				
			||||||
 | 
					            nlp = after_pipeline_creation(nlp)
 | 
				
			||||||
 | 
					            if not isinstance(nlp, cls):
 | 
				
			||||||
 | 
					                raise ValueError(
 | 
				
			||||||
 | 
					                    Errors.E942.format(name="pipeline_creation", value=type(nlp))
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
        return nlp
 | 
					        return nlp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def to_disk(
 | 
					    def to_disk(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -222,6 +222,9 @@ class ConfigSchemaNlp(BaseModel):
 | 
				
			||||||
    tokenizer: Callable = Field(..., title="The tokenizer to use")
 | 
					    tokenizer: Callable = Field(..., title="The tokenizer to use")
 | 
				
			||||||
    lemmatizer: Callable = Field(..., title="The lemmatizer to use")
 | 
					    lemmatizer: Callable = Field(..., title="The lemmatizer to use")
 | 
				
			||||||
    load_vocab_data: StrictBool = Field(..., title="Whether to load additional vocab data from spacy-lookups-data")
 | 
					    load_vocab_data: StrictBool = Field(..., title="Whether to load additional vocab data from spacy-lookups-data")
 | 
				
			||||||
 | 
					    before_creation: Optional[Callable[[Type["Language"]], Type["Language"]]] = Field(..., title="Optional callback to modify Language class before initialization")
 | 
				
			||||||
 | 
					    after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed")
 | 
				
			||||||
 | 
					    after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed")
 | 
				
			||||||
    # fmt: on
 | 
					    # fmt: on
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class Config:
 | 
					    class Config:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,10 +3,11 @@ import pytest
 | 
				
			||||||
from spacy.language import Language
 | 
					from spacy.language import Language
 | 
				
			||||||
from spacy.tokens import Doc, Span
 | 
					from spacy.tokens import Doc, Span
 | 
				
			||||||
from spacy.vocab import Vocab
 | 
					from spacy.vocab import Vocab
 | 
				
			||||||
 | 
					from spacy.gold import Example
 | 
				
			||||||
from spacy.lang.en import English
 | 
					from spacy.lang.en import English
 | 
				
			||||||
 | 
					from spacy.util import registry
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .util import add_vecs_to_vocab, assert_docs_equal
 | 
					from .util import add_vecs_to_vocab, assert_docs_equal
 | 
				
			||||||
from ..gold import Example
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture
 | 
					@pytest.fixture
 | 
				
			||||||
| 
						 | 
					@ -153,6 +154,85 @@ def test_language_pipe_stream(nlp2, n_process, texts):
 | 
				
			||||||
        assert_docs_equal(doc, expected_doc)
 | 
					        assert_docs_equal(doc, expected_doc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_language_from_config():
 | 
					def test_language_from_config_before_after_init():
 | 
				
			||||||
    English.from_config()
 | 
					    name = "test_language_from_config_before_after_init"
 | 
				
			||||||
    # TODO: add more tests
 | 
					    ran_before = False
 | 
				
			||||||
 | 
					    ran_after = False
 | 
				
			||||||
 | 
					    ran_after_pipeline = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @registry.callbacks(f"{name}_before")
 | 
				
			||||||
 | 
					    def make_before_creation():
 | 
				
			||||||
 | 
					        def before_creation(lang_cls):
 | 
				
			||||||
 | 
					            nonlocal ran_before
 | 
				
			||||||
 | 
					            ran_before = True
 | 
				
			||||||
 | 
					            assert lang_cls is English
 | 
				
			||||||
 | 
					            lang_cls.Defaults.foo = "bar"
 | 
				
			||||||
 | 
					            return lang_cls
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return before_creation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @registry.callbacks(f"{name}_after")
 | 
				
			||||||
 | 
					    def make_after_creation():
 | 
				
			||||||
 | 
					        def after_creation(nlp):
 | 
				
			||||||
 | 
					            nonlocal ran_after
 | 
				
			||||||
 | 
					            ran_after = True
 | 
				
			||||||
 | 
					            assert isinstance(nlp, English)
 | 
				
			||||||
 | 
					            assert nlp.pipe_names == []
 | 
				
			||||||
 | 
					            assert nlp.Defaults.foo == "bar"
 | 
				
			||||||
 | 
					            nlp.meta["foo"] = "bar"
 | 
				
			||||||
 | 
					            return nlp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return after_creation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @registry.callbacks(f"{name}_after_pipeline")
 | 
				
			||||||
 | 
					    def make_after_pipeline_creation():
 | 
				
			||||||
 | 
					        def after_pipeline_creation(nlp):
 | 
				
			||||||
 | 
					            nonlocal ran_after_pipeline
 | 
				
			||||||
 | 
					            ran_after_pipeline = True
 | 
				
			||||||
 | 
					            assert isinstance(nlp, English)
 | 
				
			||||||
 | 
					            assert nlp.pipe_names == ["sentencizer"]
 | 
				
			||||||
 | 
					            assert nlp.Defaults.foo == "bar"
 | 
				
			||||||
 | 
					            assert nlp.meta["foo"] == "bar"
 | 
				
			||||||
 | 
					            nlp.meta["bar"] = "baz"
 | 
				
			||||||
 | 
					            return nlp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return after_pipeline_creation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    config = {
 | 
				
			||||||
 | 
					        "nlp": {
 | 
				
			||||||
 | 
					            "pipeline": ["sentencizer"],
 | 
				
			||||||
 | 
					            "before_creation": {"@callbacks": f"{name}_before"},
 | 
				
			||||||
 | 
					            "after_creation": {"@callbacks": f"{name}_after"},
 | 
				
			||||||
 | 
					            "after_pipeline_creation": {"@callbacks": f"{name}_after_pipeline"},
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        "components": {"sentencizer": {"factory": "sentencizer"}},
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    nlp = English.from_config(config)
 | 
				
			||||||
 | 
					    assert all([ran_before, ran_after, ran_after_pipeline])
 | 
				
			||||||
 | 
					    assert nlp.Defaults.foo == "bar"
 | 
				
			||||||
 | 
					    assert nlp.meta["foo"] == "bar"
 | 
				
			||||||
 | 
					    assert nlp.meta["bar"] == "baz"
 | 
				
			||||||
 | 
					    assert nlp.pipe_names == ["sentencizer"]
 | 
				
			||||||
 | 
					    assert nlp("text")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_language_from_config_before_after_init_invalid():
 | 
				
			||||||
 | 
					    """Check that an error is raised if function doesn't return nlp."""
 | 
				
			||||||
 | 
					    name = "test_language_from_config_before_after_init_invalid"
 | 
				
			||||||
 | 
					    registry.callbacks(f"{name}_before1", func=lambda: lambda nlp: None)
 | 
				
			||||||
 | 
					    registry.callbacks(f"{name}_before2", func=lambda: lambda nlp: nlp())
 | 
				
			||||||
 | 
					    registry.callbacks(f"{name}_after1", func=lambda: lambda nlp: None)
 | 
				
			||||||
 | 
					    registry.callbacks(f"{name}_after1", func=lambda: lambda nlp: English)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for callback_name in [f"{name}_before1", f"{name}_before2"]:
 | 
				
			||||||
 | 
					        config = {"nlp": {"before_creation": {"@callbacks": callback_name}}}
 | 
				
			||||||
 | 
					        with pytest.raises(ValueError):
 | 
				
			||||||
 | 
					            English.from_config(config)
 | 
				
			||||||
 | 
					    for callback_name in [f"{name}_after1", f"{name}_after2"]:
 | 
				
			||||||
 | 
					        config = {"nlp": {"after_creation": {"@callbacks": callback_name}}}
 | 
				
			||||||
 | 
					        with pytest.raises(ValueError):
 | 
				
			||||||
 | 
					            English.from_config(config)
 | 
				
			||||||
 | 
					    for callback_name in [f"{name}_after1", f"{name}_after2"]:
 | 
				
			||||||
 | 
					        config = {"nlp": {"after_pipeline_creation": {"@callbacks": callback_name}}}
 | 
				
			||||||
 | 
					        with pytest.raises(ValueError):
 | 
				
			||||||
 | 
					            English.from_config(config)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -67,6 +67,8 @@ class registry(thinc.registry):
 | 
				
			||||||
    lookups = catalogue.create("spacy", "lookups", entry_points=True)
 | 
					    lookups = catalogue.create("spacy", "lookups", entry_points=True)
 | 
				
			||||||
    displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True)
 | 
					    displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True)
 | 
				
			||||||
    assets = catalogue.create("spacy", "assets", entry_points=True)
 | 
					    assets = catalogue.create("spacy", "assets", entry_points=True)
 | 
				
			||||||
 | 
					    # Callback functions used to manipulate nlp object etc.
 | 
				
			||||||
 | 
					    callbacks = catalogue.create("spacy", "callbacks")
 | 
				
			||||||
    batchers = catalogue.create("spacy", "batchers", entry_points=True)
 | 
					    batchers = catalogue.create("spacy", "batchers", entry_points=True)
 | 
				
			||||||
    readers = catalogue.create("spacy", "readers", entry_points=True)
 | 
					    readers = catalogue.create("spacy", "readers", entry_points=True)
 | 
				
			||||||
    # These are factories registered via third-party packages and the
 | 
					    # These are factories registered via third-party packages and the
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user