spaCy/spacy/tests/pipeline/test_initialize.py

70 lines
2.3 KiB
Python
Raw Normal View History

2020-09-29 13:54:52 +03:00
import pytest
from spacy.language import Language
from spacy.lang.en import English
from spacy.training import Example
from thinc.api import ConfigValidationError
from pydantic import StrictBool
def test_initialize_arguments():
name = "test_initialize_arguments"
2020-09-29 16:23:34 +03:00
class CustomTokenizer:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.from_initialize = None
def __call__(self, text):
return self.tokenizer(text)
def initialize(self, get_examples, nlp, custom: int):
self.from_initialize = custom
2020-09-29 13:54:52 +03:00
class Component:
def __init__(self):
2020-09-29 16:23:34 +03:00
self.from_initialize = None
2020-09-29 13:54:52 +03:00
def initialize(
self, get_examples, nlp, custom1: str, custom2: StrictBool = False
):
2020-09-29 16:23:34 +03:00
self.from_initialize = (custom1, custom2)
2020-09-29 13:54:52 +03:00
Language.factory(name, func=lambda nlp, name: Component())
nlp = English()
2020-09-29 16:23:34 +03:00
nlp.tokenizer = CustomTokenizer(nlp.tokenizer)
2020-09-29 13:54:52 +03:00
example = Example.from_dict(nlp("x"), {})
get_examples = lambda: [example]
nlp.add_pipe(name)
# The settings here will typically come from the [initialize] block
init_cfg = {"tokenizer": {"custom": 1}, "components": {name: {}}}
nlp.config["initialize"].update(init_cfg)
2020-09-29 13:54:52 +03:00
with pytest.raises(ConfigValidationError) as e:
# Empty config for component, no required custom1 argument
nlp.initialize(get_examples)
2020-09-29 13:54:52 +03:00
errors = e.value.errors
assert len(errors) == 1
assert errors[0]["loc"] == ("custom1",)
assert errors[0]["type"] == "value_error.missing"
init_cfg = {
"tokenizer": {"custom": 1},
"components": {name: {"custom1": "x", "custom2": 1}},
}
nlp.config["initialize"].update(init_cfg)
2020-09-29 13:54:52 +03:00
with pytest.raises(ConfigValidationError) as e:
# Wrong type of custom 2
nlp.initialize(get_examples)
2020-09-29 13:54:52 +03:00
errors = e.value.errors
assert len(errors) == 1
assert errors[0]["loc"] == ("custom2",)
assert errors[0]["type"] == "value_error.strictbool"
init_cfg = {
2020-09-29 16:23:34 +03:00
"tokenizer": {"custom": 1},
2020-09-29 18:00:40 +03:00
"components": {name: {"custom1": "x"}},
2020-09-29 16:23:34 +03:00
}
nlp.config["initialize"].update(init_cfg)
nlp.initialize(get_examples)
2020-09-29 16:23:34 +03:00
assert nlp.tokenizer.from_initialize == 1
pipe = nlp.get_pipe(name)
2020-09-29 18:00:40 +03:00
assert pipe.from_initialize == ("x", False)