Simple changes based on review comments

This commit is contained in:
richardpaulhudson 2022-12-12 11:10:10 +01:00
parent ec1426700e
commit 79b2843a3f
5 changed files with 36 additions and 35 deletions

View File

@ -69,7 +69,7 @@ def forward(
features: List[Ints2d] = [] features: List[Ints2d] = []
for doc in docs: for doc in docs:
hashes = doc.get_character_combination_hashes( hashes = doc.get_character_combination_hashes(
cs=case_sensitive, case_sensitive=case_sensitive,
p_lengths=p_lengths, p_lengths=p_lengths,
s_lengths=s_lengths, s_lengths=s_lengths,
ps_search_chars=ps_search_chars, ps_search_chars=ps_search_chars,

View File

@ -316,7 +316,6 @@ cdef class StringStore:
self.keys.push_back(key) self.keys.push_back(key)
return value return value
@cython.boundscheck(False) # Deactivate bounds checking
cdef (const unsigned char*, int) utf8_ptr(self, const attr_t hash_val): cdef (const unsigned char*, int) utf8_ptr(self, const attr_t hash_val):
# Returns a pointer to the UTF-8 string together with its length in bytes. # Returns a pointer to the UTF-8 string together with its length in bytes.
# This method presumes the calling code has already checked that *hash_val* # This method presumes the calling code has already checked that *hash_val*

View File

@ -4,7 +4,6 @@ import weakref
import numpy import numpy
from time import time from time import time
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
from murmurhash.mrmr import hash
import pytest import pytest
import warnings import warnings
from thinc.api import NumpyOps, get_current_ops from thinc.api import NumpyOps, get_current_ops
@ -1017,10 +1016,10 @@ def _get_fnv1a_hash(input: bytes) -> int:
def test_fnv1a_hash(): def test_fnv1a_hash():
"""Checks the conformity of the 64-bit FNV1A implementation with """Checks the conformity of the 64-bit FNV1A implementation with
http://www.isthe.com/chongo/src/fnv/test_fnv.c. http://www.isthe.com/chongo/src/fnv/test_fnv.c.
The method called here, _get_fnv1a_hash(), is only used in testing; The method called here, _get_fnv1a_hash(), is only used in testing;
in production code, the hashing is performed in a fashion that is interweaved in production code, the hashing is performed in a fashion that is interweaved
with other logic. The conformity of the production code is demonstrated by the with other logic. The conformity of the production code is demonstrated by the
character combination hash tests, where hashes produced by the production code character combination hash tests, where hashes produced by the production code
are tested for equality against hashes produced by _get_fnv1a_hash(). are tested for equality against hashes produced by _get_fnv1a_hash().
""" """
INPUTS = [ INPUTS = [
@ -1458,7 +1457,7 @@ def test_get_character_combination_hashes_good_case(en_tokenizer, case_sensitive
"xx✨rp", case_sensitive "xx✨rp", case_sensitive
) )
hashes = doc.get_character_combination_hashes( hashes = doc.get_character_combination_hashes(
cs=case_sensitive, case_sensitive=case_sensitive,
p_lengths=bytes( p_lengths=bytes(
( (
1, 1,
@ -1539,7 +1538,7 @@ def test_get_character_combination_hashes_good_case_partial(en_tokenizer):
doc = en_tokenizer("spaCy✨ and Prodigy") doc = en_tokenizer("spaCy✨ and Prodigy")
ps_search_chars, ps_width_offsets = get_search_char_byte_arrays("rp", False) ps_search_chars, ps_width_offsets = get_search_char_byte_arrays("rp", False)
hashes = doc.get_character_combination_hashes( hashes = doc.get_character_combination_hashes(
cs=False, case_sensitive=False,
p_lengths=bytes(), p_lengths=bytes(),
s_lengths=bytes( s_lengths=bytes(
( (
@ -1586,7 +1585,7 @@ def test_get_character_combination_hashes_various_lengths(en_tokenizer):
for s_length in range(1, 8): for s_length in range(1, 8):
hashes = doc.get_character_combination_hashes( hashes = doc.get_character_combination_hashes(
cs=False, case_sensitive=False,
p_lengths=bytes((p_length,)), p_lengths=bytes((p_length,)),
s_lengths=bytes((s_length,)), s_lengths=bytes((s_length,)),
ps_search_chars=bytes(), ps_search_chars=bytes(),
@ -1608,7 +1607,7 @@ def test_get_character_combination_hashes_turkish_i_with_dot(
doc = en_tokenizer("İ".lower() + "İ") doc = en_tokenizer("İ".lower() + "İ")
search_chars, width_offsets = get_search_char_byte_arrays("İ", case_sensitive) search_chars, width_offsets = get_search_char_byte_arrays("İ", case_sensitive)
hashes = doc.get_character_combination_hashes( hashes = doc.get_character_combination_hashes(
cs=case_sensitive, case_sensitive=case_sensitive,
p_lengths=bytes( p_lengths=bytes(
( (
1, 1,
@ -1696,7 +1695,7 @@ def test_get_character_combination_hashes_string_store_spec_cases(
assert len(doc) == 4 assert len(doc) == 4
ps_search_chars, ps_width_offsets = get_search_char_byte_arrays("E", case_sensitive) ps_search_chars, ps_width_offsets = get_search_char_byte_arrays("E", case_sensitive)
hashes = doc.get_character_combination_hashes( hashes = doc.get_character_combination_hashes(
cs=case_sensitive, case_sensitive=case_sensitive,
p_lengths=bytes((2,)), p_lengths=bytes((2,)),
s_lengths=bytes((2,)), s_lengths=bytes((2,)),
ps_search_chars=ps_search_chars, ps_search_chars=ps_search_chars,
@ -1726,7 +1725,7 @@ def test_get_character_combination_hashes_string_store_spec_cases(
def test_character_combination_hashes_empty_lengths(en_tokenizer): def test_character_combination_hashes_empty_lengths(en_tokenizer):
doc = en_tokenizer("and𐌞") doc = en_tokenizer("and𐌞")
assert doc.get_character_combination_hashes( assert doc.get_character_combination_hashes(
cs=True, case_sensitive=True,
p_lengths=bytes(), p_lengths=bytes(),
s_lengths=bytes(), s_lengths=bytes(),
ps_search_chars=bytes(), ps_search_chars=bytes(),

View File

@ -177,7 +177,7 @@ class Doc:
def get_character_combination_hashes( def get_character_combination_hashes(
self, self,
*, *,
cs: bool, case_sensitive: bool,
p_lengths: bytes, p_lengths: bytes,
s_lengths: bytes, s_lengths: bytes,
ps_search_chars: bytes, ps_search_chars: bytes,

View File

@ -41,6 +41,7 @@ from ._serialize import ALL_ATTRS as DOCBIN_ALL_ATTRS
from ..util import get_words_and_spaces from ..util import get_words_and_spaces
DEF PADDING = 5 DEF PADDING = 5
MAX_UTF8_CHAR_BYTE_WIDTH = 4
cdef int bounds_check(int i, int length, int padding) except -1: cdef int bounds_check(int i, int length, int padding) except -1:
if (i + padding) < 0: if (i + padding) < 0:
@ -1743,10 +1744,9 @@ cdef class Doc:
j += 1 j += 1
return output return output
@cython.boundscheck(False) # Deactivate bounds checking
def get_character_combination_hashes(self, def get_character_combination_hashes(self,
*, *,
const bint cs, const bint case_sensitive,
const unsigned char* p_lengths, const unsigned char* p_lengths,
const unsigned char* s_lengths, const unsigned char* s_lengths,
const unsigned char* ps_search_chars, const unsigned char* ps_search_chars,
@ -1789,8 +1789,8 @@ cdef class Doc:
Many of the buffers passed into and used by this method contain single-byte numerical values. This takes advantage of Many of the buffers passed into and used by this method contain single-byte numerical values. This takes advantage of
the fact that we are hashing short affixes and searching for small groups of characters. The calling code is responsible the fact that we are hashing short affixes and searching for small groups of characters. The calling code is responsible
for ensuring that lengths being passed in cannot exceed 63 and hence that resulting values with maximally four-byte for ensuring that lengths being passed in cannot exceed 63 and hence, with maximally four-byte
character widths can never exceed 255. character widths, that individual values within buffers can never exceed the capacity of a single byte (255).
Note that this method performs no data validation itself as it expects the calling code will already have done so, and Note that this method performs no data validation itself as it expects the calling code will already have done so, and
that the behaviour of the code may be erratic if the supplied parameters do not conform to expectations. that the behaviour of the code may be erratic if the supplied parameters do not conform to expectations.
@ -1809,12 +1809,14 @@ cdef class Doc:
# Define / allocate buffers # Define / allocate buffers
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef unsigned char* pref_l_buf = <unsigned char*> mem.alloc(p_max_l, 1) cdef unsigned char* pref_l_buf = <unsigned char*> mem.alloc(p_max_l, sizeof(char))
cdef unsigned char* suff_l_buf = <unsigned char*> mem.alloc(s_max_l, 1) cdef unsigned char* suff_l_buf = <unsigned char*> mem.alloc(s_max_l, sizeof(char))
cdef unsigned char* ps_res_buf = <unsigned char*> mem.alloc(ps_max_l, 4) cdef unsigned char* ps_res_buf = <unsigned char*> mem.alloc(ps_max_l,
cdef unsigned char* ps_l_buf = <unsigned char*> mem.alloc(ps_max_l, 1) MAX_UTF8_CHAR_BYTE_WIDTH * sizeof(char))
cdef unsigned char* ss_res_buf = <unsigned char*> mem.alloc(ss_max_l, 4) cdef unsigned char* ps_l_buf = <unsigned char*> mem.alloc(ps_max_l, sizeof(char))
cdef unsigned char* ss_l_buf = <unsigned char*> mem.alloc(ss_max_l, 1) cdef unsigned char* ss_res_buf = <unsigned char*> mem.alloc(ss_max_l,
MAX_UTF8_CHAR_BYTE_WIDTH * sizeof(char))
cdef unsigned char* ss_l_buf = <unsigned char*> mem.alloc(ss_max_l, sizeof(char))
cdef int doc_l = self.length cdef int doc_l = self.length
cdef np.ndarray[np.uint64_t, ndim=2] hashes = numpy.empty( cdef np.ndarray[np.uint64_t, ndim=2] hashes = numpy.empty(
(doc_l, hashes_per_tok), dtype="uint64") (doc_l, hashes_per_tok), dtype="uint64")
@ -1829,7 +1831,7 @@ cdef class Doc:
for tok_i in range(doc_l): for tok_i in range(doc_l):
tok_c = self.c[tok_i] tok_c = self.c[tok_i]
num_tok_attr = tok_c.lex.orth if cs else tok_c.lex.lower num_tok_attr = tok_c.lex.orth if case_sensitive else tok_c.lex.lower
if num_tok_attr < len(SYMBOLS_BY_INT): # hardly ever happens if num_tok_attr < len(SYMBOLS_BY_INT): # hardly ever happens
if num_tok_attr == 0: if num_tok_attr == 0:
tok_str_bytes = b"" tok_str_bytes = b""
@ -2042,21 +2044,22 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
return lca_matrix return lca_matrix
@cython.boundscheck(False) # Deactivate bounds checking
cdef void _set_prefix_lengths( cdef void _set_prefix_lengths(
const unsigned char* tok_str, const unsigned char* tok_str,
const int tok_str_l, const int tok_str_l,
const int p_max_l, const int p_max_l,
unsigned char* pref_l_buf, unsigned char* pref_l_buf,
) nogil: ) nogil:
""" Populate *pref_l_buf*, which has length *pref_l*, with the byte lengths of the first *pref_l* characters within *tok_str*. """ Populate *pref_l_buf*, which has length *p_max_l*, with the byte lengths of each of the substrings terminated by the first *p_max_l*
Lengths that are greater than the character length of the whole word are populated with the byte length of the whole word. characters within *tok_str*. Lengths that are greater than the character length of the whole word are populated with the byte length
of the whole word.
tok_str: a UTF-8 representation of a string. tok_str: a UTF-8 representation of a string.
tok_str_l: the length of *tok_str*. tok_str_l: the length of *tok_str*.
p_max_l: the number of characters to process at the beginning of the word. p_max_l: the number of characters to process at the beginning of the word.
pref_l_buf: a buffer of length *p_max_l* in which to store the lengths. The calling code ensures that lengths pref_l_buf: a buffer of length *p_max_l* in which to store the lengths. The code calling *get_character_combination_hashes()* is
greater than 255 cannot occur. responsible for ensuring that *p_max_l* cannot exceed 63 and hence, with maximally four-byte character widths, that individual values
within the buffer can never exceed the capacity of a single byte (255).
""" """
cdef int tok_str_idx = 1, pref_l_buf_idx = 0 cdef int tok_str_idx = 1, pref_l_buf_idx = 0
@ -2075,21 +2078,22 @@ cdef void _set_prefix_lengths(
memset(pref_l_buf + pref_l_buf_idx, pref_l_buf[pref_l_buf_idx - 1], p_max_l - pref_l_buf_idx) memset(pref_l_buf + pref_l_buf_idx, pref_l_buf[pref_l_buf_idx - 1], p_max_l - pref_l_buf_idx)
@cython.boundscheck(False) # Deactivate bounds checking
cdef void _set_suffix_lengths( cdef void _set_suffix_lengths(
const unsigned char* tok_str, const unsigned char* tok_str,
const int tok_str_l, const int tok_str_l,
const int s_max_l, const int s_max_l,
unsigned char* suff_l_buf, unsigned char* suff_l_buf,
) nogil: ) nogil:
""" Populate *suff_l_buf*, which has length *suff_l*, with the byte lengths of the last *suff_l* characters within *tok_str*. """ Populate *suff_l_buf*, which has length *s_max_l*, with the byte lengths of each of the substrings started by the last *s_max_l*
Lengths that are greater than the character length of the whole word are populated with the byte length of the whole word. characters within *tok_str*. Lengths that are greater than the character length of the whole word are populated with the byte length
of the whole word.
tok_str: a UTF-8 representation of a string. tok_str: a UTF-8 representation of a string.
tok_str_l: the length of *tok_str*. tok_str_l: the length of *tok_str*.
s_max_l: the number of characters to process at the end of the word. s_max_l: the number of characters to process at the end of the word.
suff_l_buf: a buffer of length *s_max_l* in which to store the lengths. The calling code ensures that lengths suff_l_buf: a buffer of length *s_max_l* in which to store the lengths. The code calling *get_character_combination_hashes()* is
greater than 255 cannot occur. responsible for ensuring that *s_max_l* cannot exceed 63 and hence, with maximally four-byte character widths, that individual values
within the buffer can never exceed the capacity of a single byte (255).
""" """
cdef int tok_str_idx = tok_str_l - 1, suff_l_buf_idx = 0 cdef int tok_str_idx = tok_str_l - 1, suff_l_buf_idx = 0
@ -2105,7 +2109,6 @@ cdef void _set_suffix_lengths(
memset(suff_l_buf + suff_l_buf_idx, suff_l_buf[suff_l_buf_idx - 1], s_max_l - suff_l_buf_idx) memset(suff_l_buf + suff_l_buf_idx, suff_l_buf[suff_l_buf_idx - 1], s_max_l - suff_l_buf_idx)
@cython.boundscheck(False) # Deactivate bounds checking
cdef void _search_for_chars( cdef void _search_for_chars(
const unsigned char* tok_str, const unsigned char* tok_str,
const int tok_str_l, const int tok_str_l,