diff --git a/spacy/compat.py b/spacy/compat.py index 35f6ecef3..2a551a831 100644 --- a/spacy/compat.py +++ b/spacy/compat.py @@ -16,6 +16,17 @@ try: except ImportError: import copyreg as copy_reg +try: + from cupy.cuda.stream import Stream as CudaStream +except ImportError: + CudaStream = None + +try: + import cupy +except ImportError: + cupy = None + + pickle = pickle copy_reg = copy_reg CudaStream = CudaStream diff --git a/spacy/util.py b/spacy/util.py index 717e4f160..0fbe9f92d 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -10,20 +10,12 @@ import sys import textwrap from .symbols import ORTH -from .compat import path2str, basestring_, input_, unicode_ +from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_ LANGUAGES = {} _data_path = Path(__file__).parent / 'data' -try: - from cupy.cuda.stream import Stream as CudaStream -except ImportError: - CudaStream = None -try: - import cupy -except ImportError: - cupy = None def get_lang_class(lang): """Import and load a Language class.