Speed improvements

This commit is contained in:
richardpaulhudson 2022-10-28 14:42:42 +02:00
parent 217ff36559
commit 749da9d348
2 changed files with 31 additions and 28 deletions

View File

@ -43,7 +43,7 @@ cdef void _set_affix_lengths(
unsigned char* aff_l_buf, unsigned char* aff_l_buf,
const int pref_l, const int pref_l,
const int suff_l, const int suff_l,
) ) nogil
cdef void _search_for_chars( cdef void _search_for_chars(
@ -56,7 +56,7 @@ cdef void _search_for_chars(
int max_res_l, int max_res_l,
unsigned char* l_buf, unsigned char* l_buf,
bint suffs_not_prefs bint suffs_not_prefs
) ) nogil
cdef class Doc: cdef class Doc:

View File

@ -1735,22 +1735,22 @@ cdef class Doc:
j += 1 j += 1
return output return output
@cython.boundscheck(False) # Deactivate bounds checking
def get_character_combination_hashes(self, def get_character_combination_hashes(self,
*, *,
const bint cs, const bint cs,
np.ndarray p_lengths, int[:] p_lengths,
np.ndarray s_lengths, int[:] s_lengths,
const unsigned char[:] ps_1byte_ch, const unsigned char[:] ps_1byte_ch,
const unsigned char[:] ps_2byte_ch, const unsigned char[:] ps_2byte_ch,
const unsigned char[:] ps_3byte_ch, const unsigned char[:] ps_3byte_ch,
const unsigned char[:] ps_4byte_ch, const unsigned char[:] ps_4byte_ch,
np.ndarray ps_lengths, int[:] ps_lengths,
const unsigned char[:] ss_1byte_ch, const unsigned char[:] ss_1byte_ch,
const unsigned char[:] ss_2byte_ch, const unsigned char[:] ss_2byte_ch,
const unsigned char[:] ss_3byte_ch, const unsigned char[:] ss_3byte_ch,
const unsigned char[:] ss_4byte_ch, const unsigned char[:] ss_4byte_ch,
np.ndarray ss_lengths, int[:] ss_lengths,
): ):
""" """
Returns a 2D NumPy array where the rows represent tokens and the columns represent hashes of various character combinations Returns a 2D NumPy array where the rows represent tokens and the columns represent hashes of various character combinations
@ -1782,10 +1782,10 @@ cdef class Doc:
# Define the result array and work out what is used for what in axis 1 # Define the result array and work out what is used for what in axis 1
cdef int num_toks = len(self) cdef int num_toks = len(self)
cdef int p_h_num = p_lengths.shape[0] cdef int p_h_num = len(p_lengths)
cdef int s_h_num = s_lengths.shape[0], s_h_end = p_h_num + s_h_num cdef int s_h_num = len(s_lengths), s_h_end = p_h_num + s_h_num
cdef int ps_h_num = ps_lengths.shape[0], ps_h_end = s_h_end + ps_h_num cdef int ps_h_num = len(ps_lengths), ps_h_end = s_h_end + ps_h_num
cdef int ss_h_num = ss_lengths.shape[0], ss_h_end = ps_h_end + ss_h_num cdef int ss_h_num = len(ss_lengths), ss_h_end = ps_h_end + ss_h_num
cdef np.ndarray[np.int64_t, ndim=2] hashes cdef np.ndarray[np.int64_t, ndim=2] hashes
hashes = numpy.empty((num_toks, ss_h_end), dtype="int64") hashes = numpy.empty((num_toks, ss_h_end), dtype="int64")
@ -1796,12 +1796,13 @@ cdef class Doc:
cdef int ss_max_l = ss_lengths[-1] if ss_h_num > 0 else 0 cdef int ss_max_l = ss_lengths[-1] if ss_h_num > 0 else 0
# Define / allocate buffers # Define / allocate buffers
cdef Pool mem = Pool()
cdef int aff_l = p_max_l + s_max_l cdef int aff_l = p_max_l + s_max_l
cdef unsigned char* aff_l_buf = <unsigned char*> self.mem.alloc(aff_l, 1) cdef unsigned char* aff_l_buf = <unsigned char*> mem.alloc(aff_l, 1)
cdef unsigned char* ps_res_buf = <unsigned char*> self.mem.alloc(ps_max_l, 4) cdef unsigned char* ps_res_buf = <unsigned char*> mem.alloc(ps_max_l, 4)
cdef unsigned char* ps_l_buf = <unsigned char*> self.mem.alloc(ps_max_l, 1) cdef unsigned char* ps_l_buf = <unsigned char*> mem.alloc(ps_max_l, 1)
cdef unsigned char* ss_res_buf = <unsigned char*> self.mem.alloc(ss_max_l, 4) cdef unsigned char* ss_res_buf = <unsigned char*> mem.alloc(ss_max_l, 4)
cdef unsigned char* ss_l_buf = <unsigned char*> self.mem.alloc(ss_max_l, 1) cdef unsigned char* ss_l_buf = <unsigned char*> mem.alloc(ss_max_l, 1)
# Define memory views on length arrays # Define memory views on length arrays
cdef int[:] p_lengths_v = p_lengths cdef int[:] p_lengths_v = p_lengths
@ -1854,12 +1855,6 @@ cdef class Doc:
hash_val = hash32(ss_res_buf, offset, 0) hash_val = hash32(ss_res_buf, offset, 0)
hashes[tok_i, hash_idx] = hash_val hashes[tok_i, hash_idx] = hash_val
self.mem.free(aff_l_buf)
self.mem.free(ps_res_buf)
self.mem.free(ps_l_buf)
self.mem.free(ss_res_buf)
self.mem.free(ss_l_buf)
return hashes return hashes
@staticmethod @staticmethod
@ -2042,13 +2037,13 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
lca_matrix[k, j] = lca - start lca_matrix[k, j] = lca - start
return lca_matrix return lca_matrix
@cython.boundscheck(False) # Deactivate bounds checking
cdef void _set_affix_lengths( cdef void _set_affix_lengths(
const unsigned char[:] tok_str, const unsigned char[:] tok_str,
unsigned char* aff_l_buf, unsigned char* aff_l_buf,
const int pref_l, const int pref_l,
const int suff_l, const int suff_l,
): ) nogil:
""" Populate *aff_l_buf*, which has length *pref_l+suff_l* with the byte lengths of the first *pref_l* and the last """ Populate *aff_l_buf*, which has length *pref_l+suff_l* with the byte lengths of the first *pref_l* and the last
*suff_l* characters within *tok_str*. Lengths that are greater than the character length of the whole word are *suff_l* characters within *tok_str*. Lengths that are greater than the character length of the whole word are
populated with the byte length of the whole word. populated with the byte length of the whole word.
@ -2085,6 +2080,7 @@ cdef void _set_affix_lengths(
if aff_l_buf_idx < pref_l + suff_l: if aff_l_buf_idx < pref_l + suff_l:
memset(aff_l_buf + aff_l_buf_idx, aff_l_buf[aff_l_buf_idx - 1], pref_l + suff_l - aff_l_buf_idx) memset(aff_l_buf + aff_l_buf_idx, aff_l_buf[aff_l_buf_idx - 1], pref_l + suff_l - aff_l_buf_idx)
@cython.boundscheck(False) # Deactivate bounds checking
cdef void _search_for_chars( cdef void _search_for_chars(
const unsigned char[:] tok_str, const unsigned char[:] tok_str,
const unsigned char[:] s_1byte_ch, const unsigned char[:] s_1byte_ch,
@ -2095,7 +2091,7 @@ cdef void _search_for_chars(
int max_res_l, int max_res_l,
unsigned char* l_buf, unsigned char* l_buf,
bint suffs_not_prefs bint suffs_not_prefs
): ) nogil:
""" Search *tok_str* within a string for characters within the *s_<n>byte_ch> buffers, starting at the """ Search *tok_str* within a string for characters within the *s_<n>byte_ch> buffers, starting at the
beginning or end depending on the value of *suffs_not_prefs*. Wherever a character matches, beginning or end depending on the value of *suffs_not_prefs*. Wherever a character matches,
it is added to *res_buf* and the byte length up to that point is added to *len_buf*. When nothing it is added to *res_buf* and the byte length up to that point is added to *len_buf*. When nothing
@ -2110,7 +2106,8 @@ cdef void _search_for_chars(
The calling code ensures that lengths greater than 255 cannot occur. The calling code ensures that lengths greater than 255 cannot occur.
suffs_not_prefs: if *True*, searching starts from the end of the word; if *False*, from the beginning. suffs_not_prefs: if *True*, searching starts from the end of the word; if *False*, from the beginning.
""" """
cdef int tok_str_l = len(tok_str), search_char_idx = 0, res_buf_idx = 0, l_buf_idx = 0, ch_wdth, tok_start_idx cdef int tok_str_l = len(tok_str), res_buf_idx = 0, l_buf_idx = 0, ch_wdth, tok_start_idx, search_char_idx
cdef int search_chars_l
cdef const unsigned char[:] search_chars cdef const unsigned char[:] search_chars
cdef int last_tok_str_idx = tok_str_l if suffs_not_prefs else 0 cdef int last_tok_str_idx = tok_str_l if suffs_not_prefs else 0
@ -2121,7 +2118,10 @@ cdef void _search_for_chars(
this_tok_str_idx == tok_str_l or this_tok_str_idx == tok_str_l or
(tok_str[this_tok_str_idx] & 0xc0) != 0x80 # not continuation character, always applies to [0]. (tok_str[this_tok_str_idx] & 0xc0) != 0x80 # not continuation character, always applies to [0].
): ):
ch_wdth = abs(this_tok_str_idx - last_tok_str_idx) if this_tok_str_idx > last_tok_str_idx:
ch_wdth = this_tok_str_idx - last_tok_str_idx
else:
ch_wdth = last_tok_str_idx - this_tok_str_idx
if ch_wdth == 1: if ch_wdth == 1:
search_chars = s_1byte_ch search_chars = s_1byte_ch
elif ch_wdth == 2: elif ch_wdth == 2:
@ -2130,9 +2130,11 @@ cdef void _search_for_chars(
search_chars = s_3byte_ch search_chars = s_3byte_ch
else: else:
search_chars = s_4byte_ch search_chars = s_4byte_ch
search_chars_l = len(search_chars)
tok_start_idx = this_tok_str_idx if suffs_not_prefs else last_tok_str_idx tok_start_idx = this_tok_str_idx if suffs_not_prefs else last_tok_str_idx
for search_char_idx in range(0, len(search_chars), ch_wdth): search_char_idx = 0
while search_char_idx < search_chars_l:
cmp_result = memcmp(&tok_str[tok_start_idx], &search_chars[search_char_idx], ch_wdth) cmp_result = memcmp(&tok_str[tok_start_idx], &search_chars[search_char_idx], ch_wdth)
if cmp_result == 0: if cmp_result == 0:
memcpy(res_buf + res_buf_idx, &search_chars[search_char_idx], ch_wdth) memcpy(res_buf + res_buf_idx, &search_chars[search_char_idx], ch_wdth)
@ -2143,6 +2145,7 @@ cdef void _search_for_chars(
return return
if cmp_result <= 0: if cmp_result <= 0:
break break
search_char_idx += ch_wdth
last_tok_str_idx = this_tok_str_idx last_tok_str_idx = this_tok_str_idx
if suffs_not_prefs: if suffs_not_prefs:
this_tok_str_idx -= 1 this_tok_str_idx -= 1