From 749da9d34890f0d0f7c3594ea0b95169f0339f16 Mon Sep 17 00:00:00 2001 From: richardpaulhudson Date: Fri, 28 Oct 2022 14:42:42 +0200 Subject: [PATCH] Speed improvements --- spacy/tokens/doc.pxd | 4 ++-- spacy/tokens/doc.pyx | 55 +++++++++++++++++++++++--------------------- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/spacy/tokens/doc.pxd b/spacy/tokens/doc.pxd index ec9e12731..796525663 100644 --- a/spacy/tokens/doc.pxd +++ b/spacy/tokens/doc.pxd @@ -43,7 +43,7 @@ cdef void _set_affix_lengths( unsigned char* aff_l_buf, const int pref_l, const int suff_l, -) +) nogil cdef void _search_for_chars( @@ -56,7 +56,7 @@ cdef void _search_for_chars( int max_res_l, unsigned char* l_buf, bint suffs_not_prefs -) +) nogil cdef class Doc: diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 407323236..668353024 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -1735,22 +1735,22 @@ cdef class Doc: j += 1 return output - + @cython.boundscheck(False) # Deactivate bounds checking def get_character_combination_hashes(self, *, const bint cs, - np.ndarray p_lengths, - np.ndarray s_lengths, + int[:] p_lengths, + int[:] s_lengths, const unsigned char[:] ps_1byte_ch, const unsigned char[:] ps_2byte_ch, const unsigned char[:] ps_3byte_ch, const unsigned char[:] ps_4byte_ch, - np.ndarray ps_lengths, + int[:] ps_lengths, const unsigned char[:] ss_1byte_ch, const unsigned char[:] ss_2byte_ch, const unsigned char[:] ss_3byte_ch, const unsigned char[:] ss_4byte_ch, - np.ndarray ss_lengths, + int[:] ss_lengths, ): """ Returns a 2D NumPy array where the rows represent tokens and the columns represent hashes of various character combinations @@ -1782,10 +1782,10 @@ cdef class Doc: # Define the result array and work out what is used for what in axis 1 cdef int num_toks = len(self) - cdef int p_h_num = p_lengths.shape[0] - cdef int s_h_num = s_lengths.shape[0], s_h_end = p_h_num + s_h_num - cdef int ps_h_num = ps_lengths.shape[0], ps_h_end = s_h_end + ps_h_num - cdef int ss_h_num = ss_lengths.shape[0], ss_h_end = ps_h_end + ss_h_num + cdef int p_h_num = len(p_lengths) + cdef int s_h_num = len(s_lengths), s_h_end = p_h_num + s_h_num + cdef int ps_h_num = len(ps_lengths), ps_h_end = s_h_end + ps_h_num + cdef int ss_h_num = len(ss_lengths), ss_h_end = ps_h_end + ss_h_num cdef np.ndarray[np.int64_t, ndim=2] hashes hashes = numpy.empty((num_toks, ss_h_end), dtype="int64") @@ -1796,12 +1796,13 @@ cdef class Doc: cdef int ss_max_l = ss_lengths[-1] if ss_h_num > 0 else 0 # Define / allocate buffers + cdef Pool mem = Pool() cdef int aff_l = p_max_l + s_max_l - cdef unsigned char* aff_l_buf = self.mem.alloc(aff_l, 1) - cdef unsigned char* ps_res_buf = self.mem.alloc(ps_max_l, 4) - cdef unsigned char* ps_l_buf = self.mem.alloc(ps_max_l, 1) - cdef unsigned char* ss_res_buf = self.mem.alloc(ss_max_l, 4) - cdef unsigned char* ss_l_buf = self.mem.alloc(ss_max_l, 1) + cdef unsigned char* aff_l_buf = mem.alloc(aff_l, 1) + cdef unsigned char* ps_res_buf = mem.alloc(ps_max_l, 4) + cdef unsigned char* ps_l_buf = mem.alloc(ps_max_l, 1) + cdef unsigned char* ss_res_buf = mem.alloc(ss_max_l, 4) + cdef unsigned char* ss_l_buf = mem.alloc(ss_max_l, 1) # Define memory views on length arrays cdef int[:] p_lengths_v = p_lengths @@ -1854,12 +1855,6 @@ cdef class Doc: hash_val = hash32(ss_res_buf, offset, 0) hashes[tok_i, hash_idx] = hash_val - - self.mem.free(aff_l_buf) - self.mem.free(ps_res_buf) - self.mem.free(ps_l_buf) - self.mem.free(ss_res_buf) - self.mem.free(ss_l_buf) return hashes @staticmethod @@ -2042,13 +2037,13 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end): lca_matrix[k, j] = lca - start return lca_matrix - +@cython.boundscheck(False) # Deactivate bounds checking cdef void _set_affix_lengths( const unsigned char[:] tok_str, unsigned char* aff_l_buf, const int pref_l, const int suff_l, -): +) nogil: """ Populate *aff_l_buf*, which has length *pref_l+suff_l* with the byte lengths of the first *pref_l* and the last *suff_l* characters within *tok_str*. Lengths that are greater than the character length of the whole word are populated with the byte length of the whole word. @@ -2085,6 +2080,7 @@ cdef void _set_affix_lengths( if aff_l_buf_idx < pref_l + suff_l: memset(aff_l_buf + aff_l_buf_idx, aff_l_buf[aff_l_buf_idx - 1], pref_l + suff_l - aff_l_buf_idx) +@cython.boundscheck(False) # Deactivate bounds checking cdef void _search_for_chars( const unsigned char[:] tok_str, const unsigned char[:] s_1byte_ch, @@ -2095,7 +2091,7 @@ cdef void _search_for_chars( int max_res_l, unsigned char* l_buf, bint suffs_not_prefs -): +) nogil: """ Search *tok_str* within a string for characters within the *s_byte_ch> buffers, starting at the beginning or end depending on the value of *suffs_not_prefs*. Wherever a character matches, it is added to *res_buf* and the byte length up to that point is added to *len_buf*. When nothing @@ -2110,7 +2106,8 @@ cdef void _search_for_chars( The calling code ensures that lengths greater than 255 cannot occur. suffs_not_prefs: if *True*, searching starts from the end of the word; if *False*, from the beginning. """ - cdef int tok_str_l = len(tok_str), search_char_idx = 0, res_buf_idx = 0, l_buf_idx = 0, ch_wdth, tok_start_idx + cdef int tok_str_l = len(tok_str), res_buf_idx = 0, l_buf_idx = 0, ch_wdth, tok_start_idx, search_char_idx + cdef int search_chars_l cdef const unsigned char[:] search_chars cdef int last_tok_str_idx = tok_str_l if suffs_not_prefs else 0 @@ -2121,7 +2118,10 @@ cdef void _search_for_chars( this_tok_str_idx == tok_str_l or (tok_str[this_tok_str_idx] & 0xc0) != 0x80 # not continuation character, always applies to [0]. ): - ch_wdth = abs(this_tok_str_idx - last_tok_str_idx) + if this_tok_str_idx > last_tok_str_idx: + ch_wdth = this_tok_str_idx - last_tok_str_idx + else: + ch_wdth = last_tok_str_idx - this_tok_str_idx if ch_wdth == 1: search_chars = s_1byte_ch elif ch_wdth == 2: @@ -2130,9 +2130,11 @@ cdef void _search_for_chars( search_chars = s_3byte_ch else: search_chars = s_4byte_ch + search_chars_l = len(search_chars) tok_start_idx = this_tok_str_idx if suffs_not_prefs else last_tok_str_idx - for search_char_idx in range(0, len(search_chars), ch_wdth): + search_char_idx = 0 + while search_char_idx < search_chars_l: cmp_result = memcmp(&tok_str[tok_start_idx], &search_chars[search_char_idx], ch_wdth) if cmp_result == 0: memcpy(res_buf + res_buf_idx, &search_chars[search_char_idx], ch_wdth) @@ -2143,6 +2145,7 @@ cdef void _search_for_chars( return if cmp_result <= 0: break + search_char_idx += ch_wdth last_tok_str_idx = this_tok_str_idx if suffs_not_prefs: this_tok_str_idx -= 1