Performance improvements

This commit is contained in:
richardpaulhudson 2022-10-05 14:17:28 +02:00
parent d296ae9d8e
commit f712e0bc4a
4 changed files with 82 additions and 33 deletions

View File

@ -977,26 +977,26 @@ def test_doc_spans_setdefault(en_tokenizer):
doc.spans.setdefault("key3", default=SpanGroup(doc, spans=[doc[0:1], doc[1:2]]))
assert len(doc.spans["key3"]) == 2
@pytest.mark.parametrize(
"case_sensitive", [True, False]
)
def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive):
def _get_unsigned_32_bit_hash(input:str) -> int:
if not case_sensitive:
input = input.lower()
def _get_unsigned_32_bit_hash(input: str) -> int:
working_hash = hash(input.encode("UTF-16")[2:])
if working_hash < 0:
working_hash = working_hash + (2<<31)
working_hash = working_hash + (2 << 31)
return working_hash
@pytest.mark.parametrize("case_sensitive", [True, False])
def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive):
doc = en_tokenizer("spaCy✨ and Prodigy")
prefixes = doc.get_affix_hashes(False, case_sensitive, 1, 5, "", 2, 3)
suffixes = doc.get_affix_hashes(True, case_sensitive, 2, 6, "xx✨rP", 1, 3)
suffixes = doc.get_affix_hashes(True, case_sensitive, 2, 6, "xx✨rp", 1, 3)
assert prefixes[0][0] == _get_unsigned_32_bit_hash("s")
assert prefixes[0][1] == _get_unsigned_32_bit_hash("sp")
assert prefixes[0][2] == _get_unsigned_32_bit_hash("spa")
assert prefixes[0][3] == _get_unsigned_32_bit_hash("spaC")
assert prefixes[0][3] == _get_unsigned_32_bit_hash(
"spaC" if case_sensitive else "spac"
)
assert prefixes[0][4] == _get_unsigned_32_bit_hash(" ")
assert prefixes[1][0] == _get_unsigned_32_bit_hash("")
assert prefixes[1][1] == _get_unsigned_32_bit_hash("")
@ -1008,18 +1008,28 @@ def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive):
assert prefixes[2][2] == _get_unsigned_32_bit_hash("and")
assert prefixes[2][3] == _get_unsigned_32_bit_hash("and")
assert prefixes[2][4] == _get_unsigned_32_bit_hash(" ")
assert prefixes[3][0] == _get_unsigned_32_bit_hash("P")
assert prefixes[3][1] == _get_unsigned_32_bit_hash("Pr")
assert prefixes[3][2] == _get_unsigned_32_bit_hash("Pro")
assert prefixes[3][3] == _get_unsigned_32_bit_hash("Prod")
assert prefixes[3][0] == _get_unsigned_32_bit_hash("P" if case_sensitive else "p")
assert prefixes[3][1] == _get_unsigned_32_bit_hash("Pr" if case_sensitive else "pr")
assert prefixes[3][2] == _get_unsigned_32_bit_hash(
"Pro" if case_sensitive else "pro"
)
assert prefixes[3][3] == _get_unsigned_32_bit_hash(
"Prod" if case_sensitive else "prod"
)
assert prefixes[3][4] == _get_unsigned_32_bit_hash(" ")
assert suffixes[0][0] == _get_unsigned_32_bit_hash("Cy")
assert suffixes[0][1] == _get_unsigned_32_bit_hash("aCy")
assert suffixes[0][2] == _get_unsigned_32_bit_hash("paCy")
assert suffixes[0][3] == _get_unsigned_32_bit_hash("spaCy")
assert suffixes[0][4] == _get_unsigned_32_bit_hash(" ")
assert suffixes[0][5] == _get_unsigned_32_bit_hash(" ")
assert suffixes[0][0] == _get_unsigned_32_bit_hash("Cy" if case_sensitive else "cy")
assert suffixes[0][1] == _get_unsigned_32_bit_hash(
"aCy" if case_sensitive else "acy"
)
assert suffixes[0][2] == _get_unsigned_32_bit_hash(
"paCy" if case_sensitive else "pacy"
)
assert suffixes[0][3] == _get_unsigned_32_bit_hash(
"spaCy" if case_sensitive else "spacy"
)
assert suffixes[0][4] == _get_unsigned_32_bit_hash("p")
assert suffixes[0][5] == _get_unsigned_32_bit_hash("p ")
assert suffixes[1][0] == _get_unsigned_32_bit_hash("")
assert suffixes[1][1] == _get_unsigned_32_bit_hash("")
assert suffixes[1][2] == _get_unsigned_32_bit_hash("")
@ -1039,12 +1049,35 @@ def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive):
assert suffixes[3][4] == _get_unsigned_32_bit_hash("r")
if case_sensitive:
assert suffixes[3][5] == _get_unsigned_32_bit_hash("rP")
else:
assert suffixes[3][5] == _get_unsigned_32_bit_hash("r ")
else:
assert suffixes[3][5] == _get_unsigned_32_bit_hash("rp")
# check values are the same cross-platform
assert prefixes[0][3] == 18446744072456113490 if case_sensitive else 18446744071614199016
assert prefixes[0][3] == 753329845 if case_sensitive else 18446744071614199016
assert suffixes[1][0] == 3425774424
assert suffixes[2][5] == 3076404432
def test_get_affix_hashes_4_byte_char_at_end(en_tokenizer):
doc = en_tokenizer("and𐌞")
suffixes = doc.get_affix_hashes(True, True, 1, 4, "a", 1, 2)
assert suffixes[0][1] == _get_unsigned_32_bit_hash("𐌞")
assert suffixes[0][2] == _get_unsigned_32_bit_hash("d𐌞")
assert suffixes[0][3] == _get_unsigned_32_bit_hash("a")
def test_get_affix_hashes_4_byte_char_in_middle(en_tokenizer):
doc = en_tokenizer("and𐌞a")
suffixes = doc.get_affix_hashes(True, False, 1, 5, "a", 1, 3)
assert suffixes[0][0] == _get_unsigned_32_bit_hash("a")
assert suffixes[0][2] == _get_unsigned_32_bit_hash("𐌞a")
assert suffixes[0][3] == _get_unsigned_32_bit_hash("d𐌞a")
assert suffixes[0][4] == _get_unsigned_32_bit_hash("a")
assert suffixes[0][5] == _get_unsigned_32_bit_hash("aa")
def test_get_affixes_4_byte_special_char(en_tokenizer):
doc = en_tokenizer("and𐌞")
with pytest.raises(ValueError):
doc.get_affix_hashes(True, True, 2, 6, "𐌞", 2, 3)

View File

@ -39,7 +39,14 @@ cdef const unsigned char[:] _get_utf16_memoryview(str unicode_string, bint check
cdef bint _is_utf16_char_in_scs(unsigned short utf16_char, const unsigned char[:] scs)
cdef void _set_scs_buffer(const unsigned char[:] searched_string, const unsigned char[:] scs, char* buf, bint suffs_not_prefs)
cdef void _set_scs_buffer(
const unsigned char[:] searched_string,
const unsigned int ss_len,
const unsigned char[:] scs,
char* buf,
const unsigned int buf_len,
const bint suffs_not_prefs
)
cdef class Doc:

View File

@ -176,3 +176,5 @@ class Doc:
def to_utf8_array(self, nr_char: int = ...) -> Ints2d: ...
@staticmethod
def _get_array_attrs() -> Tuple[Any]: ...
def get_affix_hashes(self, suffs_not_prefs: bool, case_sensitive: bool, len_start: int, len_end: int,
special_chars: str, sc_len_start: int, sc_len_end: int) -> Ints2d: ...

View File

@ -1749,6 +1749,7 @@ cdef class Doc:
cdef np.ndarray[np.int64_t, ndim=2] output = numpy.empty((num_toks, num_norm_hashes + num_spec_hashes), dtype="int64")
cdef bytes scs_buffer_bytes = (bytes(" " * sc_len_end, "UTF-16"))[2:] # first two bytes express endianness and are not relevant here
cdef char* scs_buffer = scs_buffer_bytes
cdef unsigned int buf_len = len(scs_buffer_bytes)
cdef attr_t num_tok_attr
cdef str str_tok_attr
@ -1768,7 +1769,7 @@ cdef class Doc:
working_start = 0
output[tok_ind, norm_hash_ind] = hash32(<void*> &tok_str[working_start], working_len, 0)
_set_scs_buffer(tok_str, scs, scs_buffer, suffs_not_prefs)
_set_scs_buffer(tok_str, len_tok_str, scs, scs_buffer, buf_len, suffs_not_prefs)
for spec_hash_ind in range(num_spec_hashes):
working_len = (sc_len_start + spec_hash_ind) * 2
output[tok_ind, num_norm_hashes + spec_hash_ind] = hash32(scs_buffer, working_len, 0)
@ -1982,18 +1983,24 @@ cdef bint _is_utf16_char_in_scs(const unsigned short utf16_char, const unsigned
return False
cdef void _set_scs_buffer(const unsigned char[:] searched_string, const unsigned char[:] scs, char* buf, const bint suffs_not_prefs):
cdef void _set_scs_buffer(
const unsigned char[:] searched_string,
const unsigned int ss_len,
const unsigned char[:] scs,
char* buf,
const unsigned int buf_len,
const bint suffs_not_prefs
):
""" Pick the UFT-16 characters from *searched_string* that are also in *scs* and writes them in order to *buf*.
If *suffs_not_prefs*, the search starts from the end of *searched_string* rather than from the beginning.
"""
cdef unsigned int buf_len = len(buf), buf_idx = 0
cdef unsigned int ss_len = len(searched_string), ss_idx = ss_len - 2 if suffs_not_prefs else 0
cdef unsigned int buf_idx = 0, ss_idx = ss_len - 2 if suffs_not_prefs else 0
cdef unsigned short working_utf16_char, SPACE = 32
while buf_idx < buf_len:
working_utf16_char = (<unsigned short*> &searched_string[ss_idx])[0]
if _is_utf16_char_in_scs(working_utf16_char, scs):
memcpy(buf, &working_utf16_char, 2)
memcpy(buf + buf_idx, &working_utf16_char, 2)
buf_idx += 2
if suffs_not_prefs:
if ss_idx == 0:
@ -2005,7 +2012,7 @@ cdef void _set_scs_buffer(const unsigned char[:] searched_string, const unsigned
break
while buf_idx < buf_len:
memcpy(buf, &SPACE, 2)
memcpy(buf + buf_idx, &SPACE, 2)
buf_idx += 2
def pickle_doc(doc):