Intermediate state

This commit is contained in:
richardpaulhudson 2022-09-30 22:26:14 +02:00
parent da63b9448b
commit d296ae9d8e
2 changed files with 70 additions and 73 deletions

View File

@ -2,7 +2,7 @@ import weakref
import numpy import numpy
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
import murmurhash.mrmr from murmurhash.mrmr import hash
import pytest import pytest
import warnings import warnings
from thinc.api import NumpyOps, get_current_ops from thinc.api import NumpyOps, get_current_ops
@ -982,65 +982,69 @@ def test_doc_spans_setdefault(en_tokenizer):
) )
def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive): def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive):
def _get_unsigned_64_bit_hash(input:str) -> int: def _get_unsigned_32_bit_hash(input:str) -> int:
if not case_sensitive: if not case_sensitive:
input = input.lower() input = input.lower()
return numpy.asarray([murmurhash.mrmr.hash(input)]).astype("uint64")[0] working_hash = hash(input.encode("UTF-16")[2:])
if working_hash < 0:
working_hash = working_hash + (2<<31)
return working_hash
doc = en_tokenizer("spaCy✨ and Prodigy") doc = en_tokenizer("spaCy✨ and Prodigy")
prefixes = doc.get_affix_hashes(False, case_sensitive, 1, 5, "", 2, 3) prefixes = doc.get_affix_hashes(False, case_sensitive, 1, 5, "", 2, 3)
suffixes = doc.get_affix_hashes(True, case_sensitive, 2, 6, "xx✨rP", 1, 3) suffixes = doc.get_affix_hashes(True, case_sensitive, 2, 6, "xx✨rP", 1, 3)
assert prefixes[0][0] == _get_unsigned_64_bit_hash("s") assert prefixes[0][0] == _get_unsigned_32_bit_hash("s")
assert prefixes[0][1] == _get_unsigned_64_bit_hash("sp") assert prefixes[0][1] == _get_unsigned_32_bit_hash("sp")
assert prefixes[0][2] == _get_unsigned_64_bit_hash("spa") assert prefixes[0][2] == _get_unsigned_32_bit_hash("spa")
assert prefixes[0][3] == _get_unsigned_64_bit_hash("spaC") assert prefixes[0][3] == _get_unsigned_32_bit_hash("spaC")
assert prefixes[0][4] == _get_unsigned_64_bit_hash(" ") assert prefixes[0][4] == _get_unsigned_32_bit_hash(" ")
assert prefixes[1][0] == _get_unsigned_64_bit_hash("") assert prefixes[1][0] == _get_unsigned_32_bit_hash("")
assert prefixes[1][1] == _get_unsigned_64_bit_hash("") assert prefixes[1][1] == _get_unsigned_32_bit_hash("")
assert prefixes[1][2] == _get_unsigned_64_bit_hash("") assert prefixes[1][2] == _get_unsigned_32_bit_hash("")
assert prefixes[1][3] == _get_unsigned_64_bit_hash("") assert prefixes[1][3] == _get_unsigned_32_bit_hash("")
assert prefixes[1][4] == _get_unsigned_64_bit_hash(" ") assert prefixes[1][4] == _get_unsigned_32_bit_hash(" ")
assert prefixes[2][0] == _get_unsigned_64_bit_hash("a") assert prefixes[2][0] == _get_unsigned_32_bit_hash("a")
assert prefixes[2][1] == _get_unsigned_64_bit_hash("an") assert prefixes[2][1] == _get_unsigned_32_bit_hash("an")
assert prefixes[2][2] == _get_unsigned_64_bit_hash("and") assert prefixes[2][2] == _get_unsigned_32_bit_hash("and")
assert prefixes[2][3] == _get_unsigned_64_bit_hash("and") assert prefixes[2][3] == _get_unsigned_32_bit_hash("and")
assert prefixes[2][4] == _get_unsigned_64_bit_hash(" ") assert prefixes[2][4] == _get_unsigned_32_bit_hash(" ")
assert prefixes[3][0] == _get_unsigned_64_bit_hash("P") assert prefixes[3][0] == _get_unsigned_32_bit_hash("P")
assert prefixes[3][1] == _get_unsigned_64_bit_hash("Pr") assert prefixes[3][1] == _get_unsigned_32_bit_hash("Pr")
assert prefixes[3][2] == _get_unsigned_64_bit_hash("Pro") assert prefixes[3][2] == _get_unsigned_32_bit_hash("Pro")
assert prefixes[3][3] == _get_unsigned_64_bit_hash("Prod") assert prefixes[3][3] == _get_unsigned_32_bit_hash("Prod")
assert prefixes[3][4] == _get_unsigned_64_bit_hash(" ") assert prefixes[3][4] == _get_unsigned_32_bit_hash(" ")
assert suffixes[0][0] == _get_unsigned_64_bit_hash("yC") assert suffixes[0][0] == _get_unsigned_32_bit_hash("Cy")
assert suffixes[0][1] == _get_unsigned_64_bit_hash("yCa") assert suffixes[0][1] == _get_unsigned_32_bit_hash("aCy")
assert suffixes[0][2] == _get_unsigned_64_bit_hash("yCap") assert suffixes[0][2] == _get_unsigned_32_bit_hash("paCy")
assert suffixes[0][3] == _get_unsigned_64_bit_hash("yCaps") assert suffixes[0][3] == _get_unsigned_32_bit_hash("spaCy")
assert suffixes[0][4] == _get_unsigned_64_bit_hash(" ") assert suffixes[0][4] == _get_unsigned_32_bit_hash(" ")
assert suffixes[0][5] == _get_unsigned_64_bit_hash(" ") assert suffixes[0][5] == _get_unsigned_32_bit_hash(" ")
assert suffixes[1][0] == _get_unsigned_64_bit_hash("") assert suffixes[1][0] == _get_unsigned_32_bit_hash("")
assert suffixes[1][1] == _get_unsigned_64_bit_hash("") assert suffixes[1][1] == _get_unsigned_32_bit_hash("")
assert suffixes[1][2] == _get_unsigned_64_bit_hash("") assert suffixes[1][2] == _get_unsigned_32_bit_hash("")
assert suffixes[1][3] == _get_unsigned_64_bit_hash("") assert suffixes[1][3] == _get_unsigned_32_bit_hash("")
assert suffixes[1][4] == _get_unsigned_64_bit_hash("") assert suffixes[1][4] == _get_unsigned_32_bit_hash("")
assert suffixes[1][5] == _get_unsigned_64_bit_hash("") assert suffixes[1][5] == _get_unsigned_32_bit_hash("")
assert suffixes[2][0] == _get_unsigned_64_bit_hash("dn") assert suffixes[2][0] == _get_unsigned_32_bit_hash("nd")
assert suffixes[2][1] == _get_unsigned_64_bit_hash("dna") assert suffixes[2][1] == _get_unsigned_32_bit_hash("and")
assert suffixes[2][2] == _get_unsigned_64_bit_hash("dna") assert suffixes[2][2] == _get_unsigned_32_bit_hash("and")
assert suffixes[2][3] == _get_unsigned_64_bit_hash("dna") assert suffixes[2][3] == _get_unsigned_32_bit_hash("and")
assert suffixes[2][4] == _get_unsigned_64_bit_hash(" ") assert suffixes[2][4] == _get_unsigned_32_bit_hash(" ")
assert suffixes[2][5] == _get_unsigned_64_bit_hash(" ") assert suffixes[2][5] == _get_unsigned_32_bit_hash(" ")
assert suffixes[3][0] == _get_unsigned_64_bit_hash("yg") assert suffixes[3][0] == _get_unsigned_32_bit_hash("gy")
assert suffixes[3][1] == _get_unsigned_64_bit_hash("ygi") assert suffixes[3][1] == _get_unsigned_32_bit_hash("igy")
assert suffixes[3][2] == _get_unsigned_64_bit_hash("ygid") assert suffixes[3][2] == _get_unsigned_32_bit_hash("digy")
assert suffixes[3][3] == _get_unsigned_64_bit_hash("ygido") assert suffixes[3][3] == _get_unsigned_32_bit_hash("odigy")
assert suffixes[3][4] == _get_unsigned_64_bit_hash("r") assert suffixes[3][4] == _get_unsigned_32_bit_hash("r")
if case_sensitive: if case_sensitive:
assert suffixes[3][5] == _get_unsigned_64_bit_hash("rP") assert suffixes[3][5] == _get_unsigned_32_bit_hash("rP")
else: else:
assert suffixes[3][5] == _get_unsigned_64_bit_hash("r ") assert suffixes[3][5] == _get_unsigned_32_bit_hash("r ")
# check values are the same cross-platform # check values are the same cross-platform
assert prefixes[0][3] == 18446744072456113490 if case_sensitive else 18446744071614199016 assert prefixes[0][3] == 18446744072456113490 if case_sensitive else 18446744071614199016
assert suffixes[1][0] == 910783208 assert suffixes[1][0] == 3425774424
assert suffixes[2][5] == 1696457176 assert suffixes[2][5] == 3076404432

View File

@ -1745,41 +1745,35 @@ cdef class Doc:
""" """
cdef unsigned int tok_ind, norm_hash_ind, spec_hash_ind, len_tok_str, working_start, working_len cdef unsigned int tok_ind, norm_hash_ind, spec_hash_ind, len_tok_str, working_start, working_len
cdef unsigned int num_norm_hashes = len_end - len_start, num_spec_hashes = sc_len_end - sc_len_start, num_toks = len(self) cdef unsigned int num_norm_hashes = len_end - len_start, num_spec_hashes = sc_len_end - sc_len_start, num_toks = len(self)
cdef const unsigned char[:] token_string, scs = _get_utf16_memoryview(special_chars, True) cdef const unsigned char[:] tok_str, scs = _get_utf16_memoryview(special_chars, True)
cdef np.ndarray[np.int64_t, ndim=2] output = numpy.empty((num_toks, num_norm_hashes + num_spec_hashes), dtype="int64") cdef np.ndarray[np.int64_t, ndim=2] output = numpy.empty((num_toks, num_norm_hashes + num_spec_hashes), dtype="int64")
cdef bytes scs_buffer_bytes = bytes(b"\x20" * sc_len_end) # spaces cdef bytes scs_buffer_bytes = (bytes(" " * sc_len_end, "UTF-16"))[2:] # first two bytes express endianness and are not relevant here
cdef char* scs_buffer = scs_buffer_bytes cdef char* scs_buffer = scs_buffer_bytes
cdef attr_t tok_num_attr cdef attr_t num_tok_attr
cdef str str_token_attr cdef str str_tok_attr
cdef np.int64_t last_hash
for tok_ind in range(num_toks): for tok_ind in range(num_toks):
tok_num_attr = self.c[tok_ind].lex.orth if case_sensitive else self.c[tok_ind].lex.lower num_tok_attr = self.c[tok_ind].lex.orth if case_sensitive else self.c[tok_ind].lex.lower
str_token_attr = self.vocab.strings[tok_num_attr] str_tok_attr = self.vocab.strings[num_tok_attr]
token_string = _get_utf16_memoryview(str_token_attr, False) tok_str = _get_utf16_memoryview(str_tok_attr, False)
len_tok_str = len(token_string) len_tok_str = len(tok_str)
for norm_hash_ind in range(num_norm_hashes): for norm_hash_ind in range(num_norm_hashes):
working_len = (len_start + norm_hash_ind) * 2 working_len = (len_start + norm_hash_ind) * 2
if working_len > len_tok_str: if working_len > len_tok_str:
output[tok_ind, norm_hash_ind] = last_hash working_len = len_tok_str
break
if suffs_not_prefs: if suffs_not_prefs:
working_start = 0
else:
working_start = len_tok_str - working_len working_start = len_tok_str - working_len
if working_start < 0: else:
output[tok_ind, norm_hash_ind] = last_hash working_start = 0
break output[tok_ind, norm_hash_ind] = hash32(<void*> &tok_str[working_start], working_len, 0)
last_hash = hash32(<void*> &token_string[working_start], working_len, 0)
output[tok_ind, norm_hash_ind] = last_hash
_set_scs_buffer(token_string, scs, scs_buffer, suffs_not_prefs) _set_scs_buffer(tok_str, scs, scs_buffer, suffs_not_prefs)
for spec_hash_ind in range(num_spec_hashes): for spec_hash_ind in range(num_spec_hashes):
working_len = (sc_len_start + spec_hash_ind) * 2 working_len = (sc_len_start + spec_hash_ind) * 2
output[tok_ind, num_norm_hashes + spec_hash_ind] = hash32(scs_buffer, working_len, 0) output[tok_ind, num_norm_hashes + spec_hash_ind] = hash32(scs_buffer, working_len, 0)
return output.astype("uint64", casting="unsafe", copy=False) return output
@staticmethod @staticmethod
def _get_array_attrs(): def _get_array_attrs():
@ -1999,7 +1993,7 @@ cdef void _set_scs_buffer(const unsigned char[:] searched_string, const unsigned
while buf_idx < buf_len: while buf_idx < buf_len:
working_utf16_char = (<unsigned short*> &searched_string[ss_idx])[0] working_utf16_char = (<unsigned short*> &searched_string[ss_idx])[0]
if _is_utf16_char_in_scs(working_utf16_char, scs): if _is_utf16_char_in_scs(working_utf16_char, scs):
buf[buf_idx] = working_utf16_char memcpy(buf, &working_utf16_char, 2)
buf_idx += 2 buf_idx += 2
if suffs_not_prefs: if suffs_not_prefs:
if ss_idx == 0: if ss_idx == 0:
@ -2011,10 +2005,9 @@ cdef void _set_scs_buffer(const unsigned char[:] searched_string, const unsigned
break break
while buf_idx < buf_len: while buf_idx < buf_len:
buf[buf_idx] = SPACE memcpy(buf, &SPACE, 2)
buf_idx += 2 buf_idx += 2
def pickle_doc(doc): def pickle_doc(doc):
bytes_data = doc.to_bytes(exclude=["vocab", "user_data", "user_hooks"]) bytes_data = doc.to_bytes(exclude=["vocab", "user_data", "user_hooks"])
hooks_and_data = (doc.user_data, doc.user_hooks, doc.user_span_hooks, hooks_and_data = (doc.user_data, doc.user_hooks, doc.user_span_hooks,