Fix is_cython_func for additional imported code

* Fix `is_cython_func` for imported code loaded under `python_code`
module name
* Add `make_named_tempfile` context manager to test utils to test
loading of imported code
* Add test for validation of `initialize` params in custom module
This commit is contained in:
Adriane Boyd 2021-03-01 16:32:31 +01:00
parent 10c930cc96
commit e9f7f9a4bc
3 changed files with 44 additions and 5 deletions

View File

@ -7,7 +7,7 @@ from spacy import util
from spacy import prefer_gpu, require_gpu, require_cpu from spacy import prefer_gpu, require_gpu, require_cpu
from spacy.ml._precomputable_affine import PrecomputableAffine from spacy.ml._precomputable_affine import PrecomputableAffine
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
from spacy.util import dot_to_object, SimpleFrozenList from spacy.util import dot_to_object, SimpleFrozenList, import_file
from thinc.api import Config, Optimizer, ConfigValidationError from thinc.api import Config, Optimizer, ConfigValidationError
from spacy.training.batchers import minibatch_by_words from spacy.training.batchers import minibatch_by_words
from spacy.lang.en import English from spacy.lang.en import English
@ -17,7 +17,7 @@ from spacy.schemas import ConfigSchemaTraining
from thinc.api import get_current_ops, NumpyOps, CupyOps from thinc.api import get_current_ops, NumpyOps, CupyOps
from .util import get_random_doc from .util import get_random_doc, make_named_tempfile
@pytest.fixture @pytest.fixture
@ -347,3 +347,34 @@ def test_resolve_dot_names():
errors = e.value.errors errors = e.value.errors
assert len(errors) == 1 assert len(errors) == 1
assert errors[0]["loc"] == ["training", "xyz"] assert errors[0]["loc"] == ["training", "xyz"]
def test_import_code():
code_str = """
from spacy import Language
class DummyComponent:
def __init__(self, vocab, name):
pass
def initialize(self, get_examples, *, nlp, dummy_param: int):
pass
@Language.factory(
"dummy_component",
)
def make_dummy_component(
nlp: Language, name: str
):
return DummyComponent(nlp.vocab, name)
"""
with make_named_tempfile(mode="w", suffix=".py") as fileh:
fileh.write(code_str)
fileh.flush()
import_file("python_code", fileh.name)
config = {"initialize": {"components": {"dummy_component": {"dummy_param": 1}}}}
nlp = English.from_config(config)
nlp.add_pipe("dummy_component")
nlp.initialize()

View File

@ -14,6 +14,13 @@ def make_tempfile(mode="r"):
f.close() f.close()
@contextlib.contextmanager
def make_named_tempfile(mode="r", suffix=None):
f = tempfile.NamedTemporaryFile(mode=mode, suffix=suffix)
yield f
f.close()
def get_batch(batch_size): def get_batch(batch_size):
vocab = Vocab() vocab = Vocab()
docs = [] docs = []

View File

@ -1454,9 +1454,10 @@ def is_cython_func(func: Callable) -> bool:
if hasattr(func, attr): # function or class instance if hasattr(func, attr): # function or class instance
return True return True
# https://stackoverflow.com/a/55767059 # https://stackoverflow.com/a/55767059
if hasattr(func, "__qualname__") and hasattr(func, "__module__"): # method if hasattr(func, "__qualname__") and hasattr(func, "__module__") \
cls_func = vars(sys.modules[func.__module__])[func.__qualname__.split(".")[0]] and func.__module__ in sys.modules: # method
return hasattr(cls_func, attr) cls_func = vars(sys.modules[func.__module__])[func.__qualname__.split(".")[0]]
return hasattr(cls_func, attr)
return False return False