Fix spancat tests on GPU (#8872)

* Fix spancat tests on GPU

* Fix more spancat tests
This commit is contained in:
Adriane Boyd 2021-08-04 14:29:43 +02:00 committed by GitHub
parent 77d698dcae
commit fa2e7a4bbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,9 +1,11 @@
import pytest
from numpy.testing import assert_equal
from numpy.testing import assert_equal, assert_array_equal
from thinc.api import get_current_ops
from spacy.language import Language
from spacy.training import Example
from spacy.util import fix_random_seed, registry
OPS = get_current_ops()
SPAN_KEY = "labeled_spans"
@ -116,12 +118,12 @@ def test_ngram_suggester(en_tokenizer):
for span in spans:
assert 0 <= span[0] < len(doc)
assert 0 < span[1] <= len(doc)
spans_set.add((span[0], span[1]))
spans_set.add((int(span[0]), int(span[1])))
# spans are unique
assert spans.shape[0] == len(spans_set)
offset += ngrams.lengths[i]
# the number of spans is correct
assert_equal(ngrams.lengths, [max(0, len(doc) - (size - 1)) for doc in docs])
assert_array_equal(OPS.to_numpy(ngrams.lengths), [max(0, len(doc) - (size - 1)) for doc in docs])
# test 1-3-gram suggestions
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2, 3])
@ -129,9 +131,9 @@ def test_ngram_suggester(en_tokenizer):
en_tokenizer(text) for text in ["a", "a b", "a b c", "a b c d", "a b c d e"]
]
ngrams = ngram_suggester(docs)
assert_equal(ngrams.lengths, [1, 3, 6, 9, 12])
assert_equal(
ngrams.data,
assert_array_equal(OPS.to_numpy(ngrams.lengths), [1, 3, 6, 9, 12])
assert_array_equal(
OPS.to_numpy(ngrams.data),
[
# doc 0
[0, 1],
@ -176,13 +178,13 @@ def test_ngram_suggester(en_tokenizer):
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1])
docs = [en_tokenizer(text) for text in ["", "a", ""]]
ngrams = ngram_suggester(docs)
assert_equal(ngrams.lengths, [len(doc) for doc in docs])
assert_array_equal(OPS.to_numpy(ngrams.lengths), [len(doc) for doc in docs])
# test all empty docs
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1])
docs = [en_tokenizer(text) for text in ["", "", ""]]
ngrams = ngram_suggester(docs)
assert_equal(ngrams.lengths, [len(doc) for doc in docs])
assert_array_equal(OPS.to_numpy(ngrams.lengths), [len(doc) for doc in docs])
def test_ngram_sizes(en_tokenizer):
@ -195,12 +197,12 @@ def test_ngram_sizes(en_tokenizer):
]
ngrams_1 = size_suggester(docs)
ngrams_2 = range_suggester(docs)
assert_equal(ngrams_1.lengths, [1, 3, 6, 9, 12])
assert_equal(ngrams_1.lengths, ngrams_2.lengths)
assert_equal(ngrams_1.data, ngrams_2.data)
assert_array_equal(OPS.to_numpy(ngrams_1.lengths), [1, 3, 6, 9, 12])
assert_array_equal(OPS.to_numpy(ngrams_1.lengths), OPS.to_numpy(ngrams_2.lengths))
assert_array_equal(OPS.to_numpy(ngrams_1.data), OPS.to_numpy(ngrams_2.data))
# one more variation
suggester_factory = registry.misc.get("spacy.ngram_range_suggester.v1")
range_suggester = suggester_factory(min_size=2, max_size=4)
ngrams_3 = range_suggester(docs)
assert_equal(ngrams_3.lengths, [0, 1, 3, 6, 9])
assert_array_equal(OPS.to_numpy(ngrams_3.lengths), [0, 1, 3, 6, 9])