Intermediate state

This commit is contained in:
richardpaulhudson 2022-09-30 22:26:14 +02:00
parent da63b9448b
commit d296ae9d8e
2 changed files with 70 additions and 73 deletions

View File

@ -2,7 +2,7 @@ import weakref
import numpy
from numpy.testing import assert_array_equal
import murmurhash.mrmr
from murmurhash.mrmr import hash
import pytest
import warnings
from thinc.api import NumpyOps, get_current_ops
@ -982,65 +982,69 @@ def test_doc_spans_setdefault(en_tokenizer):
)
def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive):
def _get_unsigned_64_bit_hash(input:str) -> int:
def _get_unsigned_32_bit_hash(input:str) -> int:
if not case_sensitive:
input = input.lower()
return numpy.asarray([murmurhash.mrmr.hash(input)]).astype("uint64")[0]
working_hash = hash(input.encode("UTF-16")[2:])
if working_hash < 0:
working_hash = working_hash + (2<<31)
return working_hash
doc = en_tokenizer("spaCy✨ and Prodigy")
prefixes = doc.get_affix_hashes(False, case_sensitive, 1, 5, "", 2, 3)
suffixes = doc.get_affix_hashes(True, case_sensitive, 2, 6, "xx✨rP", 1, 3)
assert prefixes[0][0] == _get_unsigned_64_bit_hash("s")
assert prefixes[0][1] == _get_unsigned_64_bit_hash("sp")
assert prefixes[0][2] == _get_unsigned_64_bit_hash("spa")
assert prefixes[0][3] == _get_unsigned_64_bit_hash("spaC")
assert prefixes[0][4] == _get_unsigned_64_bit_hash(" ")
assert prefixes[1][0] == _get_unsigned_64_bit_hash("")
assert prefixes[1][1] == _get_unsigned_64_bit_hash("")
assert prefixes[1][2] == _get_unsigned_64_bit_hash("")
assert prefixes[1][3] == _get_unsigned_64_bit_hash("")
assert prefixes[1][4] == _get_unsigned_64_bit_hash(" ")
assert prefixes[2][0] == _get_unsigned_64_bit_hash("a")
assert prefixes[2][1] == _get_unsigned_64_bit_hash("an")
assert prefixes[2][2] == _get_unsigned_64_bit_hash("and")
assert prefixes[2][3] == _get_unsigned_64_bit_hash("and")
assert prefixes[2][4] == _get_unsigned_64_bit_hash(" ")
assert prefixes[3][0] == _get_unsigned_64_bit_hash("P")
assert prefixes[3][1] == _get_unsigned_64_bit_hash("Pr")
assert prefixes[3][2] == _get_unsigned_64_bit_hash("Pro")
assert prefixes[3][3] == _get_unsigned_64_bit_hash("Prod")
assert prefixes[3][4] == _get_unsigned_64_bit_hash(" ")
assert prefixes[0][0] == _get_unsigned_32_bit_hash("s")
assert prefixes[0][1] == _get_unsigned_32_bit_hash("sp")
assert prefixes[0][2] == _get_unsigned_32_bit_hash("spa")
assert prefixes[0][3] == _get_unsigned_32_bit_hash("spaC")
assert prefixes[0][4] == _get_unsigned_32_bit_hash(" ")
assert prefixes[1][0] == _get_unsigned_32_bit_hash("")
assert prefixes[1][1] == _get_unsigned_32_bit_hash("")
assert prefixes[1][2] == _get_unsigned_32_bit_hash("")
assert prefixes[1][3] == _get_unsigned_32_bit_hash("")
assert prefixes[1][4] == _get_unsigned_32_bit_hash(" ")
assert prefixes[2][0] == _get_unsigned_32_bit_hash("a")
assert prefixes[2][1] == _get_unsigned_32_bit_hash("an")
assert prefixes[2][2] == _get_unsigned_32_bit_hash("and")
assert prefixes[2][3] == _get_unsigned_32_bit_hash("and")
assert prefixes[2][4] == _get_unsigned_32_bit_hash(" ")
assert prefixes[3][0] == _get_unsigned_32_bit_hash("P")
assert prefixes[3][1] == _get_unsigned_32_bit_hash("Pr")
assert prefixes[3][2] == _get_unsigned_32_bit_hash("Pro")
assert prefixes[3][3] == _get_unsigned_32_bit_hash("Prod")
assert prefixes[3][4] == _get_unsigned_32_bit_hash(" ")
assert suffixes[0][0] == _get_unsigned_64_bit_hash("yC")
assert suffixes[0][1] == _get_unsigned_64_bit_hash("yCa")
assert suffixes[0][2] == _get_unsigned_64_bit_hash("yCap")
assert suffixes[0][3] == _get_unsigned_64_bit_hash("yCaps")
assert suffixes[0][4] == _get_unsigned_64_bit_hash(" ")
assert suffixes[0][5] == _get_unsigned_64_bit_hash(" ")
assert suffixes[1][0] == _get_unsigned_64_bit_hash("")
assert suffixes[1][1] == _get_unsigned_64_bit_hash("")
assert suffixes[1][2] == _get_unsigned_64_bit_hash("")
assert suffixes[1][3] == _get_unsigned_64_bit_hash("")
assert suffixes[1][4] == _get_unsigned_64_bit_hash("")
assert suffixes[1][5] == _get_unsigned_64_bit_hash("")
assert suffixes[2][0] == _get_unsigned_64_bit_hash("dn")
assert suffixes[2][1] == _get_unsigned_64_bit_hash("dna")
assert suffixes[2][2] == _get_unsigned_64_bit_hash("dna")
assert suffixes[2][3] == _get_unsigned_64_bit_hash("dna")
assert suffixes[2][4] == _get_unsigned_64_bit_hash(" ")
assert suffixes[2][5] == _get_unsigned_64_bit_hash(" ")
assert suffixes[3][0] == _get_unsigned_64_bit_hash("yg")
assert suffixes[3][1] == _get_unsigned_64_bit_hash("ygi")
assert suffixes[3][2] == _get_unsigned_64_bit_hash("ygid")
assert suffixes[3][3] == _get_unsigned_64_bit_hash("ygido")
assert suffixes[3][4] == _get_unsigned_64_bit_hash("r")
assert suffixes[0][0] == _get_unsigned_32_bit_hash("Cy")
assert suffixes[0][1] == _get_unsigned_32_bit_hash("aCy")
assert suffixes[0][2] == _get_unsigned_32_bit_hash("paCy")
assert suffixes[0][3] == _get_unsigned_32_bit_hash("spaCy")
assert suffixes[0][4] == _get_unsigned_32_bit_hash(" ")
assert suffixes[0][5] == _get_unsigned_32_bit_hash(" ")
assert suffixes[1][0] == _get_unsigned_32_bit_hash("")
assert suffixes[1][1] == _get_unsigned_32_bit_hash("")
assert suffixes[1][2] == _get_unsigned_32_bit_hash("")
assert suffixes[1][3] == _get_unsigned_32_bit_hash("")
assert suffixes[1][4] == _get_unsigned_32_bit_hash("")
assert suffixes[1][5] == _get_unsigned_32_bit_hash("")
assert suffixes[2][0] == _get_unsigned_32_bit_hash("nd")
assert suffixes[2][1] == _get_unsigned_32_bit_hash("and")
assert suffixes[2][2] == _get_unsigned_32_bit_hash("and")
assert suffixes[2][3] == _get_unsigned_32_bit_hash("and")
assert suffixes[2][4] == _get_unsigned_32_bit_hash(" ")
assert suffixes[2][5] == _get_unsigned_32_bit_hash(" ")
assert suffixes[3][0] == _get_unsigned_32_bit_hash("gy")
assert suffixes[3][1] == _get_unsigned_32_bit_hash("igy")
assert suffixes[3][2] == _get_unsigned_32_bit_hash("digy")
assert suffixes[3][3] == _get_unsigned_32_bit_hash("odigy")
assert suffixes[3][4] == _get_unsigned_32_bit_hash("r")
if case_sensitive:
assert suffixes[3][5] == _get_unsigned_64_bit_hash("rP")
assert suffixes[3][5] == _get_unsigned_32_bit_hash("rP")
else:
assert suffixes[3][5] == _get_unsigned_64_bit_hash("r ")
assert suffixes[3][5] == _get_unsigned_32_bit_hash("r ")
# check values are the same cross-platform
assert prefixes[0][3] == 18446744072456113490 if case_sensitive else 18446744071614199016
assert suffixes[1][0] == 910783208
assert suffixes[2][5] == 1696457176
assert suffixes[1][0] == 3425774424
assert suffixes[2][5] == 3076404432

View File

@ -1745,41 +1745,35 @@ cdef class Doc:
"""
cdef unsigned int tok_ind, norm_hash_ind, spec_hash_ind, len_tok_str, working_start, working_len
cdef unsigned int num_norm_hashes = len_end - len_start, num_spec_hashes = sc_len_end - sc_len_start, num_toks = len(self)
cdef const unsigned char[:] token_string, scs = _get_utf16_memoryview(special_chars, True)
cdef const unsigned char[:] tok_str, scs = _get_utf16_memoryview(special_chars, True)
cdef np.ndarray[np.int64_t, ndim=2] output = numpy.empty((num_toks, num_norm_hashes + num_spec_hashes), dtype="int64")
cdef bytes scs_buffer_bytes = bytes(b"\x20" * sc_len_end) # spaces
cdef bytes scs_buffer_bytes = (bytes(" " * sc_len_end, "UTF-16"))[2:] # first two bytes express endianness and are not relevant here
cdef char* scs_buffer = scs_buffer_bytes
cdef attr_t tok_num_attr
cdef str str_token_attr
cdef np.int64_t last_hash
cdef attr_t num_tok_attr
cdef str str_tok_attr
for tok_ind in range(num_toks):
tok_num_attr = self.c[tok_ind].lex.orth if case_sensitive else self.c[tok_ind].lex.lower
str_token_attr = self.vocab.strings[tok_num_attr]
token_string = _get_utf16_memoryview(str_token_attr, False)
len_tok_str = len(token_string)
num_tok_attr = self.c[tok_ind].lex.orth if case_sensitive else self.c[tok_ind].lex.lower
str_tok_attr = self.vocab.strings[num_tok_attr]
tok_str = _get_utf16_memoryview(str_tok_attr, False)
len_tok_str = len(tok_str)
for norm_hash_ind in range(num_norm_hashes):
working_len = (len_start + norm_hash_ind) * 2
if working_len > len_tok_str:
output[tok_ind, norm_hash_ind] = last_hash
break
working_len = len_tok_str
if suffs_not_prefs:
working_start = 0
else:
working_start = len_tok_str - working_len
if working_start < 0:
output[tok_ind, norm_hash_ind] = last_hash
break
last_hash = hash32(<void*> &token_string[working_start], working_len, 0)
output[tok_ind, norm_hash_ind] = last_hash
else:
working_start = 0
output[tok_ind, norm_hash_ind] = hash32(<void*> &tok_str[working_start], working_len, 0)
_set_scs_buffer(token_string, scs, scs_buffer, suffs_not_prefs)
_set_scs_buffer(tok_str, scs, scs_buffer, suffs_not_prefs)
for spec_hash_ind in range(num_spec_hashes):
working_len = (sc_len_start + spec_hash_ind) * 2
output[tok_ind, num_norm_hashes + spec_hash_ind] = hash32(scs_buffer, working_len, 0)
return output.astype("uint64", casting="unsafe", copy=False)
return output
@staticmethod
def _get_array_attrs():
@ -1999,7 +1993,7 @@ cdef void _set_scs_buffer(const unsigned char[:] searched_string, const unsigned
while buf_idx < buf_len:
working_utf16_char = (<unsigned short*> &searched_string[ss_idx])[0]
if _is_utf16_char_in_scs(working_utf16_char, scs):
buf[buf_idx] = working_utf16_char
memcpy(buf, &working_utf16_char, 2)
buf_idx += 2
if suffs_not_prefs:
if ss_idx == 0:
@ -2011,10 +2005,9 @@ cdef void _set_scs_buffer(const unsigned char[:] searched_string, const unsigned
break
while buf_idx < buf_len:
buf[buf_idx] = SPACE
memcpy(buf, &SPACE, 2)
buf_idx += 2
def pickle_doc(doc):
bytes_data = doc.to_bytes(exclude=["vocab", "user_data", "user_hooks"])
hooks_and_data = (doc.user_data, doc.user_hooks, doc.user_span_hooks,