From b2074a15d10fb15dcc5141a3398d99fccabe51d2 Mon Sep 17 00:00:00 2001 From: richardpaulhudson Date: Mon, 12 Sep 2022 22:20:57 +0200 Subject: [PATCH] Improvements --- spacy/tokens/doc.pxd | 1 - spacy/tokens/doc.pyx | 75 ++++++++++++++++++++++++-------------------- 2 files changed, 41 insertions(+), 35 deletions(-) diff --git a/spacy/tokens/doc.pxd b/spacy/tokens/doc.pxd index ae5e945f5..eede5160d 100644 --- a/spacy/tokens/doc.pxd +++ b/spacy/tokens/doc.pxd @@ -9,7 +9,6 @@ from ..attrs cimport attr_id_t cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) nogil -cdef np.ndarray init_array(int num_tokens, int length) ctypedef const LexemeC* const_Lexeme_ptr diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 987690dc3..dda282bf4 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -94,12 +94,6 @@ cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) return get_token_attr(token, feat_name) -cdef np.ndarray init_array(int num_tokens, int length): - cdef np.ndarray output = numpy.zeros((num_tokens, length), dtype='uint8') - output.fill(255) - return output - - class SetEntsDefault(str, Enum): blocked = "blocked" missing = "missing" @@ -1740,41 +1734,54 @@ cdef class Doc: j += 1 return output - def get_suffixes(self, int min_length, int max_length, special_chars:str, int sc_min_length, int sc_max_length): + def get_affixes(self, bint suffs_not_prefs, int len_start, int len_end, special_chars:str, int sc_len_start, int sc_len_end): """ TODO """ - byte_strings = [token.orth_.encode('utf8') for token in self] - special_chars_enc = special_chars.encode('utf8') - cdef num_tokens = len(byte_strings) - outputs = [] - for length in range(min_length, max_length+1): - outputs.append(init_array(num_tokens, length)) - for length in range(sc_min_length, sc_max_length+1): - outputs.append(init_array(num_tokens, length)) - - cdef int token_i, sc_char_i, idx + byte_strings = [token.orth_.encode('utf-16BE') for token in self] + cdef int num_tokens = len(byte_strings) + + special_chars_enc = special_chars.encode('utf-16BE') + cdef int sc_test_len = len(special_chars) + + cdef np.ndarray[np.uint8_t, ndim=3] outputs = numpy.zeros( + (len_end - len_start, num_tokens, (len_end - 1) * 2), dtype="uint8") + cdef np.ndarray[np.uint8_t, ndim=3] sc_outputs = numpy.zeros( + (sc_len_end - sc_len_start, num_tokens, (sc_len_end - 1) * 2), dtype="uint8") + cdef bytes byte_string - cdef unsigned char utf8_char - cdef num_normal_arr = 1 + max_length - min_length - cdef num_sc_arr = 1 + sc_max_length - sc_min_length - for token_i, byte_string in enumerate(byte_strings): - sc_char_i = 0 + cdef char this_char_part, next_char_part, this_test_char_part, next_test_char_part + cdef int len_byte_string, idx, sc_char_idx, sc_test_idx, this_len, this_sc_len + + for token_idx, byte_string in enumerate(byte_strings): idx = 0 - while (idx < max_length or sc_char_i < sc_max_length) and idx < len(byte_string): - this_char = byte_string[len(byte_string) - (1 + idx)] - for normal_arr_i in range(num_normal_arr - 1, -1, -1): - if idx >= normal_arr_i + min_length: + sc_char_idx = 0 + len_byte_string = len(byte_string) + + while (idx < len_end - 1 or sc_char_idx < sc_len_end - 1) and idx * 2 < len_byte_string: + char_first_byte_idx = len_byte_string - 2 * (idx + 1) if suffs_not_prefs else idx * 2 + this_char_part = byte_string[char_first_byte_idx] + next_char_part = byte_string[char_first_byte_idx + 1] + for this_len in range(len_end-1, len_start-1, -1): + if idx >= this_len: break - outputs[normal_arr_i][token_i, idx] = this_char - if this_char in special_chars_enc: - for sc_arr_i in range(num_sc_arr - 1, -1, -1): - if sc_char_i >= sc_arr_i + sc_min_length: - break - outputs[sc_arr_i + num_normal_arr][token_i, sc_char_i] = this_char - sc_char_i += 1 + outputs[this_len - len_start, token_idx, idx * 2] = this_char_part + outputs[this_len - len_start, token_idx, idx * 2 + 1] = next_char_part + sc_test_idx = 0 + while sc_test_len > sc_test_idx: + this_test_char_part = special_chars_enc[sc_test_idx*2] + next_test_char_part = special_chars_enc[sc_test_idx*2 + 1] + if this_char_part == this_test_char_part and next_char_part == next_test_char_part: + for this_sc_len in range(sc_len_end-1, sc_len_start-1, -1): + if sc_char_idx >= this_sc_len: + break + sc_outputs[this_sc_len - sc_len_start, token_idx, sc_char_idx * 2] = this_char_part + sc_outputs[this_sc_len - sc_len_start, token_idx, sc_char_idx * 2 + 1] = next_char_part + sc_char_idx += 1 + break + sc_test_idx += 1 idx += 1 - return outputs + return outputs, sc_outputs @staticmethod def _get_array_attrs():