Speed improvements

This commit is contained in:
richardpaulhudson 2022-10-28 14:42:42 +02:00
parent 217ff36559
commit 749da9d348
2 changed files with 31 additions and 28 deletions

View File

@ -43,7 +43,7 @@ cdef void _set_affix_lengths(
unsigned char* aff_l_buf,
const int pref_l,
const int suff_l,
)
) nogil
cdef void _search_for_chars(
@ -56,7 +56,7 @@ cdef void _search_for_chars(
int max_res_l,
unsigned char* l_buf,
bint suffs_not_prefs
)
) nogil
cdef class Doc:

View File

@ -1735,22 +1735,22 @@ cdef class Doc:
j += 1
return output
@cython.boundscheck(False) # Deactivate bounds checking
def get_character_combination_hashes(self,
*,
const bint cs,
np.ndarray p_lengths,
np.ndarray s_lengths,
int[:] p_lengths,
int[:] s_lengths,
const unsigned char[:] ps_1byte_ch,
const unsigned char[:] ps_2byte_ch,
const unsigned char[:] ps_3byte_ch,
const unsigned char[:] ps_4byte_ch,
np.ndarray ps_lengths,
int[:] ps_lengths,
const unsigned char[:] ss_1byte_ch,
const unsigned char[:] ss_2byte_ch,
const unsigned char[:] ss_3byte_ch,
const unsigned char[:] ss_4byte_ch,
np.ndarray ss_lengths,
int[:] ss_lengths,
):
"""
Returns a 2D NumPy array where the rows represent tokens and the columns represent hashes of various character combinations
@ -1782,10 +1782,10 @@ cdef class Doc:
# Define the result array and work out what is used for what in axis 1
cdef int num_toks = len(self)
cdef int p_h_num = p_lengths.shape[0]
cdef int s_h_num = s_lengths.shape[0], s_h_end = p_h_num + s_h_num
cdef int ps_h_num = ps_lengths.shape[0], ps_h_end = s_h_end + ps_h_num
cdef int ss_h_num = ss_lengths.shape[0], ss_h_end = ps_h_end + ss_h_num
cdef int p_h_num = len(p_lengths)
cdef int s_h_num = len(s_lengths), s_h_end = p_h_num + s_h_num
cdef int ps_h_num = len(ps_lengths), ps_h_end = s_h_end + ps_h_num
cdef int ss_h_num = len(ss_lengths), ss_h_end = ps_h_end + ss_h_num
cdef np.ndarray[np.int64_t, ndim=2] hashes
hashes = numpy.empty((num_toks, ss_h_end), dtype="int64")
@ -1796,12 +1796,13 @@ cdef class Doc:
cdef int ss_max_l = ss_lengths[-1] if ss_h_num > 0 else 0
# Define / allocate buffers
cdef Pool mem = Pool()
cdef int aff_l = p_max_l + s_max_l
cdef unsigned char* aff_l_buf = <unsigned char*> self.mem.alloc(aff_l, 1)
cdef unsigned char* ps_res_buf = <unsigned char*> self.mem.alloc(ps_max_l, 4)
cdef unsigned char* ps_l_buf = <unsigned char*> self.mem.alloc(ps_max_l, 1)
cdef unsigned char* ss_res_buf = <unsigned char*> self.mem.alloc(ss_max_l, 4)
cdef unsigned char* ss_l_buf = <unsigned char*> self.mem.alloc(ss_max_l, 1)
cdef unsigned char* aff_l_buf = <unsigned char*> mem.alloc(aff_l, 1)
cdef unsigned char* ps_res_buf = <unsigned char*> mem.alloc(ps_max_l, 4)
cdef unsigned char* ps_l_buf = <unsigned char*> mem.alloc(ps_max_l, 1)
cdef unsigned char* ss_res_buf = <unsigned char*> mem.alloc(ss_max_l, 4)
cdef unsigned char* ss_l_buf = <unsigned char*> mem.alloc(ss_max_l, 1)
# Define memory views on length arrays
cdef int[:] p_lengths_v = p_lengths
@ -1854,12 +1855,6 @@ cdef class Doc:
hash_val = hash32(ss_res_buf, offset, 0)
hashes[tok_i, hash_idx] = hash_val
self.mem.free(aff_l_buf)
self.mem.free(ps_res_buf)
self.mem.free(ps_l_buf)
self.mem.free(ss_res_buf)
self.mem.free(ss_l_buf)
return hashes
@staticmethod
@ -2042,13 +2037,13 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
lca_matrix[k, j] = lca - start
return lca_matrix
@cython.boundscheck(False) # Deactivate bounds checking
cdef void _set_affix_lengths(
const unsigned char[:] tok_str,
unsigned char* aff_l_buf,
const int pref_l,
const int suff_l,
):
) nogil:
""" Populate *aff_l_buf*, which has length *pref_l+suff_l* with the byte lengths of the first *pref_l* and the last
*suff_l* characters within *tok_str*. Lengths that are greater than the character length of the whole word are
populated with the byte length of the whole word.
@ -2085,6 +2080,7 @@ cdef void _set_affix_lengths(
if aff_l_buf_idx < pref_l + suff_l:
memset(aff_l_buf + aff_l_buf_idx, aff_l_buf[aff_l_buf_idx - 1], pref_l + suff_l - aff_l_buf_idx)
@cython.boundscheck(False) # Deactivate bounds checking
cdef void _search_for_chars(
const unsigned char[:] tok_str,
const unsigned char[:] s_1byte_ch,
@ -2095,7 +2091,7 @@ cdef void _search_for_chars(
int max_res_l,
unsigned char* l_buf,
bint suffs_not_prefs
):
) nogil:
""" Search *tok_str* within a string for characters within the *s_<n>byte_ch> buffers, starting at the
beginning or end depending on the value of *suffs_not_prefs*. Wherever a character matches,
it is added to *res_buf* and the byte length up to that point is added to *len_buf*. When nothing
@ -2110,7 +2106,8 @@ cdef void _search_for_chars(
The calling code ensures that lengths greater than 255 cannot occur.
suffs_not_prefs: if *True*, searching starts from the end of the word; if *False*, from the beginning.
"""
cdef int tok_str_l = len(tok_str), search_char_idx = 0, res_buf_idx = 0, l_buf_idx = 0, ch_wdth, tok_start_idx
cdef int tok_str_l = len(tok_str), res_buf_idx = 0, l_buf_idx = 0, ch_wdth, tok_start_idx, search_char_idx
cdef int search_chars_l
cdef const unsigned char[:] search_chars
cdef int last_tok_str_idx = tok_str_l if suffs_not_prefs else 0
@ -2121,7 +2118,10 @@ cdef void _search_for_chars(
this_tok_str_idx == tok_str_l or
(tok_str[this_tok_str_idx] & 0xc0) != 0x80 # not continuation character, always applies to [0].
):
ch_wdth = abs(this_tok_str_idx - last_tok_str_idx)
if this_tok_str_idx > last_tok_str_idx:
ch_wdth = this_tok_str_idx - last_tok_str_idx
else:
ch_wdth = last_tok_str_idx - this_tok_str_idx
if ch_wdth == 1:
search_chars = s_1byte_ch
elif ch_wdth == 2:
@ -2130,9 +2130,11 @@ cdef void _search_for_chars(
search_chars = s_3byte_ch
else:
search_chars = s_4byte_ch
search_chars_l = len(search_chars)
tok_start_idx = this_tok_str_idx if suffs_not_prefs else last_tok_str_idx
for search_char_idx in range(0, len(search_chars), ch_wdth):
search_char_idx = 0
while search_char_idx < search_chars_l:
cmp_result = memcmp(&tok_str[tok_start_idx], &search_chars[search_char_idx], ch_wdth)
if cmp_result == 0:
memcpy(res_buf + res_buf_idx, &search_chars[search_char_idx], ch_wdth)
@ -2143,6 +2145,7 @@ cdef void _search_for_chars(
return
if cmp_result <= 0:
break
search_char_idx += ch_wdth
last_tok_str_idx = this_tok_str_idx
if suffs_not_prefs:
this_tok_str_idx -= 1