Add tests for get_affixes

This commit is contained in:
richardpaulhudson 2022-09-13 12:45:40 +02:00
parent b2074a15d1
commit 910a6bc98f
2 changed files with 39 additions and 0 deletions

View File

@ -975,3 +975,40 @@ def test_doc_spans_setdefault(en_tokenizer):
assert len(doc.spans["key2"]) == 1
doc.spans.setdefault("key3", default=SpanGroup(doc, spans=[doc[0:1], doc[1:2]]))
assert len(doc.spans["key3"]) == 2
def test_get_affixes_good_case(en_tokenizer):
doc = en_tokenizer("spaCy✨ and Prodigy")
prefixes = doc.get_affixes(False, 1, 5, "", 2, 3)
suffixes = doc.get_affixes(True, 2, 6, "xx✨rP", 2, 3)
assert prefixes[0][3, 3, 6] == suffixes[0][3, 3, 6]
assert prefixes[0][3, 3, 7] == suffixes[0][3, 3, 7]
assert prefixes[0][3, 3, 4] == suffixes[0][3, 3, 8]
assert prefixes[0][3, 3, 5] == suffixes[0][3, 3, 9]
assert (prefixes[0][0, :, 2:] == 0).all()
assert not (suffixes[0][0, :, 2:] == 0).all()
assert (suffixes[0][0, :, 4:] == 0).all()
assert (prefixes[0][1, :, 4:] == 0).all()
assert (prefixes[0][:, 1, 2:] == 0).all()
assert not (suffixes[0][1, :, 4:] == 0).all()
assert (suffixes[0][1, :, 6:] == 0).all()
assert prefixes[0][0][0][0] == 0
assert prefixes[0][0][1][0] != 0
assert (prefixes[1] == 0).all()
assert (suffixes[1][0][0] == 0).all()
assert suffixes[1][0][1].tolist() == [39, 40, 0, 0]
assert suffixes[1][0][3].tolist() == [0, 114, 0, 80]
def test_get_affixes_4_byte_normal_char(en_tokenizer):
doc = en_tokenizer("and𐌞")
suffixes = doc.get_affixes(True, 2, 6, "a", 1, 2)
assert (suffixes[0][:, 0, 2] == 216).all()
assert suffixes[0][3, 0, 9] == 97
assert suffixes[1][0, 0, 1] == 97
def test_get_affixes_4_byte_special_char(en_tokenizer):
doc = en_tokenizer("and𐌞")
with pytest.raises(ValueError):
doc.get_affixes(True, 2, 6, "𐌞", 2, 3)

View File

@ -1743,6 +1743,8 @@ cdef class Doc:
special_chars_enc = special_chars.encode('utf-16BE')
cdef int sc_test_len = len(special_chars)
if sc_test_len * 2 != len(special_chars_enc):
raise ValueError(Errors.E1044)
cdef np.ndarray[np.uint8_t, ndim=3] outputs = numpy.zeros(
(len_end - len_start, num_tokens, (len_end - 1) * 2), dtype="uint8")