mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-02 19:30:19 +03:00
Performance improvements
This commit is contained in:
parent
d296ae9d8e
commit
f712e0bc4a
|
@ -977,26 +977,26 @@ def test_doc_spans_setdefault(en_tokenizer):
|
|||
doc.spans.setdefault("key3", default=SpanGroup(doc, spans=[doc[0:1], doc[1:2]]))
|
||||
assert len(doc.spans["key3"]) == 2
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case_sensitive", [True, False]
|
||||
)
|
||||
def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive):
|
||||
|
||||
def _get_unsigned_32_bit_hash(input:str) -> int:
|
||||
if not case_sensitive:
|
||||
input = input.lower()
|
||||
def _get_unsigned_32_bit_hash(input: str) -> int:
|
||||
working_hash = hash(input.encode("UTF-16")[2:])
|
||||
if working_hash < 0:
|
||||
working_hash = working_hash + (2<<31)
|
||||
working_hash = working_hash + (2 << 31)
|
||||
return working_hash
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case_sensitive", [True, False])
|
||||
def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive):
|
||||
|
||||
doc = en_tokenizer("spaCy✨ and Prodigy")
|
||||
prefixes = doc.get_affix_hashes(False, case_sensitive, 1, 5, "", 2, 3)
|
||||
suffixes = doc.get_affix_hashes(True, case_sensitive, 2, 6, "xx✨rP", 1, 3)
|
||||
suffixes = doc.get_affix_hashes(True, case_sensitive, 2, 6, "xx✨rp", 1, 3)
|
||||
assert prefixes[0][0] == _get_unsigned_32_bit_hash("s")
|
||||
assert prefixes[0][1] == _get_unsigned_32_bit_hash("sp")
|
||||
assert prefixes[0][2] == _get_unsigned_32_bit_hash("spa")
|
||||
assert prefixes[0][3] == _get_unsigned_32_bit_hash("spaC")
|
||||
assert prefixes[0][3] == _get_unsigned_32_bit_hash(
|
||||
"spaC" if case_sensitive else "spac"
|
||||
)
|
||||
assert prefixes[0][4] == _get_unsigned_32_bit_hash(" ")
|
||||
assert prefixes[1][0] == _get_unsigned_32_bit_hash("✨")
|
||||
assert prefixes[1][1] == _get_unsigned_32_bit_hash("✨")
|
||||
|
@ -1008,18 +1008,28 @@ def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive):
|
|||
assert prefixes[2][2] == _get_unsigned_32_bit_hash("and")
|
||||
assert prefixes[2][3] == _get_unsigned_32_bit_hash("and")
|
||||
assert prefixes[2][4] == _get_unsigned_32_bit_hash(" ")
|
||||
assert prefixes[3][0] == _get_unsigned_32_bit_hash("P")
|
||||
assert prefixes[3][1] == _get_unsigned_32_bit_hash("Pr")
|
||||
assert prefixes[3][2] == _get_unsigned_32_bit_hash("Pro")
|
||||
assert prefixes[3][3] == _get_unsigned_32_bit_hash("Prod")
|
||||
assert prefixes[3][0] == _get_unsigned_32_bit_hash("P" if case_sensitive else "p")
|
||||
assert prefixes[3][1] == _get_unsigned_32_bit_hash("Pr" if case_sensitive else "pr")
|
||||
assert prefixes[3][2] == _get_unsigned_32_bit_hash(
|
||||
"Pro" if case_sensitive else "pro"
|
||||
)
|
||||
assert prefixes[3][3] == _get_unsigned_32_bit_hash(
|
||||
"Prod" if case_sensitive else "prod"
|
||||
)
|
||||
assert prefixes[3][4] == _get_unsigned_32_bit_hash(" ")
|
||||
|
||||
assert suffixes[0][0] == _get_unsigned_32_bit_hash("Cy")
|
||||
assert suffixes[0][1] == _get_unsigned_32_bit_hash("aCy")
|
||||
assert suffixes[0][2] == _get_unsigned_32_bit_hash("paCy")
|
||||
assert suffixes[0][3] == _get_unsigned_32_bit_hash("spaCy")
|
||||
assert suffixes[0][4] == _get_unsigned_32_bit_hash(" ")
|
||||
assert suffixes[0][5] == _get_unsigned_32_bit_hash(" ")
|
||||
assert suffixes[0][0] == _get_unsigned_32_bit_hash("Cy" if case_sensitive else "cy")
|
||||
assert suffixes[0][1] == _get_unsigned_32_bit_hash(
|
||||
"aCy" if case_sensitive else "acy"
|
||||
)
|
||||
assert suffixes[0][2] == _get_unsigned_32_bit_hash(
|
||||
"paCy" if case_sensitive else "pacy"
|
||||
)
|
||||
assert suffixes[0][3] == _get_unsigned_32_bit_hash(
|
||||
"spaCy" if case_sensitive else "spacy"
|
||||
)
|
||||
assert suffixes[0][4] == _get_unsigned_32_bit_hash("p")
|
||||
assert suffixes[0][5] == _get_unsigned_32_bit_hash("p ")
|
||||
assert suffixes[1][0] == _get_unsigned_32_bit_hash("✨")
|
||||
assert suffixes[1][1] == _get_unsigned_32_bit_hash("✨")
|
||||
assert suffixes[1][2] == _get_unsigned_32_bit_hash("✨")
|
||||
|
@ -1039,12 +1049,35 @@ def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive):
|
|||
assert suffixes[3][4] == _get_unsigned_32_bit_hash("r")
|
||||
|
||||
if case_sensitive:
|
||||
assert suffixes[3][5] == _get_unsigned_32_bit_hash("rP")
|
||||
else:
|
||||
assert suffixes[3][5] == _get_unsigned_32_bit_hash("r ")
|
||||
else:
|
||||
assert suffixes[3][5] == _get_unsigned_32_bit_hash("rp")
|
||||
|
||||
# check values are the same cross-platform
|
||||
assert prefixes[0][3] == 18446744072456113490 if case_sensitive else 18446744071614199016
|
||||
assert prefixes[0][3] == 753329845 if case_sensitive else 18446744071614199016
|
||||
assert suffixes[1][0] == 3425774424
|
||||
assert suffixes[2][5] == 3076404432
|
||||
|
||||
|
||||
def test_get_affix_hashes_4_byte_char_at_end(en_tokenizer):
|
||||
doc = en_tokenizer("and𐌞")
|
||||
suffixes = doc.get_affix_hashes(True, True, 1, 4, "a", 1, 2)
|
||||
assert suffixes[0][1] == _get_unsigned_32_bit_hash("𐌞")
|
||||
assert suffixes[0][2] == _get_unsigned_32_bit_hash("d𐌞")
|
||||
assert suffixes[0][3] == _get_unsigned_32_bit_hash("a")
|
||||
|
||||
|
||||
def test_get_affix_hashes_4_byte_char_in_middle(en_tokenizer):
|
||||
doc = en_tokenizer("and𐌞a")
|
||||
suffixes = doc.get_affix_hashes(True, False, 1, 5, "a", 1, 3)
|
||||
assert suffixes[0][0] == _get_unsigned_32_bit_hash("a")
|
||||
assert suffixes[0][2] == _get_unsigned_32_bit_hash("𐌞a")
|
||||
assert suffixes[0][3] == _get_unsigned_32_bit_hash("d𐌞a")
|
||||
assert suffixes[0][4] == _get_unsigned_32_bit_hash("a")
|
||||
assert suffixes[0][5] == _get_unsigned_32_bit_hash("aa")
|
||||
|
||||
|
||||
def test_get_affixes_4_byte_special_char(en_tokenizer):
|
||||
doc = en_tokenizer("and𐌞")
|
||||
with pytest.raises(ValueError):
|
||||
doc.get_affix_hashes(True, True, 2, 6, "𐌞", 2, 3)
|
||||
|
|
|
@ -39,7 +39,14 @@ cdef const unsigned char[:] _get_utf16_memoryview(str unicode_string, bint check
|
|||
cdef bint _is_utf16_char_in_scs(unsigned short utf16_char, const unsigned char[:] scs)
|
||||
|
||||
|
||||
cdef void _set_scs_buffer(const unsigned char[:] searched_string, const unsigned char[:] scs, char* buf, bint suffs_not_prefs)
|
||||
cdef void _set_scs_buffer(
|
||||
const unsigned char[:] searched_string,
|
||||
const unsigned int ss_len,
|
||||
const unsigned char[:] scs,
|
||||
char* buf,
|
||||
const unsigned int buf_len,
|
||||
const bint suffs_not_prefs
|
||||
)
|
||||
|
||||
|
||||
cdef class Doc:
|
||||
|
|
|
@ -176,3 +176,5 @@ class Doc:
|
|||
def to_utf8_array(self, nr_char: int = ...) -> Ints2d: ...
|
||||
@staticmethod
|
||||
def _get_array_attrs() -> Tuple[Any]: ...
|
||||
def get_affix_hashes(self, suffs_not_prefs: bool, case_sensitive: bool, len_start: int, len_end: int,
|
||||
special_chars: str, sc_len_start: int, sc_len_end: int) -> Ints2d: ...
|
||||
|
|
|
@ -1749,6 +1749,7 @@ cdef class Doc:
|
|||
cdef np.ndarray[np.int64_t, ndim=2] output = numpy.empty((num_toks, num_norm_hashes + num_spec_hashes), dtype="int64")
|
||||
cdef bytes scs_buffer_bytes = (bytes(" " * sc_len_end, "UTF-16"))[2:] # first two bytes express endianness and are not relevant here
|
||||
cdef char* scs_buffer = scs_buffer_bytes
|
||||
cdef unsigned int buf_len = len(scs_buffer_bytes)
|
||||
cdef attr_t num_tok_attr
|
||||
cdef str str_tok_attr
|
||||
|
||||
|
@ -1768,7 +1769,7 @@ cdef class Doc:
|
|||
working_start = 0
|
||||
output[tok_ind, norm_hash_ind] = hash32(<void*> &tok_str[working_start], working_len, 0)
|
||||
|
||||
_set_scs_buffer(tok_str, scs, scs_buffer, suffs_not_prefs)
|
||||
_set_scs_buffer(tok_str, len_tok_str, scs, scs_buffer, buf_len, suffs_not_prefs)
|
||||
for spec_hash_ind in range(num_spec_hashes):
|
||||
working_len = (sc_len_start + spec_hash_ind) * 2
|
||||
output[tok_ind, num_norm_hashes + spec_hash_ind] = hash32(scs_buffer, working_len, 0)
|
||||
|
@ -1982,18 +1983,24 @@ cdef bint _is_utf16_char_in_scs(const unsigned short utf16_char, const unsigned
|
|||
return False
|
||||
|
||||
|
||||
cdef void _set_scs_buffer(const unsigned char[:] searched_string, const unsigned char[:] scs, char* buf, const bint suffs_not_prefs):
|
||||
cdef void _set_scs_buffer(
|
||||
const unsigned char[:] searched_string,
|
||||
const unsigned int ss_len,
|
||||
const unsigned char[:] scs,
|
||||
char* buf,
|
||||
const unsigned int buf_len,
|
||||
const bint suffs_not_prefs
|
||||
):
|
||||
""" Pick the UFT-16 characters from *searched_string* that are also in *scs* and writes them in order to *buf*.
|
||||
If *suffs_not_prefs*, the search starts from the end of *searched_string* rather than from the beginning.
|
||||
"""
|
||||
cdef unsigned int buf_len = len(buf), buf_idx = 0
|
||||
cdef unsigned int ss_len = len(searched_string), ss_idx = ss_len - 2 if suffs_not_prefs else 0
|
||||
cdef unsigned int buf_idx = 0, ss_idx = ss_len - 2 if suffs_not_prefs else 0
|
||||
cdef unsigned short working_utf16_char, SPACE = 32
|
||||
|
||||
while buf_idx < buf_len:
|
||||
working_utf16_char = (<unsigned short*> &searched_string[ss_idx])[0]
|
||||
if _is_utf16_char_in_scs(working_utf16_char, scs):
|
||||
memcpy(buf, &working_utf16_char, 2)
|
||||
memcpy(buf + buf_idx, &working_utf16_char, 2)
|
||||
buf_idx += 2
|
||||
if suffs_not_prefs:
|
||||
if ss_idx == 0:
|
||||
|
@ -2005,7 +2012,7 @@ cdef void _set_scs_buffer(const unsigned char[:] searched_string, const unsigned
|
|||
break
|
||||
|
||||
while buf_idx < buf_len:
|
||||
memcpy(buf, &SPACE, 2)
|
||||
memcpy(buf + buf_idx, &SPACE, 2)
|
||||
buf_idx += 2
|
||||
|
||||
def pickle_doc(doc):
|
||||
|
|
Loading…
Reference in New Issue
Block a user