mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
add tests for individual spacy layers
This commit is contained in:
parent
c23041ae60
commit
6ccacff54e
|
@ -1,46 +1,108 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import numpy
|
||||||
import pytest
|
import pytest
|
||||||
from numpy.testing import assert_equal
|
from numpy.testing import assert_almost_equal
|
||||||
|
from spacy.vocab import Vocab
|
||||||
from thinc.api import get_current_ops, Model, data_validation
|
from thinc.api import get_current_ops, Model, data_validation
|
||||||
from thinc.types import Array2d
|
from thinc.types import Array2d, Ragged
|
||||||
|
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
|
from spacy.ml import FeatureExtractor, StaticVectors
|
||||||
|
from spacy.ml._character_embed import CharacterEmbed
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
|
|
||||||
OPS = get_current_ops()
|
OPS = get_current_ops()
|
||||||
|
|
||||||
texts = ["These are 4 words", "These just three"]
|
texts = ["These are 4 words", "Here just three"]
|
||||||
l0 = [[1, 2], [3, 4], [5, 6], [7, 8]]
|
l0 = [[1, 2], [3, 4], [5, 6], [7, 8]]
|
||||||
l1 = [[9, 8], [7, 6], [5, 4]]
|
l1 = [[9, 8], [7, 6], [5, 4]]
|
||||||
out_list = [OPS.xp.asarray(l0, dtype="f"), OPS.xp.asarray(l1, dtype="f")]
|
list_floats = [OPS.xp.asarray(l0, dtype="f"), OPS.xp.asarray(l1, dtype="f")]
|
||||||
a1 = OPS.xp.asarray(l1, dtype="f")
|
list_ints = [OPS.xp.asarray(l0, dtype="i"), OPS.xp.asarray(l1, dtype="i")]
|
||||||
|
array = OPS.xp.asarray(l1, dtype="f")
|
||||||
|
ragged = Ragged(array, OPS.xp.asarray([2, 1], dtype="i"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_docs():
|
||||||
|
vocab = Vocab()
|
||||||
|
for t in texts:
|
||||||
|
for word in t.split():
|
||||||
|
hash_id = vocab.strings.add(word)
|
||||||
|
vector = numpy.random.uniform(-1, 1, (7,))
|
||||||
|
vocab.set_vector(hash_id, vector)
|
||||||
|
docs = [English(vocab)(t) for t in texts]
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
# Test components with a model of type Model[List[Doc], List[Floats2d]]
|
# Test components with a model of type Model[List[Doc], List[Floats2d]]
|
||||||
@pytest.mark.parametrize("name", ["tagger", "tok2vec", "morphologizer", "senter"])
|
@pytest.mark.parametrize("name", ["tagger", "tok2vec", "morphologizer", "senter"])
|
||||||
def test_layers_batching_all_list(name):
|
def test_components_batching_list(name):
|
||||||
nlp = English()
|
nlp = English()
|
||||||
in_data = [nlp(text) for text in texts]
|
|
||||||
proc = nlp.create_pipe(name)
|
proc = nlp.create_pipe(name)
|
||||||
util_batch_unbatch_List(proc.model, in_data, out_list)
|
util_batch_unbatch_List(proc.model, get_docs(), list_floats)
|
||||||
|
|
||||||
def util_batch_unbatch_List(model: Model[List[Doc], List[Array2d]], in_data: List[Doc], out_data: List[Array2d]):
|
|
||||||
with data_validation(True):
|
|
||||||
model.initialize(in_data, out_data)
|
|
||||||
Y_batched = model.predict(in_data)
|
|
||||||
Y_not_batched = [model.predict([u])[0] for u in in_data]
|
|
||||||
assert_equal(Y_batched, Y_not_batched)
|
|
||||||
|
|
||||||
# Test components with a model of type Model[List[Doc], Floats2d]
|
# Test components with a model of type Model[List[Doc], Floats2d]
|
||||||
@pytest.mark.parametrize("name", ["textcat"])
|
@pytest.mark.parametrize("name", ["textcat"])
|
||||||
def test_layers_batching_all_array(name):
|
def test_components_batching_array(name):
|
||||||
nlp = English()
|
nlp = English()
|
||||||
in_data = [nlp(text) for text in texts]
|
in_data = [nlp(text) for text in texts]
|
||||||
proc = nlp.create_pipe(name)
|
proc = nlp.create_pipe(name)
|
||||||
util_batch_unbatch_Array(proc.model, in_data, a1)
|
util_batch_unbatch_Array(proc.model, get_docs(), array)
|
||||||
|
|
||||||
def util_batch_unbatch_Array(model: Model[List[Doc], Array2d], in_data: List[Doc], out_data: Array2d):
|
|
||||||
|
LAYERS = [
|
||||||
|
(CharacterEmbed(nM=5, nC=3), get_docs(), list_floats),
|
||||||
|
(FeatureExtractor([100, 200]), get_docs(), list_ints),
|
||||||
|
(StaticVectors(), get_docs(), ragged),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model,in_data,out_data", LAYERS)
|
||||||
|
def test_layers_batching_all(model, in_data, out_data):
|
||||||
|
# In = List[Doc]
|
||||||
|
if isinstance(in_data, list) and isinstance(in_data[0], Doc):
|
||||||
|
if isinstance(out_data, OPS.xp.ndarray) and out_data.ndim == 2:
|
||||||
|
util_batch_unbatch_Array(model, in_data, out_data)
|
||||||
|
elif (
|
||||||
|
isinstance(out_data, list)
|
||||||
|
and isinstance(out_data[0], OPS.xp.ndarray)
|
||||||
|
and out_data[0].ndim == 2
|
||||||
|
):
|
||||||
|
util_batch_unbatch_List(model, in_data, out_data)
|
||||||
|
elif isinstance(out_data, Ragged):
|
||||||
|
util_batch_unbatch_Ragged(model, in_data, out_data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def util_batch_unbatch_List(
|
||||||
|
model: Model[List[Doc], List[Array2d]], in_data: List[Doc], out_data: List[Array2d]
|
||||||
|
):
|
||||||
with data_validation(True):
|
with data_validation(True):
|
||||||
model.initialize(in_data, out_data)
|
model.initialize(in_data, out_data)
|
||||||
Y_batched = model.predict(in_data)
|
Y_batched = model.predict(in_data)
|
||||||
Y_not_batched = [model.predict([u])[0] for u in in_data]
|
Y_not_batched = [model.predict([u])[0] for u in in_data]
|
||||||
assert_equal(Y_batched, Y_not_batched)
|
for i in range(len(Y_batched)):
|
||||||
|
assert_almost_equal(Y_batched[i], Y_not_batched[i], decimal=4)
|
||||||
|
|
||||||
|
|
||||||
|
def util_batch_unbatch_Array(
|
||||||
|
model: Model[List[Doc], Array2d], in_data: List[Doc], out_data: Array2d
|
||||||
|
):
|
||||||
|
with data_validation(True):
|
||||||
|
model.initialize(in_data, out_data)
|
||||||
|
Y_batched = model.predict(in_data).tolist()
|
||||||
|
Y_not_batched = [model.predict([u])[0] for u in in_data]
|
||||||
|
assert_almost_equal(Y_batched, Y_not_batched, decimal=4)
|
||||||
|
|
||||||
|
|
||||||
|
def util_batch_unbatch_Ragged(
|
||||||
|
model: Model[List[Doc], Ragged], in_data: List[Doc], out_data: Ragged
|
||||||
|
):
|
||||||
|
with data_validation(True):
|
||||||
|
model.initialize(in_data, out_data)
|
||||||
|
Y_batched = model.predict(in_data)
|
||||||
|
Y_not_batched = []
|
||||||
|
for u in in_data:
|
||||||
|
Y_not_batched.extend(model.predict([u]).data.tolist())
|
||||||
|
assert_almost_equal(Y_batched.data, Y_not_batched, decimal=4)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user