From 910a6bc98fb3797b17d8626118061140e6ead839 Mon Sep 17 00:00:00 2001 From: richardpaulhudson Date: Tue, 13 Sep 2022 12:45:40 +0200 Subject: [PATCH] Add tests for get_affixes --- spacy/tests/doc/test_doc_api.py | 37 +++++++++++++++++++++++++++++++++ spacy/tokens/doc.pyx | 2 ++ 2 files changed, 39 insertions(+) diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index a64ab2ba8..c9cc4113b 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -975,3 +975,40 @@ def test_doc_spans_setdefault(en_tokenizer): assert len(doc.spans["key2"]) == 1 doc.spans.setdefault("key3", default=SpanGroup(doc, spans=[doc[0:1], doc[1:2]])) assert len(doc.spans["key3"]) == 2 + + +def test_get_affixes_good_case(en_tokenizer): + doc = en_tokenizer("spaCy✨ and Prodigy") + prefixes = doc.get_affixes(False, 1, 5, "", 2, 3) + suffixes = doc.get_affixes(True, 2, 6, "xx✨rP", 2, 3) + assert prefixes[0][3, 3, 6] == suffixes[0][3, 3, 6] + assert prefixes[0][3, 3, 7] == suffixes[0][3, 3, 7] + assert prefixes[0][3, 3, 4] == suffixes[0][3, 3, 8] + assert prefixes[0][3, 3, 5] == suffixes[0][3, 3, 9] + assert (prefixes[0][0, :, 2:] == 0).all() + assert not (suffixes[0][0, :, 2:] == 0).all() + assert (suffixes[0][0, :, 4:] == 0).all() + assert (prefixes[0][1, :, 4:] == 0).all() + assert (prefixes[0][:, 1, 2:] == 0).all() + assert not (suffixes[0][1, :, 4:] == 0).all() + assert (suffixes[0][1, :, 6:] == 0).all() + assert prefixes[0][0][0][0] == 0 + assert prefixes[0][0][1][0] != 0 + assert (prefixes[1] == 0).all() + assert (suffixes[1][0][0] == 0).all() + assert suffixes[1][0][1].tolist() == [39, 40, 0, 0] + assert suffixes[1][0][3].tolist() == [0, 114, 0, 80] + + +def test_get_affixes_4_byte_normal_char(en_tokenizer): + doc = en_tokenizer("and𐌞") + suffixes = doc.get_affixes(True, 2, 6, "a", 1, 2) + assert (suffixes[0][:, 0, 2] == 216).all() + assert suffixes[0][3, 0, 9] == 97 + assert suffixes[1][0, 0, 1] == 97 + + +def test_get_affixes_4_byte_special_char(en_tokenizer): + doc = en_tokenizer("and𐌞") + with pytest.raises(ValueError): + doc.get_affixes(True, 2, 6, "𐌞", 2, 3) diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index dda282bf4..35082793d 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -1743,6 +1743,8 @@ cdef class Doc: special_chars_enc = special_chars.encode('utf-16BE') cdef int sc_test_len = len(special_chars) + if sc_test_len * 2 != len(special_chars_enc): + raise ValueError(Errors.E1044) cdef np.ndarray[np.uint8_t, ndim=3] outputs = numpy.zeros( (len_end - len_start, num_tokens, (len_end - 1) * 2), dtype="uint8")