Set up GPU CI testing (#7293)

* Set up CI for tests with GPU agent

* Update tests for enabled GPU

* Fix steps filename

* Add parallel build jobs as a setting

* Fix test requirements

* Fix install test requirements condition

* Fix pipeline models test

* Reset current ops in prefer/require testing

* Fix more tests

* Remove separate test_models test

* Fix regression 5551

* fix StaticVectors for GPU use

* fix vocab tests

* Fix regression test 5082

* Move azure steps to .github and reenable default pool jobs

* Consolidate/rename azure steps

Co-authored-by: svlandeg <sofie.vanlandeghem@gmail.com>
This commit is contained in:
Adriane Boyd 2021-04-22 14:58:29 +02:00 committed by GitHub
parent bdb485cc80
commit 36ecba224e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 238 additions and 137 deletions

57
.github/azure-steps.yml vendored Normal file
View File

@ -0,0 +1,57 @@
parameters:
python_version: ''
architecture: ''
prefix: ''
gpu: false
num_build_jobs: 1
steps:
- task: UsePythonVersion@0
inputs:
versionSpec: ${{ parameters.python_version }}
architecture: ${{ parameters.architecture }}
- script: |
${{ parameters.prefix }} python -m pip install -U pip setuptools
${{ parameters.prefix }} python -m pip install -U -r requirements.txt
displayName: "Install dependencies"
- script: |
${{ parameters.prefix }} python setup.py build_ext --inplace -j ${{ parameters.num_build_jobs }}
${{ parameters.prefix }} python setup.py sdist --formats=gztar
displayName: "Compile and build sdist"
- task: DeleteFiles@1
inputs:
contents: "spacy"
displayName: "Delete source directory"
- script: |
${{ parameters.prefix }} python -m pip freeze --exclude torch --exclude cupy-cuda110 > installed.txt
${{ parameters.prefix }} python -m pip uninstall -y -r installed.txt
displayName: "Uninstall all packages"
- bash: |
${{ parameters.prefix }} SDIST=$(python -c "import os;print(os.listdir('./dist')[-1])" 2>&1)
${{ parameters.prefix }} python -m pip install dist/$SDIST
displayName: "Install from sdist"
- script: |
${{ parameters.prefix }} python -m pip install -U -r requirements.txt
displayName: "Install test requirements"
- script: |
${{ parameters.prefix }} python -m pip install -U cupy-cuda110
${{ parameters.prefix }} python -m pip install "torch==1.7.1+cu110" -f https://download.pytorch.org/whl/torch_stable.html
displayName: "Install GPU requirements"
condition: eq(${{ parameters.gpu }}, true)
- script: |
${{ parameters.prefix }} python -m pytest --pyargs spacy
displayName: "Run CPU tests"
condition: eq(${{ parameters.gpu }}, false)
- script: |
${{ parameters.prefix }} python -m pytest --pyargs spacy -p spacy.tests.enable_gpu
displayName: "Run GPU tests"
condition: eq(${{ parameters.gpu }}, true)

View File

@ -76,39 +76,24 @@ jobs:
maxParallel: 4 maxParallel: 4
pool: pool:
vmImage: $(imageName) vmImage: $(imageName)
steps: steps:
- task: UsePythonVersion@0 - template: .github/azure-steps.yml
inputs: parameters:
versionSpec: "$(python.version)" python_version: '$(python.version)'
architecture: "x64" architecture: 'x64'
- script: | - job: "TestGPU"
python -m pip install -U setuptools dependsOn: "Validate"
pip install -r requirements.txt strategy:
displayName: "Install dependencies" matrix:
Python38LinuxX64_GPU:
- script: | python.version: '3.8'
python setup.py build_ext --inplace pool:
python setup.py sdist --formats=gztar name: "LinuxX64_GPU"
displayName: "Compile and build sdist" steps:
- template: .github/azure-steps.yml
- task: DeleteFiles@1 parameters:
inputs: python_version: '$(python.version)'
contents: "spacy" architecture: 'x64'
displayName: "Delete source directory" gpu: true
num_build_jobs: 24
- script: |
pip freeze > installed.txt
pip uninstall -y -r installed.txt
displayName: "Uninstall all packages"
- bash: |
SDIST=$(python -c "import os;print(os.listdir('./dist')[-1])" 2>&1)
pip install dist/$SDIST
displayName: "Install from sdist"
- script: |
pip install -r requirements.txt
python -m pytest --pyargs spacy
displayName: "Run tests"

View File

@ -38,7 +38,7 @@ def forward(
return _handle_empty(model.ops, model.get_dim("nO")) return _handle_empty(model.ops, model.get_dim("nO"))
key_attr = model.attrs["key_attr"] key_attr = model.attrs["key_attr"]
W = cast(Floats2d, model.ops.as_contig(model.get_param("W"))) W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
V = cast(Floats2d, docs[0].vocab.vectors.data) V = cast(Floats2d, model.ops.asarray(docs[0].vocab.vectors.data))
rows = model.ops.flatten( rows = model.ops.flatten(
[doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs] [doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs]
) )

View File

@ -0,0 +1,3 @@
from spacy import require_gpu
require_gpu()

View File

@ -5,6 +5,7 @@ from spacy.tokens import Span
from spacy.language import Language from spacy.language import Language
from spacy.pipeline import EntityRuler from spacy.pipeline import EntityRuler
from spacy.errors import MatchPatternError from spacy.errors import MatchPatternError
from thinc.api import NumpyOps, get_current_ops
@pytest.fixture @pytest.fixture
@ -201,13 +202,14 @@ def test_entity_ruler_overlapping_spans(nlp):
@pytest.mark.parametrize("n_process", [1, 2]) @pytest.mark.parametrize("n_process", [1, 2])
def test_entity_ruler_multiprocessing(nlp, n_process): def test_entity_ruler_multiprocessing(nlp, n_process):
texts = ["I enjoy eating Pizza Hut pizza."] if isinstance(get_current_ops, NumpyOps) or n_process < 2:
texts = ["I enjoy eating Pizza Hut pizza."]
patterns = [{"label": "FASTFOOD", "pattern": "Pizza Hut", "id": "1234"}] patterns = [{"label": "FASTFOOD", "pattern": "Pizza Hut", "id": "1234"}]
ruler = nlp.add_pipe("entity_ruler") ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns) ruler.add_patterns(patterns)
for doc in nlp.pipe(texts, n_process=2): for doc in nlp.pipe(texts, n_process=2):
for ent in doc.ents: for ent in doc.ents:
assert ent.ent_id_ == "1234" assert ent.ent_id_ == "1234"

View File

@ -4,7 +4,7 @@ import numpy
import pytest import pytest
from numpy.testing import assert_almost_equal from numpy.testing import assert_almost_equal
from spacy.vocab import Vocab from spacy.vocab import Vocab
from thinc.api import NumpyOps, Model, data_validation from thinc.api import Model, data_validation, get_current_ops
from thinc.types import Array2d, Ragged from thinc.types import Array2d, Ragged
from spacy.lang.en import English from spacy.lang.en import English
@ -13,7 +13,7 @@ from spacy.ml._character_embed import CharacterEmbed
from spacy.tokens import Doc from spacy.tokens import Doc
OPS = NumpyOps() OPS = get_current_ops()
texts = ["These are 4 words", "Here just three"] texts = ["These are 4 words", "Here just three"]
l0 = [[1, 2], [3, 4], [5, 6], [7, 8]] l0 = [[1, 2], [3, 4], [5, 6], [7, 8]]
@ -82,7 +82,7 @@ def util_batch_unbatch_docs_list(
Y_batched = model.predict(in_data) Y_batched = model.predict(in_data)
Y_not_batched = [model.predict([u])[0] for u in in_data] Y_not_batched = [model.predict([u])[0] for u in in_data]
for i in range(len(Y_batched)): for i in range(len(Y_batched)):
assert_almost_equal(Y_batched[i], Y_not_batched[i], decimal=4) assert_almost_equal(OPS.to_numpy(Y_batched[i]), OPS.to_numpy(Y_not_batched[i]), decimal=4)
def util_batch_unbatch_docs_array( def util_batch_unbatch_docs_array(
@ -91,7 +91,7 @@ def util_batch_unbatch_docs_array(
with data_validation(True): with data_validation(True):
model.initialize(in_data, out_data) model.initialize(in_data, out_data)
Y_batched = model.predict(in_data).tolist() Y_batched = model.predict(in_data).tolist()
Y_not_batched = [model.predict([u])[0] for u in in_data] Y_not_batched = [model.predict([u])[0].tolist() for u in in_data]
assert_almost_equal(Y_batched, Y_not_batched, decimal=4) assert_almost_equal(Y_batched, Y_not_batched, decimal=4)
@ -100,8 +100,8 @@ def util_batch_unbatch_docs_ragged(
): ):
with data_validation(True): with data_validation(True):
model.initialize(in_data, out_data) model.initialize(in_data, out_data)
Y_batched = model.predict(in_data) Y_batched = model.predict(in_data).data.tolist()
Y_not_batched = [] Y_not_batched = []
for u in in_data: for u in in_data:
Y_not_batched.extend(model.predict([u]).data.tolist()) Y_not_batched.extend(model.predict([u]).data.tolist())
assert_almost_equal(Y_batched.data, Y_not_batched, decimal=4) assert_almost_equal(Y_batched, Y_not_batched, decimal=4)

View File

@ -1,7 +1,7 @@
import pytest import pytest
import random import random
import numpy.random import numpy.random
from numpy.testing import assert_equal from numpy.testing import assert_almost_equal
from thinc.api import fix_random_seed from thinc.api import fix_random_seed
from spacy import util from spacy import util
from spacy.lang.en import English from spacy.lang.en import English
@ -222,8 +222,12 @@ def test_overfitting_IO():
batch_cats_1 = [doc.cats for doc in nlp.pipe(texts)] batch_cats_1 = [doc.cats for doc in nlp.pipe(texts)]
batch_cats_2 = [doc.cats for doc in nlp.pipe(texts)] batch_cats_2 = [doc.cats for doc in nlp.pipe(texts)]
no_batch_cats = [doc.cats for doc in [nlp(text) for text in texts]] no_batch_cats = [doc.cats for doc in [nlp(text) for text in texts]]
assert_equal(batch_cats_1, batch_cats_2) for cats_1, cats_2 in zip(batch_cats_1, batch_cats_2):
assert_equal(batch_cats_1, no_batch_cats) for cat in cats_1:
assert_almost_equal(cats_1[cat], cats_2[cat], decimal=5)
for cats_1, cats_2 in zip(batch_cats_1, no_batch_cats):
for cat in cats_1:
assert_almost_equal(cats_1[cat], cats_2[cat], decimal=5)
def test_overfitting_IO_multi(): def test_overfitting_IO_multi():
@ -270,8 +274,12 @@ def test_overfitting_IO_multi():
batch_deps_1 = [doc.cats for doc in nlp.pipe(texts)] batch_deps_1 = [doc.cats for doc in nlp.pipe(texts)]
batch_deps_2 = [doc.cats for doc in nlp.pipe(texts)] batch_deps_2 = [doc.cats for doc in nlp.pipe(texts)]
no_batch_deps = [doc.cats for doc in [nlp(text) for text in texts]] no_batch_deps = [doc.cats for doc in [nlp(text) for text in texts]]
assert_equal(batch_deps_1, batch_deps_2) for cats_1, cats_2 in zip(batch_deps_1, batch_deps_2):
assert_equal(batch_deps_1, no_batch_deps) for cat in cats_1:
assert_almost_equal(cats_1[cat], cats_2[cat], decimal=5)
for cats_1, cats_2 in zip(batch_deps_1, no_batch_deps):
for cat in cats_1:
assert_almost_equal(cats_1[cat], cats_2[cat], decimal=5)
# fmt: off # fmt: off

View File

@ -8,8 +8,8 @@ from spacy.tokens import Doc
from spacy.training import Example from spacy.training import Example
from spacy import util from spacy import util
from spacy.lang.en import English from spacy.lang.en import English
from thinc.api import Config from thinc.api import Config, get_current_ops
from numpy.testing import assert_equal from numpy.testing import assert_array_equal
from ..util import get_batch, make_tempdir from ..util import get_batch, make_tempdir
@ -160,7 +160,8 @@ def test_tok2vec_listener():
doc = nlp("Running the pipeline as a whole.") doc = nlp("Running the pipeline as a whole.")
doc_tensor = tagger_tok2vec.predict([doc])[0] doc_tensor = tagger_tok2vec.predict([doc])[0]
assert_equal(doc.tensor, doc_tensor) ops = get_current_ops()
assert_array_equal(ops.to_numpy(doc.tensor), ops.to_numpy(doc_tensor))
# TODO: should this warn or error? # TODO: should this warn or error?
nlp.select_pipes(disable="tok2vec") nlp.select_pipes(disable="tok2vec")

View File

@ -9,6 +9,7 @@ from spacy.language import Language
from spacy.util import ensure_path, load_model_from_path from spacy.util import ensure_path, load_model_from_path
import numpy import numpy
import pickle import pickle
from thinc.api import NumpyOps, get_current_ops
from ..util import make_tempdir from ..util import make_tempdir
@ -169,21 +170,22 @@ def test_issue4725_1():
def test_issue4725_2(): def test_issue4725_2():
# ensures that this runs correctly and doesn't hang or crash because of the global vectors if isinstance(get_current_ops, NumpyOps):
# if it does crash, it's usually because of calling 'spawn' for multiprocessing (e.g. on Windows), # ensures that this runs correctly and doesn't hang or crash because of the global vectors
# or because of issues with pickling the NER (cf test_issue4725_1) # if it does crash, it's usually because of calling 'spawn' for multiprocessing (e.g. on Windows),
vocab = Vocab(vectors_name="test_vocab_add_vector") # or because of issues with pickling the NER (cf test_issue4725_1)
data = numpy.ndarray((5, 3), dtype="f") vocab = Vocab(vectors_name="test_vocab_add_vector")
data[0] = 1.0 data = numpy.ndarray((5, 3), dtype="f")
data[1] = 2.0 data[0] = 1.0
vocab.set_vector("cat", data[0]) data[1] = 2.0
vocab.set_vector("dog", data[1]) vocab.set_vector("cat", data[0])
nlp = English(vocab=vocab) vocab.set_vector("dog", data[1])
nlp.add_pipe("ner") nlp = English(vocab=vocab)
nlp.initialize() nlp.add_pipe("ner")
docs = ["Kurt is in London."] * 10 nlp.initialize()
for _ in nlp.pipe(docs, batch_size=2, n_process=2): docs = ["Kurt is in London."] * 10
pass for _ in nlp.pipe(docs, batch_size=2, n_process=2):
pass
def test_issue4849(): def test_issue4849():
@ -204,10 +206,11 @@ def test_issue4849():
count_ents += len([ent for ent in doc.ents if ent.ent_id > 0]) count_ents += len([ent for ent in doc.ents if ent.ent_id > 0])
assert count_ents == 2 assert count_ents == 2
# USING 2 PROCESSES # USING 2 PROCESSES
count_ents = 0 if isinstance(get_current_ops, NumpyOps):
for doc in nlp.pipe([text], n_process=2): count_ents = 0
count_ents += len([ent for ent in doc.ents if ent.ent_id > 0]) for doc in nlp.pipe([text], n_process=2):
assert count_ents == 2 count_ents += len([ent for ent in doc.ents if ent.ent_id > 0])
assert count_ents == 2
@Language.factory("my_pipe") @Language.factory("my_pipe")
@ -239,10 +242,11 @@ def test_issue4903():
nlp.add_pipe("sentencizer") nlp.add_pipe("sentencizer")
nlp.add_pipe("my_pipe", after="sentencizer") nlp.add_pipe("my_pipe", after="sentencizer")
text = ["I like bananas.", "Do you like them?", "No, I prefer wasabi."] text = ["I like bananas.", "Do you like them?", "No, I prefer wasabi."]
docs = list(nlp.pipe(text, n_process=2)) if isinstance(get_current_ops(), NumpyOps):
assert docs[0].text == "I like bananas." docs = list(nlp.pipe(text, n_process=2))
assert docs[1].text == "Do you like them?" assert docs[0].text == "I like bananas."
assert docs[2].text == "No, I prefer wasabi." assert docs[1].text == "Do you like them?"
assert docs[2].text == "No, I prefer wasabi."
def test_issue4924(): def test_issue4924():

View File

@ -6,6 +6,7 @@ from spacy.language import Language
from spacy.lang.en.syntax_iterators import noun_chunks from spacy.lang.en.syntax_iterators import noun_chunks
from spacy.vocab import Vocab from spacy.vocab import Vocab
import spacy import spacy
from thinc.api import get_current_ops
import pytest import pytest
from ...util import make_tempdir from ...util import make_tempdir
@ -54,16 +55,17 @@ def test_issue5082():
ruler.add_patterns(patterns) ruler.add_patterns(patterns)
parsed_vectors_1 = [t.vector for t in nlp(text)] parsed_vectors_1 = [t.vector for t in nlp(text)]
assert len(parsed_vectors_1) == 4 assert len(parsed_vectors_1) == 4
numpy.testing.assert_array_equal(parsed_vectors_1[0], array1) ops = get_current_ops()
numpy.testing.assert_array_equal(parsed_vectors_1[1], array2) numpy.testing.assert_array_equal(ops.to_numpy(parsed_vectors_1[0]), array1)
numpy.testing.assert_array_equal(parsed_vectors_1[2], array3) numpy.testing.assert_array_equal(ops.to_numpy(parsed_vectors_1[1]), array2)
numpy.testing.assert_array_equal(parsed_vectors_1[3], array4) numpy.testing.assert_array_equal(ops.to_numpy(parsed_vectors_1[2]), array3)
numpy.testing.assert_array_equal(ops.to_numpy(parsed_vectors_1[3]), array4)
nlp.add_pipe("merge_entities") nlp.add_pipe("merge_entities")
parsed_vectors_2 = [t.vector for t in nlp(text)] parsed_vectors_2 = [t.vector for t in nlp(text)]
assert len(parsed_vectors_2) == 3 assert len(parsed_vectors_2) == 3
numpy.testing.assert_array_equal(parsed_vectors_2[0], array1) numpy.testing.assert_array_equal(ops.to_numpy(parsed_vectors_2[0]), array1)
numpy.testing.assert_array_equal(parsed_vectors_2[1], array2) numpy.testing.assert_array_equal(ops.to_numpy(parsed_vectors_2[1]), array2)
numpy.testing.assert_array_equal(parsed_vectors_2[2], array34) numpy.testing.assert_array_equal(ops.to_numpy(parsed_vectors_2[2]), array34)
def test_issue5137(): def test_issue5137():

View File

@ -1,5 +1,6 @@
import pytest import pytest
from thinc.api import Config, fix_random_seed from numpy.testing import assert_almost_equal
from thinc.api import Config, fix_random_seed, get_current_ops
from spacy.lang.en import English from spacy.lang.en import English
from spacy.pipeline.textcat import single_label_default_config, single_label_bow_config from spacy.pipeline.textcat import single_label_default_config, single_label_bow_config
@ -44,11 +45,12 @@ def test_issue5551(textcat_config):
nlp.update([Example.from_dict(doc, annots)]) nlp.update([Example.from_dict(doc, annots)])
# Store the result of each iteration # Store the result of each iteration
result = pipe.model.predict([doc]) result = pipe.model.predict([doc])
results.append(list(result[0])) results.append(result[0])
# All results should be the same because of the fixed seed # All results should be the same because of the fixed seed
assert len(results) == 3 assert len(results) == 3
assert results[0] == results[1] ops = get_current_ops()
assert results[0] == results[2] assert_almost_equal(ops.to_numpy(results[0]), ops.to_numpy(results[1]))
assert_almost_equal(ops.to_numpy(results[0]), ops.to_numpy(results[2]))
def test_issue5838(): def test_issue5838():

View File

@ -10,6 +10,7 @@ from spacy.lang.en import English
from spacy.lang.de import German from spacy.lang.de import German
from spacy.util import registry, ignore_error, raise_error from spacy.util import registry, ignore_error, raise_error
import spacy import spacy
from thinc.api import NumpyOps, get_current_ops
from .util import add_vecs_to_vocab, assert_docs_equal from .util import add_vecs_to_vocab, assert_docs_equal
@ -142,25 +143,29 @@ def texts():
@pytest.mark.parametrize("n_process", [1, 2]) @pytest.mark.parametrize("n_process", [1, 2])
def test_language_pipe(nlp2, n_process, texts): def test_language_pipe(nlp2, n_process, texts):
texts = texts * 10 ops = get_current_ops()
expecteds = [nlp2(text) for text in texts] if isinstance(ops, NumpyOps) or n_process < 2:
docs = nlp2.pipe(texts, n_process=n_process, batch_size=2) texts = texts * 10
expecteds = [nlp2(text) for text in texts]
docs = nlp2.pipe(texts, n_process=n_process, batch_size=2)
for doc, expected_doc in zip(docs, expecteds): for doc, expected_doc in zip(docs, expecteds):
assert_docs_equal(doc, expected_doc) assert_docs_equal(doc, expected_doc)
@pytest.mark.parametrize("n_process", [1, 2]) @pytest.mark.parametrize("n_process", [1, 2])
def test_language_pipe_stream(nlp2, n_process, texts): def test_language_pipe_stream(nlp2, n_process, texts):
# check if nlp.pipe can handle infinite length iterator properly. ops = get_current_ops()
stream_texts = itertools.cycle(texts) if isinstance(ops, NumpyOps) or n_process < 2:
texts0, texts1 = itertools.tee(stream_texts) # check if nlp.pipe can handle infinite length iterator properly.
expecteds = (nlp2(text) for text in texts0) stream_texts = itertools.cycle(texts)
docs = nlp2.pipe(texts1, n_process=n_process, batch_size=2) texts0, texts1 = itertools.tee(stream_texts)
expecteds = (nlp2(text) for text in texts0)
docs = nlp2.pipe(texts1, n_process=n_process, batch_size=2)
n_fetch = 20 n_fetch = 20
for doc, expected_doc in itertools.islice(zip(docs, expecteds), n_fetch): for doc, expected_doc in itertools.islice(zip(docs, expecteds), n_fetch):
assert_docs_equal(doc, expected_doc) assert_docs_equal(doc, expected_doc)
def test_language_pipe_error_handler(): def test_language_pipe_error_handler():

View File

@ -8,7 +8,8 @@ from spacy import prefer_gpu, require_gpu, require_cpu
from spacy.ml._precomputable_affine import PrecomputableAffine from spacy.ml._precomputable_affine import PrecomputableAffine
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
from spacy.util import dot_to_object, SimpleFrozenList, import_file from spacy.util import dot_to_object, SimpleFrozenList, import_file
from thinc.api import Config, Optimizer, ConfigValidationError from thinc.api import Config, Optimizer, ConfigValidationError, get_current_ops
from thinc.api import set_current_ops
from spacy.training.batchers import minibatch_by_words from spacy.training.batchers import minibatch_by_words
from spacy.lang.en import English from spacy.lang.en import English
from spacy.lang.nl import Dutch from spacy.lang.nl import Dutch
@ -81,6 +82,7 @@ def test_PrecomputableAffine(nO=4, nI=5, nF=3, nP=2):
def test_prefer_gpu(): def test_prefer_gpu():
current_ops = get_current_ops()
try: try:
import cupy # noqa: F401 import cupy # noqa: F401
@ -88,9 +90,11 @@ def test_prefer_gpu():
assert isinstance(get_current_ops(), CupyOps) assert isinstance(get_current_ops(), CupyOps)
except ImportError: except ImportError:
assert not prefer_gpu() assert not prefer_gpu()
set_current_ops(current_ops)
def test_require_gpu(): def test_require_gpu():
current_ops = get_current_ops()
try: try:
import cupy # noqa: F401 import cupy # noqa: F401
@ -99,9 +103,11 @@ def test_require_gpu():
except ImportError: except ImportError:
with pytest.raises(ValueError): with pytest.raises(ValueError):
require_gpu() require_gpu()
set_current_ops(current_ops)
def test_require_cpu(): def test_require_cpu():
current_ops = get_current_ops()
require_cpu() require_cpu()
assert isinstance(get_current_ops(), NumpyOps) assert isinstance(get_current_ops(), NumpyOps)
try: try:
@ -113,6 +119,7 @@ def test_require_cpu():
pass pass
require_cpu() require_cpu()
assert isinstance(get_current_ops(), NumpyOps) assert isinstance(get_current_ops(), NumpyOps)
set_current_ops(current_ops)
def test_ascii_filenames(): def test_ascii_filenames():

View File

@ -1,7 +1,7 @@
from typing import List from typing import List
import pytest import pytest
from thinc.api import fix_random_seed, Adam, set_dropout_rate from thinc.api import fix_random_seed, Adam, set_dropout_rate
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal, assert_array_almost_equal
import numpy import numpy
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
from spacy.ml.models import build_bow_text_classifier, build_simple_cnn_text_classifier from spacy.ml.models import build_bow_text_classifier, build_simple_cnn_text_classifier
@ -109,7 +109,7 @@ def test_models_initialize_consistently(seed, model_func, kwargs):
model2.initialize() model2.initialize()
params1 = get_all_params(model1) params1 = get_all_params(model1)
params2 = get_all_params(model2) params2 = get_all_params(model2)
assert_array_equal(params1, params2) assert_array_equal(model1.ops.to_numpy(params1), model2.ops.to_numpy(params2))
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -134,14 +134,25 @@ def test_models_predict_consistently(seed, model_func, kwargs, get_X):
for i in range(len(tok2vec1)): for i in range(len(tok2vec1)):
for j in range(len(tok2vec1[i])): for j in range(len(tok2vec1[i])):
assert_array_equal( assert_array_equal(
numpy.asarray(tok2vec1[i][j]), numpy.asarray(tok2vec2[i][j]) numpy.asarray(model1.ops.to_numpy(tok2vec1[i][j])),
numpy.asarray(model2.ops.to_numpy(tok2vec2[i][j])),
) )
try:
Y1 = model1.ops.to_numpy(Y1)
Y2 = model2.ops.to_numpy(Y2)
except Exception:
pass
if isinstance(Y1, numpy.ndarray): if isinstance(Y1, numpy.ndarray):
assert_array_equal(Y1, Y2) assert_array_equal(Y1, Y2)
elif isinstance(Y1, List): elif isinstance(Y1, List):
assert len(Y1) == len(Y2) assert len(Y1) == len(Y2)
for y1, y2 in zip(Y1, Y2): for y1, y2 in zip(Y1, Y2):
try:
y1 = model1.ops.to_numpy(y1)
y2 = model2.ops.to_numpy(y2)
except Exception:
pass
assert_array_equal(y1, y2) assert_array_equal(y1, y2)
else: else:
raise ValueError(f"Could not compare type {type(Y1)}") raise ValueError(f"Could not compare type {type(Y1)}")
@ -169,12 +180,17 @@ def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X):
model.finish_update(optimizer) model.finish_update(optimizer)
updated_params = get_all_params(model) updated_params = get_all_params(model)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert_array_equal(initial_params, updated_params) assert_array_equal(
model.ops.to_numpy(initial_params), model.ops.to_numpy(updated_params)
)
return model return model
model1 = get_updated_model() model1 = get_updated_model()
model2 = get_updated_model() model2 = get_updated_model()
assert_array_equal(get_all_params(model1), get_all_params(model2)) assert_array_almost_equal(
model1.ops.to_numpy(get_all_params(model1)),
model2.ops.to_numpy(get_all_params(model2)),
)
@pytest.mark.parametrize("model_func,kwargs", [(StaticVectors, {"nO": 128, "nM": 300})]) @pytest.mark.parametrize("model_func,kwargs", [(StaticVectors, {"nO": 128, "nM": 300})])

View File

@ -5,6 +5,7 @@ import srsly
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.util import make_tempdir # noqa: F401 from spacy.util import make_tempdir # noqa: F401
from thinc.api import get_current_ops
@contextlib.contextmanager @contextlib.contextmanager
@ -58,7 +59,10 @@ def add_vecs_to_vocab(vocab, vectors):
def get_cosine(vec1, vec2): def get_cosine(vec1, vec2):
"""Get cosine for two given vectors""" """Get cosine for two given vectors"""
return numpy.dot(vec1, vec2) / (numpy.linalg.norm(vec1) * numpy.linalg.norm(vec2)) OPS = get_current_ops()
v1 = OPS.to_numpy(OPS.asarray(vec1))
v2 = OPS.to_numpy(OPS.asarray(vec2))
return numpy.dot(v1, v2) / (numpy.linalg.norm(v1) * numpy.linalg.norm(v2))
def assert_docs_equal(doc1, doc2): def assert_docs_equal(doc1, doc2):

View File

@ -1,6 +1,7 @@
import pytest import pytest
import numpy import numpy
from numpy.testing import assert_allclose, assert_equal from numpy.testing import assert_allclose, assert_equal
from thinc.api import get_current_ops
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.vectors import Vectors from spacy.vectors import Vectors
from spacy.tokenizer import Tokenizer from spacy.tokenizer import Tokenizer
@ -9,6 +10,7 @@ from spacy.tokens import Doc
from ..util import add_vecs_to_vocab, get_cosine, make_tempdir from ..util import add_vecs_to_vocab, get_cosine, make_tempdir
OPS = get_current_ops()
@pytest.fixture @pytest.fixture
def strings(): def strings():
@ -18,21 +20,21 @@ def strings():
@pytest.fixture @pytest.fixture
def vectors(): def vectors():
return [ return [
("apple", [1, 2, 3]), ("apple", OPS.asarray([1, 2, 3])),
("orange", [-1, -2, -3]), ("orange", OPS.asarray([-1, -2, -3])),
("and", [-1, -1, -1]), ("and", OPS.asarray([-1, -1, -1])),
("juice", [5, 5, 10]), ("juice", OPS.asarray([5, 5, 10])),
("pie", [7, 6.3, 8.9]), ("pie", OPS.asarray([7, 6.3, 8.9])),
] ]
@pytest.fixture @pytest.fixture
def ngrams_vectors(): def ngrams_vectors():
return [ return [
("apple", [1, 2, 3]), ("apple", OPS.asarray([1, 2, 3])),
("app", [-0.1, -0.2, -0.3]), ("app", OPS.asarray([-0.1, -0.2, -0.3])),
("ppl", [-0.2, -0.3, -0.4]), ("ppl", OPS.asarray([-0.2, -0.3, -0.4])),
("pl", [0.7, 0.8, 0.9]), ("pl", OPS.asarray([0.7, 0.8, 0.9])),
] ]
@ -171,8 +173,10 @@ def test_vectors_most_similar_identical():
@pytest.mark.parametrize("text", ["apple and orange"]) @pytest.mark.parametrize("text", ["apple and orange"])
def test_vectors_token_vector(tokenizer_v, vectors, text): def test_vectors_token_vector(tokenizer_v, vectors, text):
doc = tokenizer_v(text) doc = tokenizer_v(text)
assert vectors[0] == (doc[0].text, list(doc[0].vector)) assert vectors[0][0] == doc[0].text
assert vectors[1] == (doc[2].text, list(doc[2].vector)) assert all([a == b for a, b in zip(vectors[0][1], doc[0].vector)])
assert vectors[1][0] == doc[2].text
assert all([a == b for a, b in zip(vectors[1][1], doc[2].vector)])
@pytest.mark.parametrize("text", ["apple"]) @pytest.mark.parametrize("text", ["apple"])
@ -301,7 +305,7 @@ def test_vectors_doc_doc_similarity(vocab, text1, text2):
def test_vocab_add_vector(): def test_vocab_add_vector():
vocab = Vocab(vectors_name="test_vocab_add_vector") vocab = Vocab(vectors_name="test_vocab_add_vector")
data = numpy.ndarray((5, 3), dtype="f") data = OPS.xp.ndarray((5, 3), dtype="f")
data[0] = 1.0 data[0] = 1.0
data[1] = 2.0 data[1] = 2.0
vocab.set_vector("cat", data[0]) vocab.set_vector("cat", data[0])
@ -320,10 +324,10 @@ def test_vocab_prune_vectors():
_ = vocab["cat"] # noqa: F841 _ = vocab["cat"] # noqa: F841
_ = vocab["dog"] # noqa: F841 _ = vocab["dog"] # noqa: F841
_ = vocab["kitten"] # noqa: F841 _ = vocab["kitten"] # noqa: F841
data = numpy.ndarray((5, 3), dtype="f") data = OPS.xp.ndarray((5, 3), dtype="f")
data[0] = [1.0, 1.2, 1.1] data[0] = OPS.asarray([1.0, 1.2, 1.1])
data[1] = [0.3, 1.3, 1.0] data[1] = OPS.asarray([0.3, 1.3, 1.0])
data[2] = [0.9, 1.22, 1.05] data[2] = OPS.asarray([0.9, 1.22, 1.05])
vocab.set_vector("cat", data[0]) vocab.set_vector("cat", data[0])
vocab.set_vector("dog", data[1]) vocab.set_vector("dog", data[1])
vocab.set_vector("kitten", data[2]) vocab.set_vector("kitten", data[2])
@ -332,40 +336,41 @@ def test_vocab_prune_vectors():
assert list(remap.keys()) == ["kitten"] assert list(remap.keys()) == ["kitten"]
neighbour, similarity = list(remap.values())[0] neighbour, similarity = list(remap.values())[0]
assert neighbour == "cat", remap assert neighbour == "cat", remap
assert_allclose(similarity, get_cosine(data[0], data[2]), atol=1e-4, rtol=1e-3) cosine = get_cosine(data[0], data[2])
assert_allclose(float(similarity), cosine, atol=1e-4, rtol=1e-3)
def test_vectors_serialize(): def test_vectors_serialize():
data = numpy.asarray([[4, 2, 2, 2], [4, 2, 2, 2], [1, 1, 1, 1]], dtype="f") data = OPS.asarray([[4, 2, 2, 2], [4, 2, 2, 2], [1, 1, 1, 1]], dtype="f")
v = Vectors(data=data, keys=["A", "B", "C"]) v = Vectors(data=data, keys=["A", "B", "C"])
b = v.to_bytes() b = v.to_bytes()
v_r = Vectors() v_r = Vectors()
v_r.from_bytes(b) v_r.from_bytes(b)
assert_equal(v.data, v_r.data) assert_equal(OPS.to_numpy(v.data), OPS.to_numpy(v_r.data))
assert v.key2row == v_r.key2row assert v.key2row == v_r.key2row
v.resize((5, 4)) v.resize((5, 4))
v_r.resize((5, 4)) v_r.resize((5, 4))
row = v.add("D", vector=numpy.asarray([1, 2, 3, 4], dtype="f")) row = v.add("D", vector=OPS.asarray([1, 2, 3, 4], dtype="f"))
row_r = v_r.add("D", vector=numpy.asarray([1, 2, 3, 4], dtype="f")) row_r = v_r.add("D", vector=OPS.asarray([1, 2, 3, 4], dtype="f"))
assert row == row_r assert row == row_r
assert_equal(v.data, v_r.data) assert_equal(OPS.to_numpy(v.data), OPS.to_numpy(v_r.data))
assert v.is_full == v_r.is_full assert v.is_full == v_r.is_full
with make_tempdir() as d: with make_tempdir() as d:
v.to_disk(d) v.to_disk(d)
v_r.from_disk(d) v_r.from_disk(d)
assert_equal(v.data, v_r.data) assert_equal(OPS.to_numpy(v.data), OPS.to_numpy(v_r.data))
assert v.key2row == v_r.key2row assert v.key2row == v_r.key2row
v.resize((5, 4)) v.resize((5, 4))
v_r.resize((5, 4)) v_r.resize((5, 4))
row = v.add("D", vector=numpy.asarray([10, 20, 30, 40], dtype="f")) row = v.add("D", vector=OPS.asarray([10, 20, 30, 40], dtype="f"))
row_r = v_r.add("D", vector=numpy.asarray([10, 20, 30, 40], dtype="f")) row_r = v_r.add("D", vector=OPS.asarray([10, 20, 30, 40], dtype="f"))
assert row == row_r assert row == row_r
assert_equal(v.data, v_r.data) assert_equal(OPS.to_numpy(v.data), OPS.to_numpy(v_r.data))
def test_vector_is_oov(): def test_vector_is_oov():
vocab = Vocab(vectors_name="test_vocab_is_oov") vocab = Vocab(vectors_name="test_vocab_is_oov")
data = numpy.ndarray((5, 3), dtype="f") data = OPS.xp.ndarray((5, 3), dtype="f")
data[0] = 1.0 data[0] = 1.0
data[1] = 2.0 data[1] = 2.0
vocab.set_vector("cat", data[0]) vocab.set_vector("cat", data[0])