Tests passing again after refactoring

This commit is contained in:
richardpaulhudson 2022-10-28 13:31:14 +02:00
parent 5d151b4abe
commit 217ff36559
4 changed files with 254 additions and 185 deletions

View File

@ -320,7 +320,7 @@ cdef class StringStore:
if hash_val == 0:
return ""
elif hash_val < len(SYMBOLS_BY_INT):
return SYMBOLS_BY_INT[hash_val]
return SYMBOLS_BY_INT[hash_val].encode("utf-8")
cdef Utf8Str* string = <Utf8Str*>self._map.get(hash_val)
cdef int i, length
if string.s[0] < sizeof(string.s) and string.s[0] != 0:

View File

@ -14,7 +14,7 @@ from spacy.lang.xx import MultiLanguage
from spacy.language import Language
from spacy.lexeme import Lexeme
from spacy.tokens import Doc, Span, SpanGroup, Token
from spacy.util import get_arrays_for_search_chars
from spacy.util import get_search_char_byte_arrays
from spacy.vocab import Vocab
from .test_underscore import clean_underscore # noqa: F401
@ -995,8 +995,7 @@ def test_doc_spans_setdefault(en_tokenizer):
def _get_unsigned_32_bit_hash(input: str) -> int:
input = input.replace(" ", "\x00")
working_hash = hash(input.encode("UTF-32LE"))
working_hash = hash(input.encode("UTF-8"))
if working_hash < 0:
working_hash = working_hash + (2 << 31)
return working_hash
@ -1006,27 +1005,29 @@ def _get_unsigned_32_bit_hash(input: str) -> int:
def test_get_character_combination_hashes_good_case(en_tokenizer, case_sensitive):
doc = en_tokenizer("spaCy✨ and Prodigy")
ops = get_current_ops()
pref_search, pref_lookup = get_arrays_for_search_chars("Rp", case_sensitive)
suff_search, suff_lookup = get_arrays_for_search_chars("xx✨rp", case_sensitive)
ps1, ps2, ps3, ps4 = get_search_char_byte_arrays("Rp", case_sensitive)
ss1, ss2, ss3, ss4 = get_search_char_byte_arrays("xx✨rp", case_sensitive)
hashes = doc.get_character_combination_hashes(
cs=case_sensitive,
p_lengths=ops.asarray1i([1, 4, 3]),
p_lengths=ops.asarray1i([1, 3, 4]),
s_lengths=ops.asarray1i([2, 3, 4, 5]),
ps_search=pref_search,
ps_lookup=pref_lookup,
ps_l=2 if case_sensitive else 4,
ps_1byte_ch=ps1,
ps_2byte_ch=ps2,
ps_3byte_ch=ps3,
ps_4byte_ch=ps4,
ps_lengths=ops.asarray1i([2]),
ss_search=suff_search,
ss_lookup=suff_lookup,
ss_l=5 if case_sensitive else 9,
ss_lengths=ops.asarray1i([2, 1]),
ss_1byte_ch=ss1,
ss_2byte_ch=ss2,
ss_3byte_ch=ss3,
ss_4byte_ch=ss4,
ss_lengths=ops.asarray1i([1, 2]),
)
assert hashes[0][0] == _get_unsigned_32_bit_hash("s")
assert hashes[0][1] == _get_unsigned_32_bit_hash(
assert hashes[0][1] == _get_unsigned_32_bit_hash("spa")
assert hashes[0][2] == _get_unsigned_32_bit_hash(
"spaC" if case_sensitive else "spac"
)
assert hashes[0][2] == _get_unsigned_32_bit_hash("spa")
assert hashes[0][3] == _get_unsigned_32_bit_hash("Cy" if case_sensitive else "cy")
assert hashes[0][4] == _get_unsigned_32_bit_hash("aCy" if case_sensitive else "acy")
assert hashes[0][5] == _get_unsigned_32_bit_hash(
@ -1046,7 +1047,7 @@ def test_get_character_combination_hashes_good_case(en_tokenizer, case_sensitive
assert hashes[1][4] == _get_unsigned_32_bit_hash("")
assert hashes[1][5] == _get_unsigned_32_bit_hash("")
assert hashes[1][6] == _get_unsigned_32_bit_hash("")
assert hashes[1][7] == _get_unsigned_32_bit_hash(" ")
assert hashes[1][7] == 0
assert hashes[1][8] == _get_unsigned_32_bit_hash("")
assert hashes[1][9] == _get_unsigned_32_bit_hash("")
assert hashes[2][0] == _get_unsigned_32_bit_hash("a")
@ -1056,54 +1057,57 @@ def test_get_character_combination_hashes_good_case(en_tokenizer, case_sensitive
assert hashes[2][4] == _get_unsigned_32_bit_hash("and")
assert hashes[2][5] == _get_unsigned_32_bit_hash("and")
assert hashes[2][6] == _get_unsigned_32_bit_hash("and")
assert hashes[2][7] == _get_unsigned_32_bit_hash(" ")
assert hashes[2][8] == _get_unsigned_32_bit_hash(" ")
assert hashes[2][9] == _get_unsigned_32_bit_hash(" ")
assert hashes[2][7] == 0
assert hashes[2][8] == 0
assert hashes[2][9] == 0
assert hashes[3][0] == _get_unsigned_32_bit_hash("P" if case_sensitive else "p")
assert hashes[3][1] == _get_unsigned_32_bit_hash(
assert hashes[3][1] == _get_unsigned_32_bit_hash("Pro" if case_sensitive else "pro")
assert hashes[3][2] == _get_unsigned_32_bit_hash(
"Prod" if case_sensitive else "prod"
)
assert hashes[3][2] == _get_unsigned_32_bit_hash("Pro" if case_sensitive else "pro")
assert hashes[3][3] == _get_unsigned_32_bit_hash("gy")
assert hashes[3][4] == _get_unsigned_32_bit_hash("igy")
assert hashes[3][5] == _get_unsigned_32_bit_hash("digy")
assert hashes[3][6] == _get_unsigned_32_bit_hash("odigy")
assert hashes[3][7] == _get_unsigned_32_bit_hash(" " if case_sensitive else "pr")
assert hashes[3][7] == 0 if case_sensitive else _get_unsigned_32_bit_hash("pr")
assert hashes[3][9] == _get_unsigned_32_bit_hash("r")
assert hashes[3][8] == _get_unsigned_32_bit_hash("r")
if case_sensitive:
assert hashes[3][8] == _get_unsigned_32_bit_hash("r ")
assert hashes[3][9] == _get_unsigned_32_bit_hash("r")
else:
assert hashes[3][8] == _get_unsigned_32_bit_hash("rp")
assert hashes[3][9] == _get_unsigned_32_bit_hash("rp")
# check values are the same cross-platform
if case_sensitive:
assert hashes[0][1] == 3712103410
assert hashes[0][2] == 3041529170
else:
assert hashes[0][1] == 307339932
assert hashes[1][3] == 2414314354
assert hashes[2][8] == 1669671676
assert hashes[0][2] == 2199614696
assert hashes[1][3] == 910783208
assert hashes[3][8] == 1553167345
def test_get_character_combination_hashes_good_case_partial(en_tokenizer):
doc = en_tokenizer("spaCy✨ and Prodigy")
ops = get_current_ops()
pref_search, pref_lookup = get_arrays_for_search_chars("rp", False)
ps1, ps2, ps3, ps4 = get_search_char_byte_arrays("rp", False)
hashes = doc.get_character_combination_hashes(
cs=False,
p_lengths=ops.asarray1i([]),
s_lengths=ops.asarray1i([2, 3, 4, 5]),
ps_search=pref_search,
ps_lookup=pref_lookup,
ps_l=4,
ps_1byte_ch=ps1,
ps_2byte_ch=ps2,
ps_3byte_ch=ps3,
ps_4byte_ch=ps4,
ps_lengths=ops.asarray1i([2]),
ss_search=bytes(),
ss_lookup=bytes(),
ss_l=0,
ss_1byte_ch=bytes(),
ss_2byte_ch=bytes(),
ss_3byte_ch=bytes(),
ss_4byte_ch=bytes(),
ss_lengths=ops.asarray1i([]),
)
assert hashes[0][0] == _get_unsigned_32_bit_hash("cy")
assert hashes[0][1] == _get_unsigned_32_bit_hash("acy")
assert hashes[0][2] == _get_unsigned_32_bit_hash("pacy")
@ -1113,12 +1117,12 @@ def test_get_character_combination_hashes_good_case_partial(en_tokenizer):
assert hashes[1][1] == _get_unsigned_32_bit_hash("")
assert hashes[1][2] == _get_unsigned_32_bit_hash("")
assert hashes[1][3] == _get_unsigned_32_bit_hash("")
assert hashes[1][4] == _get_unsigned_32_bit_hash(" ")
assert hashes[1][4] == 0
assert hashes[2][0] == _get_unsigned_32_bit_hash("nd")
assert hashes[2][1] == _get_unsigned_32_bit_hash("and")
assert hashes[2][2] == _get_unsigned_32_bit_hash("and")
assert hashes[2][3] == _get_unsigned_32_bit_hash("and")
assert hashes[2][4] == _get_unsigned_32_bit_hash(" ")
assert hashes[2][4] == 0
assert hashes[3][0] == _get_unsigned_32_bit_hash("gy")
assert hashes[3][1] == _get_unsigned_32_bit_hash("igy")
assert hashes[3][2] == _get_unsigned_32_bit_hash("digy")
@ -1126,30 +1130,127 @@ def test_get_character_combination_hashes_good_case_partial(en_tokenizer):
assert hashes[3][4] == _get_unsigned_32_bit_hash("pr")
def test_get_character_combination_hashes_copying_in_middle(en_tokenizer):
doc = en_tokenizer("sp𐌞Cé")
ops = get_current_ops()
for p_length in range(1, 8):
for s_length in range(1, 8):
hashes = doc.get_character_combination_hashes(
cs=False,
p_lengths=ops.asarray1i([p_length]),
s_lengths=ops.asarray1i([s_length]),
ps_search=bytes(),
ps_lookup=bytes(),
ps_l=0,
ps_1byte_ch=bytes(),
ps_2byte_ch=bytes(),
ps_3byte_ch=bytes(),
ps_4byte_ch=bytes(),
ps_lengths=ops.asarray1i([]),
ss_search=bytes(),
ss_lookup=bytes(),
ss_l=0,
ss_1byte_ch=bytes(),
ss_2byte_ch=bytes(),
ss_3byte_ch=bytes(),
ss_4byte_ch=bytes(),
ss_lengths=ops.asarray1i([]),
)
assert hashes[0][0] == _get_unsigned_32_bit_hash("sp𐌞cé"[:p_length])
assert hashes[0][1] == _get_unsigned_32_bit_hash(" sp𐌞cé"[8 - s_length :])
assert hashes[0][1] == _get_unsigned_32_bit_hash("sp𐌞cé"[-s_length:])
@pytest.mark.parametrize("case_sensitive", [True, False])
def test_get_character_combination_hashes_turkish_i_with_dot(en_tokenizer, case_sensitive):
doc = en_tokenizer("İ".lower() + "İ")
ops = get_current_ops()
s1, s2, s3, s4 = get_search_char_byte_arrays("İ", case_sensitive)
hashes = doc.get_character_combination_hashes(
cs=case_sensitive,
p_lengths=ops.asarray1i([1, 2, 3, 4]),
s_lengths=ops.asarray1i([1, 2, 3, 4]),
ps_1byte_ch=s1,
ps_2byte_ch=s2,
ps_3byte_ch=s3,
ps_4byte_ch=s4,
ps_lengths=ops.asarray1i([1, 2, 3, 4]),
ss_1byte_ch=s1,
ss_2byte_ch=s2,
ss_3byte_ch=s3,
ss_4byte_ch=s4,
ss_lengths=ops.asarray1i([1, 2, 3, 4]),
)
COMBINING_DOT_ABOVE = b"\xcc\x87".decode("UTF-8")
assert hashes[0][0] == _get_unsigned_32_bit_hash("i")
assert hashes[0][1] == _get_unsigned_32_bit_hash("İ".lower())
if case_sensitive:
assert hashes[0][2] == _get_unsigned_32_bit_hash("İ".lower() + "İ")
assert hashes[0][3] == _get_unsigned_32_bit_hash("İ".lower() + "İ")
assert hashes[0][4] == _get_unsigned_32_bit_hash("İ")
assert hashes[0][5] == _get_unsigned_32_bit_hash(COMBINING_DOT_ABOVE + "İ")
assert hashes[0][6] == _get_unsigned_32_bit_hash("İ".lower() + "İ")
assert hashes[0][7] == _get_unsigned_32_bit_hash("İ".lower() + "İ")
assert hashes[0][8] == _get_unsigned_32_bit_hash("İ")
assert hashes[0][9] == _get_unsigned_32_bit_hash("İ")
assert hashes[0][12] == _get_unsigned_32_bit_hash("İ")
assert hashes[0][13] == _get_unsigned_32_bit_hash("İ")
else:
assert hashes[0][2] == _get_unsigned_32_bit_hash("İ".lower() + "i")
assert hashes[0][3] == _get_unsigned_32_bit_hash("İ".lower() * 2)
assert hashes[0][4] == _get_unsigned_32_bit_hash(COMBINING_DOT_ABOVE)
assert hashes[0][5] == _get_unsigned_32_bit_hash("İ".lower())
assert hashes[0][6] == _get_unsigned_32_bit_hash(COMBINING_DOT_ABOVE + "İ".lower())
assert hashes[0][7] == _get_unsigned_32_bit_hash("İ".lower() * 2)
assert hashes[0][8] == _get_unsigned_32_bit_hash("i")
assert hashes[0][9] == _get_unsigned_32_bit_hash("İ".lower())
assert hashes[0][10] == _get_unsigned_32_bit_hash("İ".lower() + "i")
assert hashes[0][11] == _get_unsigned_32_bit_hash("İ".lower() * 2)
assert hashes[0][12] == _get_unsigned_32_bit_hash(COMBINING_DOT_ABOVE)
assert hashes[0][13] == _get_unsigned_32_bit_hash(COMBINING_DOT_ABOVE + "i")
assert hashes[0][14] == _get_unsigned_32_bit_hash(COMBINING_DOT_ABOVE + "i" + COMBINING_DOT_ABOVE)
assert hashes[0][15] == _get_unsigned_32_bit_hash((COMBINING_DOT_ABOVE + "i") * 2)
@pytest.mark.parametrize("case_sensitive", [True, False])
def test_get_character_combination_hashes_string_store_spec_cases(en_tokenizer, case_sensitive):
symbol = "FLAG19"
short_word = "bee"
normal_word = "serendipity"
long_word = "serendipity" * 50
assert len(long_word) > 255
doc = en_tokenizer(' '.join((symbol, short_word, normal_word, long_word)))
assert len(doc) == 4
ops = get_current_ops()
ps1, ps2, ps3, ps4 = get_search_char_byte_arrays("E", case_sensitive)
hashes = doc.get_character_combination_hashes(
cs=case_sensitive,
p_lengths=ops.asarray1i([2]),
s_lengths=ops.asarray1i([2]),
ps_1byte_ch=ps1,
ps_2byte_ch=ps2,
ps_3byte_ch=ps3,
ps_4byte_ch=ps4,
ps_lengths=ops.asarray1i([2]),
ss_1byte_ch=bytes(),
ss_2byte_ch=bytes(),
ss_3byte_ch=bytes(),
ss_4byte_ch=bytes(),
ss_lengths=ops.asarray1i([]),
)
assert hashes[0][0] == _get_unsigned_32_bit_hash("FL" if case_sensitive else "fl")
assert hashes[0][1] == _get_unsigned_32_bit_hash("19")
assert hashes[0][2] == 0
assert hashes[1][0] == _get_unsigned_32_bit_hash("be")
assert hashes[1][1] == _get_unsigned_32_bit_hash("ee")
if case_sensitive:
assert hashes[1][2] == 0
else:
assert hashes[1][2] == _get_unsigned_32_bit_hash("ee")
assert hashes[2][0] == hashes[3][0] == _get_unsigned_32_bit_hash("se")
assert hashes[2][1] == hashes[3][1] == _get_unsigned_32_bit_hash("ty")
if case_sensitive:
assert hashes[2][2] == hashes[3][2] == 0
else:
assert hashes[2][2] == hashes[3][2] == _get_unsigned_32_bit_hash("ee")
def test_character_combination_hashes_empty_lengths(en_tokenizer):
@ -1159,12 +1260,14 @@ def test_character_combination_hashes_empty_lengths(en_tokenizer):
cs=True,
p_lengths=ops.asarray1i([]),
s_lengths=ops.asarray1i([]),
ps_search=bytes(),
ps_lookup=bytes(),
ps_l=0,
ps_1byte_ch=bytes(),
ps_2byte_ch=bytes(),
ps_3byte_ch=bytes(),
ps_4byte_ch=bytes(),
ps_lengths=ops.asarray1i([]),
ss_search=bytes(),
ss_lookup=bytes(),
ss_l=0,
ss_1byte_ch=bytes(),
ss_2byte_ch=bytes(),
ss_3byte_ch=bytes(),
ss_4byte_ch=bytes(),
ss_lengths=ops.asarray1i([]),
).shape == (1, 0)

View File

@ -40,13 +40,13 @@ cdef int [:,:] _get_lca_matrix(Doc, int start, int end)
cdef void _set_affix_lengths(
const unsigned char[:] tok_str,
unsigned char* aff_len_buf,
const int pref_len,
const int suff_len,
) nogil
unsigned char* aff_l_buf,
const int pref_l,
const int suff_l,
)
ccdef void _search_for_chars(
cdef void _search_for_chars(
const unsigned char[:] tok_str,
const unsigned char[:] s_1byte_ch,
const unsigned char[:] s_2byte_ch,
@ -54,9 +54,9 @@ ccdef void _search_for_chars(
const unsigned char[:] s_4byte_ch,
unsigned char* res_buf,
int max_res_l,
unsigned char* len_buf,
unsigned char* l_buf,
bint suffs_not_prefs
) nogil
)
cdef class Doc:

View File

@ -1736,7 +1736,7 @@ cdef class Doc:
return output
def np.ndarray get_character_combination_hashes(self,
def get_character_combination_hashes(self,
*,
const bint cs,
np.ndarray p_lengths,
@ -1751,7 +1751,7 @@ cdef class Doc:
const unsigned char[:] ss_3byte_ch,
const unsigned char[:] ss_4byte_ch,
np.ndarray ss_lengths,
) nogil:
):
"""
Returns a 2D NumPy array where the rows represent tokens and the columns represent hashes of various character combinations
derived from the raw text of each token.
@ -1797,11 +1797,11 @@ cdef class Doc:
# Define / allocate buffers
cdef int aff_l = p_max_l + s_max_l
cdef char* aff_len_buf = self.mem.alloc(aff_l, 1)
cdef char* ps_res_buf = self.mem.alloc(ps_max_l, 4)
cdef char* ps_len_buf = self.mem.alloc(ps_max_l, 1)
cdef char* ss_res_buf = self.mem.alloc(ss_max_l, 4)
cdef char* ss_len_buf = self.mem.alloc(ss_max_l, 1)
cdef unsigned char* aff_l_buf = <unsigned char*> self.mem.alloc(aff_l, 1)
cdef unsigned char* ps_res_buf = <unsigned char*> self.mem.alloc(ps_max_l, 4)
cdef unsigned char* ps_l_buf = <unsigned char*> self.mem.alloc(ps_max_l, 1)
cdef unsigned char* ss_res_buf = <unsigned char*> self.mem.alloc(ss_max_l, 4)
cdef unsigned char* ss_l_buf = <unsigned char*> self.mem.alloc(ss_max_l, 1)
# Define memory views on length arrays
cdef int[:] p_lengths_v = p_lengths
@ -1812,7 +1812,7 @@ cdef class Doc:
# Define working variables
cdef TokenC tok_c
cdef int tok_i, offset
cdef uint64_t hash_val
cdef uint64_t hash_val = 0
cdef attr_t num_tok_attr
cdef const unsigned char[:] tok_str
@ -1822,43 +1822,44 @@ cdef class Doc:
tok_str = self.vocab.strings.utf8_view(num_tok_attr)
if aff_l > 0:
_set_affix_lengths(tok_str, aff_len_buf, p_max_l, s_max_l)
_set_affix_lengths(tok_str, aff_l_buf, p_max_l, s_max_l)
for hash_idx in range(p_h_num):
offset = aff_len_buf[p_lengths_v[hash_idx]]
offset = aff_l_buf[p_lengths_v[hash_idx] - 1]
if offset > 0:
hash_val = hash32(<void*> &tok_str[0], offset, 0)
hashes[tok_i, hash_idx] = hash_val
for hash_idx in range(p_h_num, s_h_end):
offset = s_lengths_v[hash_idx - p_h_num]
offset = aff_l_buf[s_lengths_v[hash_idx - p_h_num] + p_max_l - 1]
if offset > 0:
hash_val = hash32(<void*> &tok_str[len(tok_str) - offset], offset, 0)
hashes[tok_i, hash_idx] = hash_val
if ps_h_num > 0:
_search_for_chars(tok_str, ps_1byte_ch, ps_2byte_ch, ps_3byte_ch, ps_4byte_ch, ps_res_buf, ps_max_l, ps_res_len, False)
_search_for_chars(tok_str, ps_1byte_ch, ps_2byte_ch, ps_3byte_ch, ps_4byte_ch, ps_res_buf, ps_max_l, ps_l_buf, False)
hash_val = 0
for hash_idx in range(s_h_end, ps_h_end):
offset = ps_lengths_v[hash_idx - s_h_end]
offset = ps_l_buf[ps_lengths_v[hash_idx - s_h_end] - 1]
if offset > 0:
hash_val = hash32(ps_res_buf, offset, 0)
hashes[tok_i, hash_idx] = hash_val
if ss_h_num > 0:
_search_for_chars(tok_str, ss_1byte_ch, ss_2byte_ch, ss_3byte_ch, ss_4byte_ch, ss_res_buf, ss_max_l, ss_res_len, True)
_search_for_chars(tok_str, ss_1byte_ch, ss_2byte_ch, ss_3byte_ch, ss_4byte_ch, ss_res_buf, ss_max_l, ss_l_buf, True)
hash_val = 0
for hash_idx in range(ps_h_end, ss_h_end):
offset = ss_lengths_v[hash_idx - ps_h_end]
offset = ss_l_buf[ss_lengths_v[hash_idx - ps_h_end] - 1]
if offset > 0:
hash_val = hash32(ss_res_buf, offset, 0)
hashes[tok_i, hash_idx] = hash_val
self.mem.free(aff_len_buf)
self.mem.free(aff_l_buf)
self.mem.free(ps_res_buf)
self.mem.free(ps_len_buf)
self.mem.free(ps_l_buf)
self.mem.free(ss_res_buf)
self.mem.free(ss_len_buf)
self.mem.free(ss_l_buf)
return hashes
@staticmethod
@ -2044,46 +2045,45 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
cdef void _set_affix_lengths(
const unsigned char[:] tok_str,
unsigned char* aff_len_buf,
const int pref_len,
const int suff_len,
) nogil:
""" TODO : Populate *len_buf*, which has length *pref_len+suff_len* with the byte lengths of the first *pref_len* and the last
*suff_len* characters within *tok_str*. If the word is shorter than pref and/or suff, the empty lengths in the middle are
filled with zeros.
unsigned char* aff_l_buf,
const int pref_l,
const int suff_l,
):
""" Populate *aff_l_buf*, which has length *pref_l+suff_l* with the byte lengths of the first *pref_l* and the last
*suff_l* characters within *tok_str*. Lengths that are greater than the character length of the whole word are
populated with the byte length of the whole word.
tok_str: a memoryview of a UTF-8 representation of a string.
len_buf: a buffer of length *pref_len+suff_len* in which to store the lengths. The calling code ensures that lengths
aff_l_buf: a buffer of length *pref_l+suff_l* in which to store the lengths. The calling code ensures that lengths
greater than 255 cannot occur.
pref_len: the number of characters to process at the beginning of the word.
suff_len: the number of characters to process at the end of the word.
pref_l: the number of characters to process at the beginning of the word.
suff_l: the number of characters to process at the end of the word.
"""
cdef int tok_str_idx = 0, aff_len_buf_idx = 0, tok_str_len = len(tok_str)
cdef int tok_str_idx = 1, aff_l_buf_idx = 0, tok_str_l = len(tok_str)
while aff_len_buf_idx < pref_len:
if (tok_str[tok_str_idx] & 0xc0) != 0x80: # not a continuation character
aff_len_buf[aff_len_buf_idx] = tok_str_idx + 1
aff_len_buf_idx += 1
while aff_l_buf_idx < pref_l:
if tok_str_idx == len(tok_str) or ((tok_str[tok_str_idx] & 0xc0) != 0x80): # not a continuation character
aff_l_buf[aff_l_buf_idx] = tok_str_idx
aff_l_buf_idx += 1
tok_str_idx += 1
if tok_str_idx == len(tok_str):
if tok_str_idx > len(tok_str):
break
if aff_len_buf_idx < pref_len:
memset(aff_len_buf + aff_len_buf_idx, 0, pref_len - aff_len_buf_idx)
aff_len_buf_idx = pref_len
if aff_l_buf_idx < pref_l:
memset(aff_l_buf + aff_l_buf_idx, aff_l_buf[aff_l_buf_idx - 1], pref_l - aff_l_buf_idx)
aff_l_buf_idx = pref_l
tok_str_idx = 1
while aff_len_buf_idx < pref_len + suff_len:
tok_str_idx = tok_str_l - 1
while aff_l_buf_idx < pref_l + suff_l:
if (tok_str[tok_str_idx] & 0xc0) != 0x80: # not a continuation character
aff_len_buf[aff_len_buf_idx] = tok_str_len - tok_str_idx
aff_len_buf_idx += 1
tok_str_idx += 1
if tok_str_idx > tok_str_len:
aff_l_buf[aff_l_buf_idx] = tok_str_l - tok_str_idx
aff_l_buf_idx += 1
tok_str_idx -= 1
if tok_str_idx < 0:
break
if aff_len_buf_idx < pref_len + suff_len:
memset(aff_len_buf + aff_len_buf_idx, 0, suff_len - aff_len_buf_idx)
if aff_l_buf_idx < pref_l + suff_l:
memset(aff_l_buf + aff_l_buf_idx, aff_l_buf[aff_l_buf_idx - 1], pref_l + suff_l - aff_l_buf_idx)
cdef void _search_for_chars(
const unsigned char[:] tok_str,
@ -2093,31 +2093,33 @@ cdef void _search_for_chars(
const unsigned char[:] s_4byte_ch,
unsigned char* res_buf,
int max_res_l,
unsigned char* len_buf,
unsigned char* l_buf,
bint suffs_not_prefs
) nogil:
):
""" Search *tok_str* within a string for characters within the *s_<n>byte_ch> buffers, starting at the
beginning or end depending on the value of *suffs_not_prefs*. Wherever a character matches,
it is added to *res_buf* and the byte length up to that point is added to *len_buf*.
it is added to *res_buf* and the byte length up to that point is added to *len_buf*. When nothing
more is found, the remainder of *len_buf* is populated wth the byte length from the last result,
which may be *0* if the search was not successful.
tok_str: a memoryview of a UTF-8 representation of a string.
s_<n>byte_ch: a byte array containing in order n-byte-wide characters to search for.
res_buf: the buffer in which to place the search results.
max_res_l: the maximum number of found characters to place in *res_buf*.
len_buf: a buffer of length *max_res_l* in which to store the byte lengths.
l_buf: a buffer of length *max_res_l* in which to store the byte lengths.
The calling code ensures that lengths greater than 255 cannot occur.
suffs_not_prefs: if *True*, searching starts from the end of the word; if *False*, from the beginning.
"""
cdef int tok_str_len = len(tok_str), search_char_idx = 0, res_buf_idx = 0, len_buf_idx = 0
cdef int last_tok_str_idx = tok_str_len if suffs_not_prefs else 0
cdef int this_tok_str_idx = tok_str_len - 1 if suffs_not_prefs else 1
cdef int ch_wdth, tok_start_idx
cdef char[:] search_chars
cdef int tok_str_l = len(tok_str), search_char_idx = 0, res_buf_idx = 0, l_buf_idx = 0, ch_wdth, tok_start_idx
cdef const unsigned char[:] search_chars
cdef int last_tok_str_idx = tok_str_l if suffs_not_prefs else 0
cdef int this_tok_str_idx = tok_str_l - 1 if suffs_not_prefs else 1
while True:
if (
this_tok_str_idx == tok_str_len or
(tok_str[this_tok_str_idx] & 0xc0) != 0x80 # not continuation character
this_tok_str_idx == tok_str_l or
(tok_str[this_tok_str_idx] & 0xc0) != 0x80 # not continuation character, always applies to [0].
):
ch_wdth = abs(this_tok_str_idx - last_tok_str_idx)
if ch_wdth == 1:
@ -2129,16 +2131,17 @@ cdef void _search_for_chars(
else:
search_chars = s_4byte_ch
tok_start_idx = this_tok_str_idx if suffs_not_prefs else last_tok_str_idx
for search_char_idx in range(0, len(search_chars), ch_wdth):
cmp_result = memcmp(tok_str + tok_start_idx, search_chars + search_char_idx, ch_wdth)
cmp_result = memcmp(&tok_str[tok_start_idx], &search_chars[search_char_idx], ch_wdth)
if cmp_result == 0:
memcpy(res_buf + res_buf_idx, search_chars + search_char_idx, ch_wdth)
memcpy(res_buf + res_buf_idx, &search_chars[search_char_idx], ch_wdth)
res_buf_idx += ch_wdth
len_buf[len_buf_idx] = res_buf_idx
len_buf_idx += 1
if len_buf_idx == max_res_l:
l_buf[l_buf_idx] = res_buf_idx
l_buf_idx += 1
if l_buf_idx == max_res_l:
return
if cmp_result >= 0:
if cmp_result <= 0:
break
last_tok_str_idx = this_tok_str_idx
if suffs_not_prefs:
@ -2147,48 +2150,11 @@ cdef void _search_for_chars(
break
else:
this_tok_str_idx += 1
if this_tok_str_idx >= tok_str_len:
if this_tok_str_idx > tok_str_l:
break
# fill in unused characters in the length buffer with 0
memset(res_buf + res_buf_idx, 0, max_res_l - res_buf_idx)
cdef int result_buf_idx = 0, text_string_idx = tok_idx + (tok_len - 1) if suffs_not_prefs else tok_idx
cdef int search_buf_idx
cdef int cmp_result
while result_buf_idx < result_buf_len:
for search_buf_idx in range (search_buf_len):
cmp_result = memcmp(search_buf + search_buf_idx, text_buf + text_string_idx, sizeof(Py_UCS4))
if cmp_result == 0:
memcpy(result_buf + result_buf_idx, lookup_buf + search_buf_idx, sizeof(Py_UCS4))
result_buf_idx += 1
if cmp_result >= 0:
break
if suffs_not_prefs:
if text_string_idx <= tok_idx:
break
text_string_idx -= 1
else:
text_string_idx += 1
if text_string_idx >= tok_idx + tok_len:
break
# fill in any unused characters in the result buffer with zeros
if result_buf_idx < result_buf_len:
memset(result_buf + result_buf_idx, 0, (result_buf_len - result_buf_idx) * sizeof(Py_UCS4))
return result_buf_idx > 0
# fill in unused characters in the length buffer
memset(l_buf + l_buf_idx, res_buf_idx, max_res_l - l_buf_idx)
def pickle_doc(doc):