From 146d286da6affae688eaa40757c513d291294a1c Mon Sep 17 00:00:00 2001 From: richardpaulhudson Date: Wed, 14 Sep 2022 14:47:00 +0200 Subject: [PATCH] Refactoring --- spacy/tests/doc/test_doc_api.py | 18 +++++--- spacy/tokens/doc.pxd | 1 + spacy/tokens/doc.pyx | 79 ++++++++++++++++++++------------- 3 files changed, 61 insertions(+), 37 deletions(-) diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index 1b0fb7ad9..9630c895e 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -977,10 +977,13 @@ def test_doc_spans_setdefault(en_tokenizer): assert len(doc.spans["key3"]) == 2 -def test_get_affixes_good_case(en_tokenizer): +@pytest.mark.parametrize( + "case_sensitive", [True, False] +) +def test_get_affixes_good_case(en_tokenizer, case_sensitive): doc = en_tokenizer("spaCy✨ and Prodigy") - prefixes = doc.get_affixes(False, 1, 5, "", 2, 3) - suffixes = doc.get_affixes(True, 2, 6, "xx✨rP", 2, 3) + prefixes = doc.get_affixes(False, case_sensitive, 1, 5, "", 2, 3) + suffixes = doc.get_affixes(True, case_sensitive, 2, 6, "xx✨rp", 2, 3) assert prefixes[0][3, 3, 3] == suffixes[0][3, 3, 3] assert prefixes[0][3, 3, 2] == suffixes[0][3, 3, 4] assert (prefixes[0][0, :, 1:] == 0).all() @@ -991,12 +994,15 @@ def test_get_affixes_good_case(en_tokenizer): assert not (suffixes[0][1, :, 2:] == 0).all() assert (suffixes[0][1, :, 3:] == 0).all() assert suffixes[1][0][1].tolist() == [10024, 0] - assert suffixes[1][0][3].tolist() == [114, 112] + if case_sensitive: + assert suffixes[1][0][3].tolist() == [114, 0] + else: + assert suffixes[1][0][3].tolist() == [114, 112] def test_get_affixes_4_byte_normal_char(en_tokenizer): doc = en_tokenizer("and𐌞") - suffixes = doc.get_affixes(True, 2, 6, "a", 1, 2) + suffixes = doc.get_affixes(True, False, 2, 6, "a", 1, 2) assert (suffixes[0][:, 0, 1] == 55296).all() assert suffixes[0][3, 0, 4] == 97 assert suffixes[1][0, 0, 0] == 97 @@ -1005,4 +1011,4 @@ def test_get_affixes_4_byte_normal_char(en_tokenizer): def test_get_affixes_4_byte_special_char(en_tokenizer): doc = en_tokenizer("and𐌞") with pytest.raises(ValueError): - doc.get_affixes(True, 2, 6, "𐌞", 2, 3) + doc.get_affixes(True, False, 2, 6, "𐌞", 2, 3) diff --git a/spacy/tokens/doc.pxd b/spacy/tokens/doc.pxd index eede5160d..56176a3a7 100644 --- a/spacy/tokens/doc.pxd +++ b/spacy/tokens/doc.pxd @@ -9,6 +9,7 @@ from ..attrs cimport attr_id_t cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) nogil +cdef const unsigned char[:] get_utf16_memoryview(str unicode_string, bint check_2_bytes) ctypedef const LexemeC* const_Lexeme_ptr diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index c0b48f779..c732cd9ee 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -94,6 +94,18 @@ cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) return get_token_attr(token, feat_name) +cdef const unsigned char[:] get_utf16_memoryview(str unicode_string, bint check_2_bytes): + """ + Returns a memory view of the UTF-16 representation of a string with the default endianness of the platform. + Throws a ValueError if *check_2_bytes == True* and one or more characters in the UTF-16 representation + occupy four bytes rather than two. + """ + cdef const unsigned char[:] view = memoryview(unicode_string.encode("UTF-16"))[2:] # first two bytes are endianness + if check_2_bytes and len(unicode_string) * 2 != len(view): + raise ValueError(Errors.E1044) + return view + + class SetEntsDefault(str, Enum): blocked = "blocked" missing = "missing" @@ -1734,52 +1746,57 @@ cdef class Doc: j += 1 return output - def get_affixes(self, bint suffs_not_prefs, int len_start, int len_end, special_chars:str, int sc_len_start, int sc_len_end): + def get_affixes(self, bint suffs_not_prefs, bint case_sensitive, unsigned int len_start, unsigned int len_end, + str special_chars, unsigned int sc_len_start, unsigned int sc_len_end): """ TODO """ - cdef bytes byte_string - cdef np.uint16_t this_char - cdef int idx, len_byte_string, sc_char_idx, sc_test_idx, this_len, this_sc_len - - cdef int num_tokens = len(self) - cdef bytes sc_enc = special_chars.lower().encode("utf-16BE") - cdef int sc_test_len = len(special_chars) - if sc_test_len * 2 != len(sc_enc): - raise ValueError(Errors.E1044) - cdef np.ndarray[np.uint16_t, ndim=1] scs = numpy.empty((sc_test_len,), dtype="uint16") - for idx in range(sc_test_len): - scs[idx] = (sc_enc[idx*2] << 8) + sc_enc[idx * 2 + 1] - + if case_sensitive: + token_attrs = [t.orth_ for t in self] + else: + token_attrs = [t.lower_ for t in self] + cdef unsigned int sc_len = len(special_chars) + cdef const unsigned char[:] sc_bytes = get_utf16_memoryview(special_chars, True) + cdef np.ndarray[np.uint16_t, ndim=1] scs = numpy.ndarray((sc_len,), buffer=sc_bytes, dtype="uint16") + + cdef unsigned int num_tokens = len(self) cdef np.ndarray[np.uint16_t, ndim=3] outputs = numpy.zeros( (len_end - len_start, num_tokens, len_end - 1), dtype="uint16") cdef np.ndarray[np.uint16_t, ndim=3] sc_outputs = numpy.zeros( (sc_len_end - sc_len_start, num_tokens, sc_len_end - 1), dtype="uint16") - for token_idx in range(num_tokens): - byte_string = self[token_idx].lower_.encode("utf-16BE") - idx = 0 - sc_char_idx = 0 - len_byte_string = len(byte_string) + cdef const unsigned char[:] token_bytes + cdef np.uint16_t working_char + cdef unsigned int token_bytes_len, token_idx, char_idx, working_len, sc_char_idx, sc_test_idx, working_sc_len + cdef unsigned int char_byte_idx - while (idx < len_end - 1 or sc_char_idx < sc_len_end - 1) and idx * 2 < len_byte_string: - char_first_byte_idx = len_byte_string - 2 * (idx + 1) if suffs_not_prefs else idx * 2 - this_char = (byte_string[char_first_byte_idx] << 8) + byte_string[char_first_byte_idx + 1] - for this_len in range(len_end-1, len_start-1, -1): - if idx >= this_len: + for token_idx in range(num_tokens): + token_bytes = get_utf16_memoryview(token_attrs[token_idx], False) + char_idx = 0 + sc_char_idx = 0 + token_bytes_len = len(token_bytes) + + while (char_idx < len_end - 1 or sc_char_idx < sc_len_end - 1) and char_idx * 2 < token_bytes_len: + if suffs_not_prefs: + char_byte_idx = token_bytes_len - 2 * (char_idx + 1) + else: + char_byte_idx = char_idx * 2 + working_char = ( &token_bytes[char_byte_idx])[0] + for working_len in range(len_end-1, len_start-1, -1): + if char_idx >= working_len: break - outputs[this_len - len_start, token_idx, idx] = this_char + outputs[working_len - len_start, token_idx, char_idx] = working_char sc_test_idx = 0 - while sc_test_len > sc_test_idx: - if this_char == scs[sc_test_idx]: - for this_sc_len in range(sc_len_end-1, sc_len_start-1, -1): - if sc_char_idx >= this_sc_len: + while sc_len > sc_test_idx: + if working_char == scs[sc_test_idx]: + for working_sc_len in range(sc_len_end-1, sc_len_start-1, -1): + if sc_char_idx >= working_sc_len: break - sc_outputs[this_sc_len - sc_len_start, token_idx, sc_char_idx] = this_char + sc_outputs[working_sc_len - sc_len_start, token_idx, sc_char_idx] = working_char sc_char_idx += 1 break sc_test_idx += 1 - idx += 1 + char_idx += 1 return outputs, sc_outputs @staticmethod