Intermediate state

This commit is contained in:
richardpaulhudson 2022-09-29 13:14:42 +02:00
parent 6f42d79c1e
commit 644d6131af
2 changed files with 68 additions and 12 deletions

View File

@ -9,7 +9,10 @@ from ..attrs cimport attr_id_t
cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil
cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) nogil
cdef const unsigned char[:] get_utf16_memoryview(str unicode_string, bint check_2_bytes)
cdef bint check_utf16_char(unsigned short utf16_char, const unsigned char[:] scs)
cdef void set_scs(const unsigned char[:] searched_string, const unsigned char[:] scs, bytearray working_array, bint suffs_not_prefs)
ctypedef const LexemeC* const_Lexeme_ptr
ctypedef const TokenC* const_TokenC_ptr
@ -69,4 +72,3 @@ cdef class Doc:
cdef int push_back(self, LexemeOrToken lex_or_tok, bint has_space) except -1
cpdef np.ndarray to_array(self, object features)

View File

@ -16,7 +16,7 @@ import srsly
from thinc.api import get_array_module, get_current_ops
from thinc.util import copy_array
import warnings
import murmurhash.mrmr
from murmurhash.mrmr cimport hash32
from .span cimport Span
from .token cimport MISSING_DEP
@ -94,6 +94,50 @@ cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name)
else:
return get_token_attr(token, feat_name)
cdef const unsigned char[:] get_utf16_memoryview(str unicode_string, bint check_2_bytes):
"""
Returns a memory view of the UTF-16 representation of a string with the default endianness of the platform.
Throws a ValueError if *check_2_bytes == True* and one or more characters in the UTF-16 representation
occupy four bytes rather than two.
"""
cdef const unsigned char[:] view = unicode_string.encode("UTF-16")[2:] # first two bytes are endianness
cdef unsigned int us_len = len(unicode_string), view_len = len(view)
if check_2_bytes and us_len * 2 != view_len:
raise ValueError(Errors.E1046)
return view
cdef bint check_utf16_char(unsigned short utf16_char, const unsigned char[:] scs):
cdef unsigned int scs_idx = 0, scs_len = len(scs)
while scs_idx < scs_len:
if utf16_char == (<unsigned short*> &scs[scs_idx])[0]:
return True
scs_idx += 2
return False
cdef void set_scs(const unsigned char[:] searched_string, const unsigned char[:] scs, bytearray working_array, bint suffs_not_prefs):
cdef unsigned int wa_idx = 0, wa_len = len(working_array), ss_len = len(searched_string)
cdef unsigned int ss_idx = ss_len - 2 if suffs_not_prefs else 0
cdef unsigned short working_utf16_char, SPACE = 32
while wa_idx < wa_len:
working_utf16_char = (<unsigned short*> &searched_string[ss_idx])[0]
if check_utf16_char(working_utf16_char, scs):
working_array[wa_idx] = working_utf16_char
wa_idx += 2
if suffs_not_prefs:
if ss_idx == 0:
break
ss_idx -= 2
else:
ss_idx += 2
if ss_idx == ss_len:
break
while wa_idx < wa_len:
working_array[wa_idx] = SPACE
wa_idx += 2
class SetEntsDefault(str, Enum):
blocked = "blocked"
@ -1741,23 +1785,33 @@ cdef class Doc:
TODO
"""
cdef unsigned int token_index, norm_hash_index, spec_hash_index
cdef str token_string, specials_string, working_substring
cdef const attr_t* num_token_attr
cdef const unsigned char[:] token_string, working_substring, scs = get_utf16_memoryview(special_chars, True)
cdef unsigned int num_norm_hashes = len_end - len_start, num_spec_hashes = sc_len_end - sc_len_start, num_tokens = len(self)
cdef unsigned int len_token_string
cdef np.ndarray[np.int64_t, ndim=2] output = numpy.empty((num_tokens, num_norm_hashes + num_spec_hashes), dtype="int64")
cdef bytearray working_scs_buffer = bytearray(sc_len_end * 2)
cdef unsigned int working_start, working_len
for token_index in range(num_tokens):
token_string = self[token_index].orth_ if case_sensitive else self[token_index].lower_
if suffs_not_prefs:
token_string = token_string[::-1]
num_token_attr = &self.c[token_index].lex.orth if case_sensitive else &self.c[token_index].lex.lower
token_string = get_utf16_memoryview(self.vocab.strings[num_token_attr[0]], False)
if not suffs_not_prefs:
len_token_string = len(token_string)
for norm_hash_index in range(num_norm_hashes):
working_substring = token_string[: len_start + norm_hash_index]
output[token_index, norm_hash_index] = murmurhash.mrmr.hash(working_substring)
if suffs_not_prefs:
working_start = 0
working_len = (len_start + norm_hash_index) * 2
else:
working_len = (len_start + norm_hash_index) * 2
working_start = len_token_string - working_len
output[token_index, norm_hash_index] = hash32(<char*> &token_string[working_start], working_len, 0)
specials_string = "".join([c for c in token_string if c in special_chars]) + " " * sc_len_end
set_scs(token_string, scs, working_scs_buffer, suffs_not_prefs)
for spec_hash_index in range(num_spec_hashes):
working_substring = specials_string[: sc_len_start + spec_hash_index]
output[token_index, num_norm_hashes + spec_hash_index] = murmurhash.mrmr.hash(working_substring)
working_len = (sc_len_start + spec_hash_index) * 2
output[token_index, num_norm_hashes + spec_hash_index] = hash32(<char*> &working_scs_buffer[0], working_len, 0)
return output.astype("uint64", casting="unsafe", copy=False)