This commit is contained in:
svlandeg 2020-10-13 18:52:37 +02:00
parent 6ccacff54e
commit ff83bfae3f

View File

@ -39,7 +39,7 @@ def get_docs():
def test_components_batching_list(name): def test_components_batching_list(name):
nlp = English() nlp = English()
proc = nlp.create_pipe(name) proc = nlp.create_pipe(name)
util_batch_unbatch_List(proc.model, get_docs(), list_floats) util_batch_unbatch_docs_list(proc.model, get_docs(), list_floats)
# Test components with a model of type Model[List[Doc], Floats2d] # Test components with a model of type Model[List[Doc], Floats2d]
@ -48,7 +48,7 @@ def test_components_batching_array(name):
nlp = English() nlp = English()
in_data = [nlp(text) for text in texts] in_data = [nlp(text) for text in texts]
proc = nlp.create_pipe(name) proc = nlp.create_pipe(name)
util_batch_unbatch_Array(proc.model, get_docs(), array) util_batch_unbatch_docs_array(proc.model, get_docs(), array)
LAYERS = [ LAYERS = [
@ -63,19 +63,19 @@ def test_layers_batching_all(model, in_data, out_data):
# In = List[Doc] # In = List[Doc]
if isinstance(in_data, list) and isinstance(in_data[0], Doc): if isinstance(in_data, list) and isinstance(in_data[0], Doc):
if isinstance(out_data, OPS.xp.ndarray) and out_data.ndim == 2: if isinstance(out_data, OPS.xp.ndarray) and out_data.ndim == 2:
util_batch_unbatch_Array(model, in_data, out_data) util_batch_unbatch_docs_array(model, in_data, out_data)
elif ( elif (
isinstance(out_data, list) isinstance(out_data, list)
and isinstance(out_data[0], OPS.xp.ndarray) and isinstance(out_data[0], OPS.xp.ndarray)
and out_data[0].ndim == 2 and out_data[0].ndim == 2
): ):
util_batch_unbatch_List(model, in_data, out_data) util_batch_unbatch_docs_list(model, in_data, out_data)
elif isinstance(out_data, Ragged): elif isinstance(out_data, Ragged):
util_batch_unbatch_Ragged(model, in_data, out_data) util_batch_unbatch_docs_ragged(model, in_data, out_data)
def util_batch_unbatch_List( def util_batch_unbatch_docs_list(
model: Model[List[Doc], List[Array2d]], in_data: List[Doc], out_data: List[Array2d] model: Model[List[Doc], List[Array2d]], in_data: List[Doc], out_data: List[Array2d]
): ):
with data_validation(True): with data_validation(True):
@ -86,7 +86,7 @@ def util_batch_unbatch_List(
assert_almost_equal(Y_batched[i], Y_not_batched[i], decimal=4) assert_almost_equal(Y_batched[i], Y_not_batched[i], decimal=4)
def util_batch_unbatch_Array( def util_batch_unbatch_docs_array(
model: Model[List[Doc], Array2d], in_data: List[Doc], out_data: Array2d model: Model[List[Doc], Array2d], in_data: List[Doc], out_data: Array2d
): ):
with data_validation(True): with data_validation(True):
@ -96,7 +96,7 @@ def util_batch_unbatch_Array(
assert_almost_equal(Y_batched, Y_not_batched, decimal=4) assert_almost_equal(Y_batched, Y_not_batched, decimal=4)
def util_batch_unbatch_Ragged( def util_batch_unbatch_docs_ragged(
model: Model[List[Doc], Ragged], in_data: List[Doc], out_data: Ragged model: Model[List[Doc], Ragged], in_data: List[Doc], out_data: Ragged
): ):
with data_validation(True): with data_validation(True):