diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index 4ae27b2b4..bc68a2ab0 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -986,37 +986,41 @@ def _get_unsigned_32_bit_hash(input: str) -> int: @pytest.mark.parametrize("case_sensitive", [True, False]) -def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive): +def test_get_character_combination_hashes_good_case(en_tokenizer, case_sensitive): doc = en_tokenizer("spaCy✨ and Prodigy") - prefixes = doc.get_affix_hashes(False, case_sensitive, 1, 5, "", 2, 3) - suffixes = doc.get_affix_hashes(True, case_sensitive, 2, 6, "xx✨rp", 1, 3) + prefixes = doc.get_character_combination_hashes(case_sensitive=case_sensitive, + suffs_not_prefs=False, + affix_lengths=[1, 4, 3], + search_chars="", + search_lengths=[2]) + suffixes = doc.get_character_combination_hashes(case_sensitive=case_sensitive, + suffs_not_prefs=True, + affix_lengths=[2, 3, 4, 5], + search_chars="xx✨rp", + search_lengths=[2, 1]) assert prefixes[0][0] == _get_unsigned_32_bit_hash("s") - assert prefixes[0][1] == _get_unsigned_32_bit_hash("sp") - assert prefixes[0][2] == _get_unsigned_32_bit_hash("spa") - assert prefixes[0][3] == _get_unsigned_32_bit_hash( + assert prefixes[0][1] == _get_unsigned_32_bit_hash( "spaC" if case_sensitive else "spac" ) - assert prefixes[0][4] == _get_unsigned_32_bit_hash(" ") + assert prefixes[0][2] == _get_unsigned_32_bit_hash("spa") + assert prefixes[0][3] == _get_unsigned_32_bit_hash(" ") assert prefixes[1][0] == _get_unsigned_32_bit_hash("✨") assert prefixes[1][1] == _get_unsigned_32_bit_hash("✨") assert prefixes[1][2] == _get_unsigned_32_bit_hash("✨") - assert prefixes[1][3] == _get_unsigned_32_bit_hash("✨") - assert prefixes[1][4] == _get_unsigned_32_bit_hash(" ") + assert prefixes[1][3] == _get_unsigned_32_bit_hash(" ") assert prefixes[2][0] == _get_unsigned_32_bit_hash("a") - assert prefixes[2][1] == _get_unsigned_32_bit_hash("an") + assert prefixes[2][1] == _get_unsigned_32_bit_hash("and") assert prefixes[2][2] == _get_unsigned_32_bit_hash("and") - assert prefixes[2][3] == _get_unsigned_32_bit_hash("and") - assert prefixes[2][4] == _get_unsigned_32_bit_hash(" ") + assert prefixes[2][3] == _get_unsigned_32_bit_hash(" ") assert prefixes[3][0] == _get_unsigned_32_bit_hash("P" if case_sensitive else "p") - assert prefixes[3][1] == _get_unsigned_32_bit_hash("Pr" if case_sensitive else "pr") + assert prefixes[3][1] == _get_unsigned_32_bit_hash( + "Prod" if case_sensitive else "prod" + ) assert prefixes[3][2] == _get_unsigned_32_bit_hash( "Pro" if case_sensitive else "pro" ) - assert prefixes[3][3] == _get_unsigned_32_bit_hash( - "Prod" if case_sensitive else "prod" - ) - assert prefixes[3][4] == _get_unsigned_32_bit_hash(" ") + assert prefixes[3][3] == _get_unsigned_32_bit_hash(" ") assert suffixes[0][0] == _get_unsigned_32_bit_hash("Cy" if case_sensitive else "cy") assert suffixes[0][1] == _get_unsigned_32_bit_hash( @@ -1028,48 +1032,58 @@ def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive): assert suffixes[0][3] == _get_unsigned_32_bit_hash( "spaCy" if case_sensitive else "spacy" ) - assert suffixes[0][4] == _get_unsigned_32_bit_hash("p") - assert suffixes[0][5] == _get_unsigned_32_bit_hash("p ") + assert suffixes[0][4] == _get_unsigned_32_bit_hash("p ") + assert suffixes[0][5] == _get_unsigned_32_bit_hash("p") assert suffixes[1][0] == _get_unsigned_32_bit_hash("✨") assert suffixes[1][1] == _get_unsigned_32_bit_hash("✨") assert suffixes[1][2] == _get_unsigned_32_bit_hash("✨") assert suffixes[1][3] == _get_unsigned_32_bit_hash("✨") - assert suffixes[1][4] == _get_unsigned_32_bit_hash("✨") - assert suffixes[1][5] == _get_unsigned_32_bit_hash("✨ ") + assert suffixes[1][4] == _get_unsigned_32_bit_hash("✨ ") + assert suffixes[1][5] == _get_unsigned_32_bit_hash("✨") assert suffixes[2][0] == _get_unsigned_32_bit_hash("nd") assert suffixes[2][1] == _get_unsigned_32_bit_hash("and") assert suffixes[2][2] == _get_unsigned_32_bit_hash("and") assert suffixes[2][3] == _get_unsigned_32_bit_hash("and") - assert suffixes[2][4] == _get_unsigned_32_bit_hash(" ") - assert suffixes[2][5] == _get_unsigned_32_bit_hash(" ") + assert suffixes[2][4] == _get_unsigned_32_bit_hash(" ") + assert suffixes[2][5] == _get_unsigned_32_bit_hash(" ") assert suffixes[3][0] == _get_unsigned_32_bit_hash("gy") assert suffixes[3][1] == _get_unsigned_32_bit_hash("igy") assert suffixes[3][2] == _get_unsigned_32_bit_hash("digy") assert suffixes[3][3] == _get_unsigned_32_bit_hash("odigy") - assert suffixes[3][4] == _get_unsigned_32_bit_hash("r") + assert suffixes[3][5] == _get_unsigned_32_bit_hash("r") if case_sensitive: - assert suffixes[3][5] == _get_unsigned_32_bit_hash("r ") + assert suffixes[3][4] == _get_unsigned_32_bit_hash("r ") else: - assert suffixes[3][5] == _get_unsigned_32_bit_hash("rp") + assert suffixes[3][4] == _get_unsigned_32_bit_hash("rp") # check values are the same cross-platform - assert prefixes[0][3] == 753329845 if case_sensitive else 18446744071614199016 + assert prefixes[0][1] == 753329845 if case_sensitive else 18446744071614199016 assert suffixes[1][0] == 3425774424 - assert suffixes[2][5] == 3076404432 + assert suffixes[2][4] == 3076404432 -def test_get_affix_hashes_4_byte_char_at_end(en_tokenizer): +def test_get_character_combination_hashes_4_byte_char_at_end(en_tokenizer): doc = en_tokenizer("and𐌞") - suffixes = doc.get_affix_hashes(True, True, 1, 4, "a", 1, 2) + suffixes = doc.get_character_combination_hashes( + case_sensitive=True, + suffs_not_prefs=True, + affix_lengths=[1, 2, 3], + search_chars="a", + search_lengths=[1]) assert suffixes[0][1] == _get_unsigned_32_bit_hash("𐌞") assert suffixes[0][2] == _get_unsigned_32_bit_hash("d𐌞") assert suffixes[0][3] == _get_unsigned_32_bit_hash("a") -def test_get_affix_hashes_4_byte_char_in_middle(en_tokenizer): +def test_get_character_combination_hashes_4_byte_char_in_middle(en_tokenizer): doc = en_tokenizer("and𐌞a") - suffixes = doc.get_affix_hashes(True, False, 1, 5, "a", 1, 3) + suffixes = doc.get_character_combination_hashes( + case_sensitive=False, + suffs_not_prefs=True, + affix_lengths=[1, 2, 3, 4], + search_chars="a", + search_lengths=[1, 2]) assert suffixes[0][0] == _get_unsigned_32_bit_hash("a") assert suffixes[0][2] == _get_unsigned_32_bit_hash("𐌞a") assert suffixes[0][3] == _get_unsigned_32_bit_hash("d𐌞a") @@ -1077,7 +1091,19 @@ def test_get_affix_hashes_4_byte_char_in_middle(en_tokenizer): assert suffixes[0][5] == _get_unsigned_32_bit_hash("aa") -def test_get_affixes_4_byte_special_char(en_tokenizer): +def test_get_character_combination_hashes_4_byte_special_char(en_tokenizer): doc = en_tokenizer("and𐌞") with pytest.raises(ValueError): - doc.get_affix_hashes(True, True, 2, 6, "𐌞", 2, 3) + doc.get_character_combination_hashes(case_sensitive=True, + suffs_not_prefs=True, + affix_lengths=[2, 3, 4, 5], + search_chars="𐌞", + search_lengths=[2]) + +def test_character_combination_hashes_empty_lengths(en_tokenizer): + doc = en_tokenizer("and𐌞") + assert doc.get_character_combination_hashes(case_sensitive=True, + suffs_not_prefs=True, + affix_lengths=[], + search_chars="", + search_lengths=[]).shape == (1, 0) diff --git a/spacy/tokens/doc.pxd b/spacy/tokens/doc.pxd index 831e5749f..e8b00051b 100644 --- a/spacy/tokens/doc.pxd +++ b/spacy/tokens/doc.pxd @@ -33,19 +33,24 @@ cdef int token_by_end(const TokenC* tokens, int length, int end_char) except -2 cdef int [:,:] _get_lca_matrix(Doc, int start, int end) -cdef const unsigned char[:] _get_utf16_memoryview(str unicode_string, bint check_2_bytes) +cdef const unsigned char[:] _get_utf16_memoryview(str unicode_string, const bint check_2_bytes) -cdef bint _is_utf16_char_in_scs(unsigned short utf16_char, const unsigned char[:] scs) +cdef bint _is_searched_char_in_search_chars_v( + const unsigned short searched_char, + const unsigned char[:] search_chars_v, + const unsigned int search_chars_v_len) -cdef void _set_scs_buffer( - const unsigned char[:] searched_string, - const unsigned int ss_len, - const unsigned char[:] scs, - char* buf, - const unsigned int buf_len, - const bint suffs_not_prefs +cdef void _set_found_char_buf( + const bint suffs_not_prefs, + const unsigned char[:] searched_string_v, + const unsigned int searched_string_len, + const unsigned char[:] search_chars_v, + const unsigned int search_chars_v_len, + char* found_char_buf, + const unsigned int found_char_buf_len, + ) diff --git a/spacy/tokens/doc.pyi b/spacy/tokens/doc.pyi index 0eacc0479..f52d0b18f 100644 --- a/spacy/tokens/doc.pyi +++ b/spacy/tokens/doc.pyi @@ -174,7 +174,15 @@ class Doc: self, doc_json: Dict[str, Any] = ..., validate: bool = False ) -> Doc: ... def to_utf8_array(self, nr_char: int = ...) -> Ints2d: ... + def get_character_combination_hashes( + self, + * + case_sensitive: bool, + suffs_not_prefs: bool, + affix_lengths: List[int], + search_chars: str, + search_lengths: List[int] + ): ... @staticmethod def _get_array_attrs() -> Tuple[Any]: ... - def get_affix_hashes(self, suffs_not_prefs: bool, case_sensitive: bool, len_start: int, len_end: int, - special_chars: str, sc_len_start: int, sc_len_end: int) -> Ints2d: ... + \ No newline at end of file diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index a347c6834..c2bbf7f6a 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -1,5 +1,5 @@ # cython: infer_types=True, bounds_check=False, profile=True -from typing import Set +from typing import Set, List cimport cython cimport numpy as np @@ -1738,43 +1738,89 @@ cdef class Doc: return output - def get_affix_hashes(self, bint suffs_not_prefs, bint case_sensitive, unsigned int len_start, unsigned int len_end, - str special_chars, unsigned int sc_len_start, unsigned int sc_len_end): + def get_character_combination_hashes( + self, + *, + bint case_sensitive, + bint suffs_not_prefs, + affix_lengths: List[int], + str search_chars, + search_lengths: List[int] + ): """ - TODO + Returns a 2D NumPy array where the rows represent tokens and the columns represent hashes of various character combinations + derived from the string (text/orth) of each token. + + case_sensitive: if *True*, the lower-case version of each token string is used as the basis for generating hashes. Note that + if *case_sensitive==False*, any upper-case characters in *search_chars* will never be found in token strings. + suffs_not_prefs: if *True*, affixes are suffixes, and searching are from the end of each token; + if *False*, affixes are prefixes, and searching is from the start of each token. + affix_lengths: an integer list specifying the lengths of affixes to be hashed. For example, if *affix_lengths==[2, 3]*, + *suffs_not_prefs==True* and *case_sensitive==True*, the suffixes hashed for "spaCy" would be "Cy" and "aCy". + search_chars: a string containing characters to search for within each token, starting at the beginning or end depending on the + value of *suffs_not_prefs*. + search_lengths: an integer list specifying the lengths of search results to be hashed. For example if *search_lengths==[1, 2]*, + *search_chars=="aC", *suffs_not_prefs==True* and *case_sensitive==True*, the searched strings hashed for "spaCy" would be + "C" and "Ca". + + For a document with tokens ["spaCy", "and", "Prodigy"], the NumPy array returned by + *get_affix_hashes(True, True, [2, 4, 6], "yC", [1, 2])* would correspond to + + [[hash("Cy"), hash("paCy"), hash("spaCy"), hash("y"), hash("yC")], + [hash("nd"), hash("and", hash("and"), hash(" "), hash(" "))], + [hash("gy"), hash("digy"), hash("rodigy"), hash("y"), hash("y ")]] + + UTF-16 is used to encode the token texts, as this results in two-byte representations for all characters that are realistically + likely to occur in normal spaCy documents. UTF-16 can also contain four-byte representations, but neither of the byte pairs in + a four-byte representation is ever valid in its own right as a two-byte representation. in the rare case that a four-byte + representation occurs in a string being analysed, each of its two-byte pairs is treated as a separate character, while a four-byte + representation in *search_chars* is not supported and results in a ValueError(E1046). """ - cdef unsigned int tok_ind, norm_hash_ind, spec_hash_ind, len_tok_str, working_start, working_len - cdef unsigned int num_norm_hashes = len_end - len_start, num_spec_hashes = sc_len_end - sc_len_start, num_toks = len(self) - cdef const unsigned char[:] tok_str, scs = _get_utf16_memoryview(special_chars, True) - cdef np.ndarray[np.int64_t, ndim=2] output = numpy.empty((num_toks, num_norm_hashes + num_spec_hashes), dtype="int64") - cdef bytes scs_buffer_bytes = (bytes(" " * sc_len_end, "UTF-16"))[2:] # first two bytes express endianness and are not relevant here - cdef char* scs_buffer = scs_buffer_bytes - cdef unsigned int buf_len = len(scs_buffer_bytes) + cdef const unsigned char[:] search_chars_v = _get_utf16_memoryview(search_chars, True) + cdef unsigned int longest_search_length = max(search_lengths) if len(search_lengths) > 0 else 0 + cdef bytes found_char_buf_bytes = (bytes(" " * longest_search_length, "UTF-16"))[2:] # first two bytes express endianness + cdef char* found_char_buf = found_char_buf_bytes + cdef unsigned int search_chars_v_len = len(search_chars_v), found_char_buf_len = len(found_char_buf_bytes) + + cdef unsigned int num_toks = len(self), num_norm_hashes = len(affix_lengths), num_spec_hashes = len(search_lengths) + cdef np.ndarray[np.int64_t, ndim=2] hashes = numpy.empty((num_toks, num_norm_hashes + num_spec_hashes), dtype="int64") + + cdef const unsigned char[:] tok_str_v + cdef unsigned int tok_idx, tok_str_v_len, hash_idx, affix_start, hash_len cdef attr_t num_tok_attr cdef str str_tok_attr - - for tok_ind in range(num_toks): - num_tok_attr = self.c[tok_ind].lex.orth if case_sensitive else self.c[tok_ind].lex.lower + + for tok_idx in range(num_toks): + num_tok_attr = self.c[tok_idx].lex.orth if case_sensitive else self.c[tok_idx].lex.lower str_tok_attr = self.vocab.strings[num_tok_attr] - tok_str = _get_utf16_memoryview(str_tok_attr, False) - len_tok_str = len(tok_str) + tok_str_v = _get_utf16_memoryview(str_tok_attr, False) + tok_str_v_len = len(tok_str_v) - for norm_hash_ind in range(num_norm_hashes): - working_len = (len_start + norm_hash_ind) * 2 - if working_len > len_tok_str: - working_len = len_tok_str + for hash_idx in range(num_norm_hashes): + hash_len = affix_lengths[hash_idx] * 2 + if hash_len > tok_str_v_len: + hash_len = tok_str_v_len if suffs_not_prefs: - working_start = len_tok_str - working_len + affix_start = tok_str_v_len - hash_len else: - working_start = 0 - output[tok_ind, norm_hash_ind] = hash32( &tok_str[working_start], working_len, 0) + affix_start = 0 + hashes[tok_idx, hash_idx] = hash32( &tok_str_v[affix_start], hash_len, 0) - _set_scs_buffer(tok_str, len_tok_str, scs, scs_buffer, buf_len, suffs_not_prefs) - for spec_hash_ind in range(num_spec_hashes): - working_len = (sc_len_start + spec_hash_ind) * 2 - output[tok_ind, num_norm_hashes + spec_hash_ind] = hash32(scs_buffer, working_len, 0) + _set_found_char_buf( + suffs_not_prefs, + tok_str_v, + tok_str_v_len, + search_chars_v, + search_chars_v_len, + found_char_buf, + found_char_buf_len, + ) + + for hash_idx in range(num_norm_hashes, num_norm_hashes + num_spec_hashes): + hash_len = search_lengths[hash_idx - num_norm_hashes] * 2 + hashes[tok_idx, hash_idx] = hash32(found_char_buf, hash_len, 0) - return output + return hashes @staticmethod def _get_array_attrs(): @@ -1974,46 +2020,51 @@ cdef const unsigned char[:] _get_utf16_memoryview(str unicode_string, const bint return view -cdef bint _is_utf16_char_in_scs(const unsigned short utf16_char, const unsigned char[:] scs): - cdef unsigned int scs_idx = 0, scs_len = len(scs) - while scs_idx < scs_len: - if utf16_char == ( &scs[scs_idx])[0]: +cdef bint _is_searched_char_in_search_chars_v( + const unsigned short searched_char, + const unsigned char[:] search_chars_v, + const unsigned int search_chars_v_len +): + cdef unsigned int search_chars_v_idx = 0 + while search_chars_v_idx < search_chars_v_len: + if searched_char == ( &search_chars_v[search_chars_v_idx])[0]: return True - scs_idx += 2 + search_chars_v_idx += 2 return False -cdef void _set_scs_buffer( - const unsigned char[:] searched_string, - const unsigned int ss_len, - const unsigned char[:] scs, - char* buf, - const unsigned int buf_len, - const bint suffs_not_prefs +cdef void _set_found_char_buf( + const bint suffs_not_prefs, + const unsigned char[:] searched_string_v, + const unsigned int searched_string_v_len, + const unsigned char[:] search_chars_v, + const unsigned int search_chars_v_len, + char* found_char_buf, + const unsigned int found_char_buf_len, ): - """ Pick the UFT-16 characters from *searched_string* that are also in *scs* and writes them in order to *buf*. + """ Pick the UTF-16 characters from *searched_string_v* that are also in *search_chars_v* and writes them in order to *found_char_buf*. If *suffs_not_prefs*, the search starts from the end of *searched_string* rather than from the beginning. """ - cdef unsigned int buf_idx = 0, ss_idx = ss_len - 2 if suffs_not_prefs else 0 - cdef unsigned short working_utf16_char, SPACE = 32 + cdef unsigned int found_char_buf_idx = 0, searched_string_idx = searched_string_v_len - 2 if suffs_not_prefs else 0 + cdef unsigned short searched_char, SPACE = 32 - while buf_idx < buf_len: - working_utf16_char = ( &searched_string[ss_idx])[0] - if _is_utf16_char_in_scs(working_utf16_char, scs): - memcpy(buf + buf_idx, &working_utf16_char, 2) - buf_idx += 2 + while found_char_buf_idx < found_char_buf_len: + searched_char = ( &searched_string_v[searched_string_idx])[0] + if _is_searched_char_in_search_chars_v(searched_char, search_chars_v, search_chars_v_len): + memcpy(found_char_buf + found_char_buf_idx, &searched_char, 2) + found_char_buf_idx += 2 if suffs_not_prefs: - if ss_idx == 0: + if searched_string_idx == 0: break - ss_idx -= 2 + searched_string_idx -= 2 else: - ss_idx += 2 - if ss_idx == ss_len: + searched_string_idx += 2 + if searched_string_idx == searched_string_v_len: break - while buf_idx < buf_len: - memcpy(buf + buf_idx, &SPACE, 2) - buf_idx += 2 + while found_char_buf_idx < found_char_buf_len: + memcpy(found_char_buf + found_char_buf_idx, &SPACE, 2) + found_char_buf_idx += 2 def pickle_doc(doc): bytes_data = doc.to_bytes(exclude=["vocab", "user_data", "user_hooks"])