diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index c5f138a01..f94c22a75 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -2,6 +2,7 @@ import weakref import numpy from numpy.testing import assert_array_equal +import murmurhash.mrmr import pytest import warnings from thinc.api import NumpyOps, get_current_ops @@ -976,39 +977,70 @@ def test_doc_spans_setdefault(en_tokenizer): doc.spans.setdefault("key3", default=SpanGroup(doc, spans=[doc[0:1], doc[1:2]])) assert len(doc.spans["key3"]) == 2 - @pytest.mark.parametrize( "case_sensitive", [True, False] ) -def test_get_affixes_good_case(en_tokenizer, case_sensitive): +def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive): + + def _get_unsigned_64_bit_hash(input:str) -> int: + if not case_sensitive: + input = input.lower() + return numpy.asarray([murmurhash.mrmr.hash(input)]).astype("uint64")[0] + doc = en_tokenizer("spaCy✨ and Prodigy") - prefixes = doc.get_affixes(False, case_sensitive, 1, 5, "", 2, 3) - suffixes = doc.get_affixes(True, case_sensitive, 2, 6, "xx✨rP", 2, 3) - assert prefixes[3][3, 3] == suffixes[3][3, 3] - assert prefixes[3][3, 2] == suffixes[3][3, 4] - assert suffixes[4][1].tolist() == [10024, 0] + prefixes = doc.get_affix_hashes(False, case_sensitive, 1, 5, "", 2, 3) + suffixes = doc.get_affix_hashes(True, case_sensitive, 2, 6, "xx✨rP", 1, 3) + assert prefixes[0][0] == _get_unsigned_64_bit_hash("s") + assert prefixes[0][1] == _get_unsigned_64_bit_hash("sp") + assert prefixes[0][2] == _get_unsigned_64_bit_hash("spa") + assert prefixes[0][3] == _get_unsigned_64_bit_hash("spaC") + assert prefixes[0][4] == _get_unsigned_64_bit_hash(" ") + assert prefixes[1][0] == _get_unsigned_64_bit_hash("✨") + assert prefixes[1][1] == _get_unsigned_64_bit_hash("✨") + assert prefixes[1][2] == _get_unsigned_64_bit_hash("✨") + assert prefixes[1][3] == _get_unsigned_64_bit_hash("✨") + assert prefixes[1][4] == _get_unsigned_64_bit_hash(" ") + assert prefixes[2][0] == _get_unsigned_64_bit_hash("a") + assert prefixes[2][1] == _get_unsigned_64_bit_hash("an") + assert prefixes[2][2] == _get_unsigned_64_bit_hash("and") + assert prefixes[2][3] == _get_unsigned_64_bit_hash("and") + assert prefixes[2][4] == _get_unsigned_64_bit_hash(" ") + assert prefixes[3][0] == _get_unsigned_64_bit_hash("P") + assert prefixes[3][1] == _get_unsigned_64_bit_hash("Pr") + assert prefixes[3][2] == _get_unsigned_64_bit_hash("Pro") + assert prefixes[3][3] == _get_unsigned_64_bit_hash("Prod") + assert prefixes[3][4] == _get_unsigned_64_bit_hash(" ") + + assert suffixes[0][0] == _get_unsigned_64_bit_hash("yC") + assert suffixes[0][1] == _get_unsigned_64_bit_hash("yCa") + assert suffixes[0][2] == _get_unsigned_64_bit_hash("yCap") + assert suffixes[0][3] == _get_unsigned_64_bit_hash("yCaps") + assert suffixes[0][4] == _get_unsigned_64_bit_hash(" ") + assert suffixes[0][5] == _get_unsigned_64_bit_hash(" ") + assert suffixes[1][0] == _get_unsigned_64_bit_hash("✨") + assert suffixes[1][1] == _get_unsigned_64_bit_hash("✨") + assert suffixes[1][2] == _get_unsigned_64_bit_hash("✨") + assert suffixes[1][3] == _get_unsigned_64_bit_hash("✨") + assert suffixes[1][4] == _get_unsigned_64_bit_hash("✨") + assert suffixes[1][5] == _get_unsigned_64_bit_hash("✨ ") + assert suffixes[2][0] == _get_unsigned_64_bit_hash("dn") + assert suffixes[2][1] == _get_unsigned_64_bit_hash("dna") + assert suffixes[2][2] == _get_unsigned_64_bit_hash("dna") + assert suffixes[2][3] == _get_unsigned_64_bit_hash("dna") + assert suffixes[2][4] == _get_unsigned_64_bit_hash(" ") + assert suffixes[2][5] == _get_unsigned_64_bit_hash(" ") + assert suffixes[3][0] == _get_unsigned_64_bit_hash("yg") + assert suffixes[3][1] == _get_unsigned_64_bit_hash("ygi") + assert suffixes[3][2] == _get_unsigned_64_bit_hash("ygid") + assert suffixes[3][3] == _get_unsigned_64_bit_hash("ygido") + assert suffixes[3][4] == _get_unsigned_64_bit_hash("r") + if case_sensitive: - assert suffixes[4][3].tolist() == [114, 80] + assert suffixes[3][5] == _get_unsigned_64_bit_hash("rP") else: - assert suffixes[4][3].tolist() == [114, 112] - suffixes = doc.get_affixes(True, case_sensitive, 2, 6, "xx✨rp", 2, 3) - if case_sensitive: - assert suffixes[4][3].tolist() == [114, 0] - else: - assert suffixes[4][3].tolist() == [114, 112] - + assert suffixes[3][5] == _get_unsigned_64_bit_hash("r ") - -def test_get_affixes_4_byte_normal_char(en_tokenizer): - doc = en_tokenizer("and𐌞") - suffixes = doc.get_affixes(True, False, 2, 6, "a", 1, 2) - for i in range(0, 4): - assert suffixes[i][0, 1] == 55296 - assert suffixes[3][0, 4] == 97 - assert suffixes[4][0, 0] == 97 - - -def test_get_affixes_4_byte_special_char(en_tokenizer): - doc = en_tokenizer("and𐌞") - with pytest.raises(ValueError): - doc.get_affixes(True, False, 2, 6, "𐌞", 2, 3) + # check values are the same cross-platform + assert prefixes[0][3] == 18446744072456113490 if case_sensitive else 18446744071614199016 + assert suffixes[1][0] == 910783208 + assert suffixes[2][5] == 1696457176 \ No newline at end of file diff --git a/spacy/tokens/doc.pxd b/spacy/tokens/doc.pxd index 56176a3a7..eede5160d 100644 --- a/spacy/tokens/doc.pxd +++ b/spacy/tokens/doc.pxd @@ -9,7 +9,6 @@ from ..attrs cimport attr_id_t cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) nogil -cdef const unsigned char[:] get_utf16_memoryview(str unicode_string, bint check_2_bytes) ctypedef const LexemeC* const_Lexeme_ptr diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 3cd8281b8..9bf0b733d 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -16,6 +16,7 @@ import srsly from thinc.api import get_array_module, get_current_ops from thinc.util import copy_array import warnings +import murmurhash.mrmr from .span cimport Span from .token cimport MISSING_DEP @@ -94,18 +95,6 @@ cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) return get_token_attr(token, feat_name) -cdef const unsigned char[:] get_utf16_memoryview(str unicode_string, bint check_2_bytes): - """ - Returns a memory view of the UTF-16 representation of a string with the default endianness of the platform. - Throws a ValueError if *check_2_bytes == True* and one or more characters in the UTF-16 representation - occupy four bytes rather than two. - """ - cdef const unsigned char[:] view = memoryview(unicode_string.encode("UTF-16"))[2:] # first two bytes are endianness - if check_2_bytes and len(unicode_string) * 2 != len(view): - raise ValueError(Errors.E1044) - return view - - class SetEntsDefault(str, Enum): blocked = "blocked" missing = "missing" @@ -1746,63 +1735,31 @@ cdef class Doc: j += 1 return output - def get_affixes(self, bint suffs_not_prefs, bint case_sensitive, unsigned int len_start, unsigned int len_end, + def get_affix_hashes(self, bint suffs_not_prefs, bint lower_not_orth, unsigned int len_start, unsigned int len_end, str special_chars, unsigned int sc_len_start, unsigned int sc_len_end): """ TODO """ - if case_sensitive: - token_attrs = [t.orth_ for t in self] - else: - token_attrs = [t.lower_ for t in self] - special_chars = special_chars.lower() - cdef unsigned int sc_len = len(special_chars) - cdef const unsigned char[:] sc_bytes = get_utf16_memoryview(special_chars, True) - cdef np.ndarray[np.uint16_t, ndim=1] scs = numpy.ndarray((sc_len,), buffer=sc_bytes, dtype="uint16") - cdef np.ndarray[np.uint16_t, ndim=2] output - cdef unsigned int num_tokens = len(self), num_normal_affixes = len_end - len_start, working_len + cdef unsigned int token_index, norm_hash_index, spec_hash_index + cdef str token_string, specials_string, working_substring + cdef unsigned int num_norm_hashes = len_end - len_start, num_spec_hashes = sc_len_end - sc_len_start, num_tokens = len(self) + cdef np.ndarray[np.int64_t, ndim=2] output = numpy.empty((num_tokens, num_norm_hashes + num_spec_hashes), dtype="int64") - outputs = [] - for working_len in range(num_normal_affixes): - output = numpy.zeros((num_tokens, len_start + working_len), dtype="uint16") - outputs.append(output) - for working_len in range(sc_len_end - sc_len_start): - output = numpy.zeros((num_tokens, sc_len_start + working_len), dtype="uint16") - outputs.append(output) + for token_index in range(num_tokens): + token_string = self[token_index].orth_ if lower_not_orth else self[token_index].lower_ + if suffs_not_prefs: + token_string = token_string[::-1] - cdef const unsigned char[:] token_bytes - cdef np.uint16_t working_char - cdef unsigned int token_bytes_len, token_idx, char_idx, sc_char_idx, sc_test_idx, working_sc_len - cdef unsigned int char_byte_idx + for norm_hash_index in range(num_norm_hashes): + working_substring = token_string[: len_start + norm_hash_index] + output[token_index, norm_hash_index] = murmurhash.mrmr.hash(working_substring) - for token_idx in range(num_tokens): - token_bytes = get_utf16_memoryview(token_attrs[token_idx], False) - char_idx = 0 - sc_char_idx = 0 - token_bytes_len = len(token_bytes) + specials_string = "".join([c for c in token_string if c in special_chars]) + " " * sc_len_end + for spec_hash_index in range(num_spec_hashes): + working_substring = specials_string[: sc_len_start + spec_hash_index] + output[token_index, num_norm_hashes + spec_hash_index] = murmurhash.mrmr.hash(working_substring) - while (char_idx < len_end - 1 or sc_char_idx < sc_len_end - 1) and char_idx * 2 < token_bytes_len: - if suffs_not_prefs: - char_byte_idx = token_bytes_len - 2 * (char_idx + 1) - else: - char_byte_idx = char_idx * 2 - working_char = ( &token_bytes[char_byte_idx])[0] - for working_len in range(len_end-1, len_start-1, -1): - if char_idx >= working_len: - break - outputs[working_len - len_start][token_idx, char_idx] = working_char - sc_test_idx = 0 - while sc_len > sc_test_idx: - if working_char == scs[sc_test_idx]: - for working_sc_len in range(sc_len_end-1, sc_len_start-1, -1): - if sc_char_idx >= working_sc_len: - break - outputs[num_normal_affixes + working_sc_len - sc_len_start][token_idx, sc_char_idx] = working_char - sc_char_idx += 1 - break - sc_test_idx += 1 - char_idx += 1 - return outputs + return output.astype("uint64", casting="unsafe", copy=False) @staticmethod def _get_array_attrs():