Merge pull request #7237 from adrianeboyd/bugfix/is-cython-func-7224

This commit is contained in:
Ines Montani 2021-03-03 00:05:16 +11:00 committed by GitHub
commit 635ae55b74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 5 deletions

View File

@ -7,7 +7,7 @@ from spacy import util
from spacy import prefer_gpu, require_gpu, require_cpu
from spacy.ml._precomputable_affine import PrecomputableAffine
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
from spacy.util import dot_to_object, SimpleFrozenList
from spacy.util import dot_to_object, SimpleFrozenList, import_file
from thinc.api import Config, Optimizer, ConfigValidationError
from spacy.training.batchers import minibatch_by_words
from spacy.lang.en import English
@ -17,7 +17,7 @@ from spacy.schemas import ConfigSchemaTraining
from thinc.api import get_current_ops, NumpyOps, CupyOps
from .util import get_random_doc
from .util import get_random_doc, make_tempdir
@pytest.fixture
@ -347,3 +347,35 @@ def test_resolve_dot_names():
errors = e.value.errors
assert len(errors) == 1
assert errors[0]["loc"] == ["training", "xyz"]
def test_import_code():
code_str = """
from spacy import Language
class DummyComponent:
def __init__(self, vocab, name):
pass
def initialize(self, get_examples, *, nlp, dummy_param: int):
pass
@Language.factory(
"dummy_component",
)
def make_dummy_component(
nlp: Language, name: str
):
return DummyComponent(nlp.vocab, name)
"""
with make_tempdir() as temp_dir:
code_path = os.path.join(temp_dir, "code.py")
with open(code_path, "w") as fileh:
fileh.write(code_str)
import_file("python_code", code_path)
config = {"initialize": {"components": {"dummy_component": {"dummy_param": 1}}}}
nlp = English.from_config(config)
nlp.add_pipe("dummy_component")
nlp.initialize()

View File

@ -1454,7 +1454,8 @@ def is_cython_func(func: Callable) -> bool:
if hasattr(func, attr): # function or class instance
return True
# https://stackoverflow.com/a/55767059
if hasattr(func, "__qualname__") and hasattr(func, "__module__"): # method
if hasattr(func, "__qualname__") and hasattr(func, "__module__") \
and func.__module__ in sys.modules: # method
cls_func = vars(sys.modules[func.__module__])[func.__qualname__.split(".")[0]]
return hasattr(cls_func, attr)
return False