diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index fa895919a..c5f138a01 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -984,34 +984,28 @@ def test_get_affixes_good_case(en_tokenizer, case_sensitive): doc = en_tokenizer("spaCy✨ and Prodigy") prefixes = doc.get_affixes(False, case_sensitive, 1, 5, "", 2, 3) suffixes = doc.get_affixes(True, case_sensitive, 2, 6, "xx✨rP", 2, 3) - assert prefixes[0][3, 3, 3] == suffixes[0][3, 3, 3] - assert prefixes[0][3, 3, 2] == suffixes[0][3, 3, 4] - assert (prefixes[0][0, :, 1:] == 0).all() - assert not (suffixes[0][0, :, 1:] == 0).all() - assert (suffixes[0][0, :, 2:] == 0).all() - assert (prefixes[0][1, :, 2:] == 0).all() - assert (prefixes[0][:, 1, 1:] == 0).all() - assert not (suffixes[0][1, :, 2:] == 0).all() - assert (suffixes[0][1, :, 3:] == 0).all() - assert suffixes[1][0][1].tolist() == [10024, 0] + assert prefixes[3][3, 3] == suffixes[3][3, 3] + assert prefixes[3][3, 2] == suffixes[3][3, 4] + assert suffixes[4][1].tolist() == [10024, 0] if case_sensitive: - assert suffixes[1][0][3].tolist() == [114, 80] + assert suffixes[4][3].tolist() == [114, 80] else: - assert suffixes[1][0][3].tolist() == [114, 112] + assert suffixes[4][3].tolist() == [114, 112] suffixes = doc.get_affixes(True, case_sensitive, 2, 6, "xx✨rp", 2, 3) if case_sensitive: - assert suffixes[1][0][3].tolist() == [114, 0] + assert suffixes[4][3].tolist() == [114, 0] else: - assert suffixes[1][0][3].tolist() == [114, 112] + assert suffixes[4][3].tolist() == [114, 112] def test_get_affixes_4_byte_normal_char(en_tokenizer): doc = en_tokenizer("and𐌞") suffixes = doc.get_affixes(True, False, 2, 6, "a", 1, 2) - assert (suffixes[0][:, 0, 1] == 55296).all() - assert suffixes[0][3, 0, 4] == 97 - assert suffixes[1][0, 0, 0] == 97 + for i in range(0, 4): + assert suffixes[i][0, 1] == 55296 + assert suffixes[3][0, 4] == 97 + assert suffixes[4][0, 0] == 97 def test_get_affixes_4_byte_special_char(en_tokenizer): diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 2aea6d9b4..3cd8281b8 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -1759,16 +1759,20 @@ cdef class Doc: cdef unsigned int sc_len = len(special_chars) cdef const unsigned char[:] sc_bytes = get_utf16_memoryview(special_chars, True) cdef np.ndarray[np.uint16_t, ndim=1] scs = numpy.ndarray((sc_len,), buffer=sc_bytes, dtype="uint16") - - cdef unsigned int num_tokens = len(self) - cdef np.ndarray[np.uint16_t, ndim=3] outputs = numpy.zeros( - (len_end - len_start, num_tokens, len_end - 1), dtype="uint16") - cdef np.ndarray[np.uint16_t, ndim=3] sc_outputs = numpy.zeros( - (sc_len_end - sc_len_start, num_tokens, sc_len_end - 1), dtype="uint16") + cdef np.ndarray[np.uint16_t, ndim=2] output + cdef unsigned int num_tokens = len(self), num_normal_affixes = len_end - len_start, working_len + + outputs = [] + for working_len in range(num_normal_affixes): + output = numpy.zeros((num_tokens, len_start + working_len), dtype="uint16") + outputs.append(output) + for working_len in range(sc_len_end - sc_len_start): + output = numpy.zeros((num_tokens, sc_len_start + working_len), dtype="uint16") + outputs.append(output) cdef const unsigned char[:] token_bytes cdef np.uint16_t working_char - cdef unsigned int token_bytes_len, token_idx, char_idx, working_len, sc_char_idx, sc_test_idx, working_sc_len + cdef unsigned int token_bytes_len, token_idx, char_idx, sc_char_idx, sc_test_idx, working_sc_len cdef unsigned int char_byte_idx for token_idx in range(num_tokens): @@ -1786,19 +1790,19 @@ cdef class Doc: for working_len in range(len_end-1, len_start-1, -1): if char_idx >= working_len: break - outputs[working_len - len_start, token_idx, char_idx] = working_char + outputs[working_len - len_start][token_idx, char_idx] = working_char sc_test_idx = 0 while sc_len > sc_test_idx: if working_char == scs[sc_test_idx]: for working_sc_len in range(sc_len_end-1, sc_len_start-1, -1): if sc_char_idx >= working_sc_len: break - sc_outputs[working_sc_len - sc_len_start, token_idx, sc_char_idx] = working_char + outputs[num_normal_affixes + working_sc_len - sc_len_start][token_idx, sc_char_idx] = working_char sc_char_idx += 1 break sc_test_idx += 1 char_idx += 1 - return outputs, sc_outputs + return outputs @staticmethod def _get_array_attrs():