diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index 1519043b0..4ae27b2b4 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -977,26 +977,26 @@ def test_doc_spans_setdefault(en_tokenizer): doc.spans.setdefault("key3", default=SpanGroup(doc, spans=[doc[0:1], doc[1:2]])) assert len(doc.spans["key3"]) == 2 -@pytest.mark.parametrize( - "case_sensitive", [True, False] -) -def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive): - def _get_unsigned_32_bit_hash(input:str) -> int: - if not case_sensitive: - input = input.lower() - working_hash = hash(input.encode("UTF-16")[2:]) - if working_hash < 0: - working_hash = working_hash + (2<<31) - return working_hash +def _get_unsigned_32_bit_hash(input: str) -> int: + working_hash = hash(input.encode("UTF-16")[2:]) + if working_hash < 0: + working_hash = working_hash + (2 << 31) + return working_hash + + +@pytest.mark.parametrize("case_sensitive", [True, False]) +def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive): doc = en_tokenizer("spaCy✨ and Prodigy") prefixes = doc.get_affix_hashes(False, case_sensitive, 1, 5, "", 2, 3) - suffixes = doc.get_affix_hashes(True, case_sensitive, 2, 6, "xx✨rP", 1, 3) + suffixes = doc.get_affix_hashes(True, case_sensitive, 2, 6, "xx✨rp", 1, 3) assert prefixes[0][0] == _get_unsigned_32_bit_hash("s") assert prefixes[0][1] == _get_unsigned_32_bit_hash("sp") assert prefixes[0][2] == _get_unsigned_32_bit_hash("spa") - assert prefixes[0][3] == _get_unsigned_32_bit_hash("spaC") + assert prefixes[0][3] == _get_unsigned_32_bit_hash( + "spaC" if case_sensitive else "spac" + ) assert prefixes[0][4] == _get_unsigned_32_bit_hash(" ") assert prefixes[1][0] == _get_unsigned_32_bit_hash("✨") assert prefixes[1][1] == _get_unsigned_32_bit_hash("✨") @@ -1008,18 +1008,28 @@ def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive): assert prefixes[2][2] == _get_unsigned_32_bit_hash("and") assert prefixes[2][3] == _get_unsigned_32_bit_hash("and") assert prefixes[2][4] == _get_unsigned_32_bit_hash(" ") - assert prefixes[3][0] == _get_unsigned_32_bit_hash("P") - assert prefixes[3][1] == _get_unsigned_32_bit_hash("Pr") - assert prefixes[3][2] == _get_unsigned_32_bit_hash("Pro") - assert prefixes[3][3] == _get_unsigned_32_bit_hash("Prod") + assert prefixes[3][0] == _get_unsigned_32_bit_hash("P" if case_sensitive else "p") + assert prefixes[3][1] == _get_unsigned_32_bit_hash("Pr" if case_sensitive else "pr") + assert prefixes[3][2] == _get_unsigned_32_bit_hash( + "Pro" if case_sensitive else "pro" + ) + assert prefixes[3][3] == _get_unsigned_32_bit_hash( + "Prod" if case_sensitive else "prod" + ) assert prefixes[3][4] == _get_unsigned_32_bit_hash(" ") - assert suffixes[0][0] == _get_unsigned_32_bit_hash("Cy") - assert suffixes[0][1] == _get_unsigned_32_bit_hash("aCy") - assert suffixes[0][2] == _get_unsigned_32_bit_hash("paCy") - assert suffixes[0][3] == _get_unsigned_32_bit_hash("spaCy") - assert suffixes[0][4] == _get_unsigned_32_bit_hash(" ") - assert suffixes[0][5] == _get_unsigned_32_bit_hash(" ") + assert suffixes[0][0] == _get_unsigned_32_bit_hash("Cy" if case_sensitive else "cy") + assert suffixes[0][1] == _get_unsigned_32_bit_hash( + "aCy" if case_sensitive else "acy" + ) + assert suffixes[0][2] == _get_unsigned_32_bit_hash( + "paCy" if case_sensitive else "pacy" + ) + assert suffixes[0][3] == _get_unsigned_32_bit_hash( + "spaCy" if case_sensitive else "spacy" + ) + assert suffixes[0][4] == _get_unsigned_32_bit_hash("p") + assert suffixes[0][5] == _get_unsigned_32_bit_hash("p ") assert suffixes[1][0] == _get_unsigned_32_bit_hash("✨") assert suffixes[1][1] == _get_unsigned_32_bit_hash("✨") assert suffixes[1][2] == _get_unsigned_32_bit_hash("✨") @@ -1039,12 +1049,35 @@ def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive): assert suffixes[3][4] == _get_unsigned_32_bit_hash("r") if case_sensitive: - assert suffixes[3][5] == _get_unsigned_32_bit_hash("rP") - else: assert suffixes[3][5] == _get_unsigned_32_bit_hash("r ") + else: + assert suffixes[3][5] == _get_unsigned_32_bit_hash("rp") # check values are the same cross-platform - assert prefixes[0][3] == 18446744072456113490 if case_sensitive else 18446744071614199016 + assert prefixes[0][3] == 753329845 if case_sensitive else 18446744071614199016 assert suffixes[1][0] == 3425774424 assert suffixes[2][5] == 3076404432 + +def test_get_affix_hashes_4_byte_char_at_end(en_tokenizer): + doc = en_tokenizer("and𐌞") + suffixes = doc.get_affix_hashes(True, True, 1, 4, "a", 1, 2) + assert suffixes[0][1] == _get_unsigned_32_bit_hash("𐌞") + assert suffixes[0][2] == _get_unsigned_32_bit_hash("d𐌞") + assert suffixes[0][3] == _get_unsigned_32_bit_hash("a") + + +def test_get_affix_hashes_4_byte_char_in_middle(en_tokenizer): + doc = en_tokenizer("and𐌞a") + suffixes = doc.get_affix_hashes(True, False, 1, 5, "a", 1, 3) + assert suffixes[0][0] == _get_unsigned_32_bit_hash("a") + assert suffixes[0][2] == _get_unsigned_32_bit_hash("𐌞a") + assert suffixes[0][3] == _get_unsigned_32_bit_hash("d𐌞a") + assert suffixes[0][4] == _get_unsigned_32_bit_hash("a") + assert suffixes[0][5] == _get_unsigned_32_bit_hash("aa") + + +def test_get_affixes_4_byte_special_char(en_tokenizer): + doc = en_tokenizer("and𐌞") + with pytest.raises(ValueError): + doc.get_affix_hashes(True, True, 2, 6, "𐌞", 2, 3) diff --git a/spacy/tokens/doc.pxd b/spacy/tokens/doc.pxd index 24c938e1c..831e5749f 100644 --- a/spacy/tokens/doc.pxd +++ b/spacy/tokens/doc.pxd @@ -39,7 +39,14 @@ cdef const unsigned char[:] _get_utf16_memoryview(str unicode_string, bint check cdef bint _is_utf16_char_in_scs(unsigned short utf16_char, const unsigned char[:] scs) -cdef void _set_scs_buffer(const unsigned char[:] searched_string, const unsigned char[:] scs, char* buf, bint suffs_not_prefs) +cdef void _set_scs_buffer( + const unsigned char[:] searched_string, + const unsigned int ss_len, + const unsigned char[:] scs, + char* buf, + const unsigned int buf_len, + const bint suffs_not_prefs +) cdef class Doc: diff --git a/spacy/tokens/doc.pyi b/spacy/tokens/doc.pyi index a40fa74aa..0eacc0479 100644 --- a/spacy/tokens/doc.pyi +++ b/spacy/tokens/doc.pyi @@ -176,3 +176,5 @@ class Doc: def to_utf8_array(self, nr_char: int = ...) -> Ints2d: ... @staticmethod def _get_array_attrs() -> Tuple[Any]: ... + def get_affix_hashes(self, suffs_not_prefs: bool, case_sensitive: bool, len_start: int, len_end: int, + special_chars: str, sc_len_start: int, sc_len_end: int) -> Ints2d: ... diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 5b091c1eb..a347c6834 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -1749,6 +1749,7 @@ cdef class Doc: cdef np.ndarray[np.int64_t, ndim=2] output = numpy.empty((num_toks, num_norm_hashes + num_spec_hashes), dtype="int64") cdef bytes scs_buffer_bytes = (bytes(" " * sc_len_end, "UTF-16"))[2:] # first two bytes express endianness and are not relevant here cdef char* scs_buffer = scs_buffer_bytes + cdef unsigned int buf_len = len(scs_buffer_bytes) cdef attr_t num_tok_attr cdef str str_tok_attr @@ -1768,7 +1769,7 @@ cdef class Doc: working_start = 0 output[tok_ind, norm_hash_ind] = hash32( &tok_str[working_start], working_len, 0) - _set_scs_buffer(tok_str, scs, scs_buffer, suffs_not_prefs) + _set_scs_buffer(tok_str, len_tok_str, scs, scs_buffer, buf_len, suffs_not_prefs) for spec_hash_ind in range(num_spec_hashes): working_len = (sc_len_start + spec_hash_ind) * 2 output[tok_ind, num_norm_hashes + spec_hash_ind] = hash32(scs_buffer, working_len, 0) @@ -1982,18 +1983,24 @@ cdef bint _is_utf16_char_in_scs(const unsigned short utf16_char, const unsigned return False -cdef void _set_scs_buffer(const unsigned char[:] searched_string, const unsigned char[:] scs, char* buf, const bint suffs_not_prefs): +cdef void _set_scs_buffer( + const unsigned char[:] searched_string, + const unsigned int ss_len, + const unsigned char[:] scs, + char* buf, + const unsigned int buf_len, + const bint suffs_not_prefs +): """ Pick the UFT-16 characters from *searched_string* that are also in *scs* and writes them in order to *buf*. If *suffs_not_prefs*, the search starts from the end of *searched_string* rather than from the beginning. """ - cdef unsigned int buf_len = len(buf), buf_idx = 0 - cdef unsigned int ss_len = len(searched_string), ss_idx = ss_len - 2 if suffs_not_prefs else 0 + cdef unsigned int buf_idx = 0, ss_idx = ss_len - 2 if suffs_not_prefs else 0 cdef unsigned short working_utf16_char, SPACE = 32 while buf_idx < buf_len: working_utf16_char = ( &searched_string[ss_idx])[0] if _is_utf16_char_in_scs(working_utf16_char, scs): - memcpy(buf, &working_utf16_char, 2) + memcpy(buf + buf_idx, &working_utf16_char, 2) buf_idx += 2 if suffs_not_prefs: if ss_idx == 0: @@ -2005,7 +2012,7 @@ cdef void _set_scs_buffer(const unsigned char[:] searched_string, const unsigned break while buf_idx < buf_len: - memcpy(buf, &SPACE, 2) + memcpy(buf + buf_idx, &SPACE, 2) buf_idx += 2 def pickle_doc(doc):