Intermediate state

This commit is contained in:
richardpaulhudson 2022-10-27 20:59:30 +02:00
parent c140bd6083
commit 13e417e8d1
5 changed files with 130 additions and 77 deletions

View File

@ -956,7 +956,8 @@ class Errors(metaclass=ErrorsWithCodes):
"knowledge base, use `InMemoryLookupKB`.")
E1047 = ("Invalid rich group config '{label}'.")
E1048 = ("Length > 31 in rich group config '{label}.")
E1049 = ("Error splitting UTF-8 byte string into separate characters.")
E1049 = ("Rich group config {label} specifies lengths that are not in ascending order.")
E1050 = ("Error splitting UTF-8 byte string into separate characters.")
# Deprecated model shortcuts, only used in errors and warnings

View File

@ -195,7 +195,6 @@ def _verify_rich_config_group(
rows: Optional[List[int]],
search_chars: Optional[str],
is_search_char_group: bool,
case_sensitive: bool,
) -> None:
if lengths is not None or rows is not None:
if is_search_char_group and (search_chars is None or len(search_chars) == 0):
@ -208,8 +207,11 @@ def _verify_rich_config_group(
raise ValueError(Errors.E1047.format(label=label))
elif search_chars is not None:
raise ValueError(Errors.E1047.format(label=label))
if lengths is not None and max(lengths) > 31:
raise ValueError(Errors.E1048.format(label=label))
if lengths is not None:
if lengths[-1] > 31:
raise ValueError(Errors.E1048.format(label=label))
if len(lengths) != len(set(lengths)) or lengths != sorted(lengths):
raise ValueError(Errors.E1049.format(label=label))
@registry.architectures("spacy.RichMultiHashEmbed.v1")
@ -258,6 +260,8 @@ def RichMultiHashEmbed(
plural noun does not become `a` if it is the third or fourth vowel from the
end of the word.
All lengths must be specified in ascending order.
width (int): The output width. Also used as the width of the embedding tables.
Recommended values are between 64 and 300.
attrs (list of attr IDs): The token attributes to embed. A separate
@ -293,19 +297,14 @@ def RichMultiHashEmbed(
if len(rows) != len(attrs):
raise ValueError(f"Mismatched lengths: {len(rows)} vs {len(attrs)}")
_verify_rich_config_group(
"prefix", pref_lengths, pref_rows, None, False, case_sensitive
)
_verify_rich_config_group(
"suffix", suff_lengths, suff_rows, None, False, case_sensitive
)
_verify_rich_config_group("prefix", pref_lengths, pref_rows, None, False)
_verify_rich_config_group("suffix", suff_lengths, suff_rows, None, False)
_verify_rich_config_group(
"prefix search",
pref_search_lengths,
pref_search_rows,
pref_search_chars,
True,
case_sensitive,
)
_verify_rich_config_group(
"suffix search",
@ -313,7 +312,6 @@ def RichMultiHashEmbed(
suff_search_rows,
suff_search_chars,
True,
case_sensitive,
)
if "PREFIX" in attrs or "SUFFIX" in attrs:

View File

@ -39,20 +39,21 @@ cdef int [:,:] _get_lca_matrix(Doc, int start, int end)
cdef void _set_affix_lengths(
const unsigned char[:] text_buf,
char* aff_len_buf,
const unsigned char[:] tok_str,
unsigned char* aff_len_buf,
const int pref_len,
const int suff_len,
) nogil
cdef bint _search_for_chars(
ccdef void _search_for_chars(
const unsigned char[:] tok_str,
const unsigned char[:] s_1byte_ch,
const unsigned char[:] s_2byte_ch,
const unsigned char[:] s_3byte_ch,
const unsigned char[:] s_4byte_ch,
unsigned char* res_buf,
int max_res_l,
unsigned char* len_buf,
bint suffs_not_prefs
) nogil

View File

@ -1751,7 +1751,7 @@ cdef class Doc:
const unsigned char[:] ss_3byte_ch,
const unsigned char[:] ss_4byte_ch,
np.ndarray ss_lengths,
):
) nogil:
"""
Returns a 2D NumPy array where the rows represent tokens and the columns represent hashes of various character combinations
derived from the raw text of each token.
@ -1764,19 +1764,19 @@ cdef class Doc:
ss_ variables relate to searches starting at the end of the word
cs: if *False*, hashes are generated based on the lower-case version of each token.
p_lengths: an Ints1d specifying the lengths of prefixes to be hashed. For example, if *p_lengths==[2, 3]*,
the prefixes hashed for "spaCy" would be "sp" and "spa".
s_lengths: an Ints1d specifying the lengths of suffixes to be hashed. For example, if *s_lengths==[2, 3]* and
*cs == True*, the suffixes hashed for "spaCy" would be "Cy" and "aCy".
p_lengths: an Ints1d specifying the lengths of prefixes to be hashed in ascending order. For example,
if *p_lengths==[2, 3]*, the prefixes hashed for "spaCy" would be "sp" and "spa".
s_lengths: an Ints1d specifying the lengths of suffixes to be hashed in ascending order. For example,
if *s_lengths==[2, 3]* and *cs == True*, the suffixes hashed for "spaCy" would be "Cy" and "aCy".
ps_<n>byte_ch: a byte array containing in order n-byte-wide characters to search for within each token,
starting at the beginning.
ps_lengths: an Ints1d specifying the lengths of search results (from the beginning) to be hashed. For example, if
*ps_lengths==[1, 2]*, *ps_search=="aC" and *cs==False*, the searched strings hashed for
ps_lengths: an Ints1d specifying the lengths of search results (from the beginning) to be hashed in ascending order.
For example, if *ps_lengths==[1, 2]*, *ps_search=="aC" and *cs==False*, the searched strings hashed for
"spaCy" would be "a" and "ac".
ss_<n>byte_ch: a byte array containing in order n-byte-wide characters to search for within each token,
starting at the end.
ss_lengths: an integer list specifying the lengths of search results (from the end) to be hashed. For example, if
*ss_lengths==[1, 2]*, *ss_search=="aC" and *cs==False*, the searched strings hashed for
ss_lengths: an integer list specifying the lengths of search results (from the end) to be hashed in ascending order.
For example, if *ss_lengths==[1, 2]*, *ss_search=="aC" and *cs==False*, the searched strings hashed for
"spaCy" would be "c" and "ca".
"""
@ -1826,30 +1826,33 @@ cdef class Doc:
for hash_idx in range(p_h_num):
offset = aff_len_buf[p_lengths_v[hash_idx]]
if offset > 0:
hash_val = hash32(<void*> &qcktest2[0], offset, 0)
hash_val = hash32(<void*> &tok_str[0], offset, 0)
hashes[tok_i, hash_idx] = hash_val
for hash_idx in range(p_h_num, s_h_end):
offset = s_lengths_v[hash_idx - p_h_num]
if offset > 0:
hash_val = hash32(<void*> &qcktest2[len(qcktest2) - offset], offset, 0)
hash_val = hash32(<void*> &tok_str[len(tok_str) - offset], offset, 0)
hashes[tok_i, hash_idx] = hash_val
if (
ps_h_num > 0 and
_search_for_chars(tok_str, ps_1byte_ch, ps_2byte_ch, ps_3byte_ch, ps_4byte_ch, ps_res_buf, ps_res_len, False)
):
for hash_idx in range(s_h_end, ps_h_end):
aff_len = ps_lengths_v[hash_idx - s_h_end]
hashes[tok_i, hash_idx] = hash32(ps_r_buf, aff_len * sizeof(Py_UCS4), 0)
if (
ss_h_num > 0 and
_search_for_chars(tok_str, ss_1byte_ch, ss_2byte_ch, ss_3byte_ch, ss_4byte_ch, ss_res_buf, ss_res_len, True)
):
if ps_h_num > 0:
_search_for_chars(tok_str, ps_1byte_ch, ps_2byte_ch, ps_3byte_ch, ps_4byte_ch, ps_res_buf, ps_max_l, ps_res_len, False)
hash_val = 0
for hash_idx in range(s_h_end, ps_h_end):
offset = ps_lengths_v[hash_idx - s_h_end]
if offset > 0:
hash_val = hash32(ps_res_buf, offset, 0)
hashes[tok_i, hash_idx] = hash_val
if ss_h_num > 0:
_search_for_chars(tok_str, ss_1byte_ch, ss_2byte_ch, ss_3byte_ch, ss_4byte_ch, ss_res_buf, ss_max_l, ss_res_len, True)
hash_val = 0
for hash_idx in range(ps_h_end, ss_h_end):
aff_len = ss_lengths_v[hash_idx - ps_h_end]
hashes[tok_i, hash_idx] = hash32(ss_r_buf, aff_len * sizeof(Py_UCS4), 0)
offset = ss_lengths_v[hash_idx - ps_h_end]
if offset > 0:
hash_val = hash32(ss_res_buf, offset, 0)
hashes[tok_i, hash_idx] = hash_val
self.mem.free(aff_len_buf)
self.mem.free(ps_res_buf)
@ -2040,76 +2043,126 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
cdef void _set_affix_lengths(
const unsigned char[:] text_buf,
char* aff_len_buf,
const unsigned char[:] tok_str,
unsigned char* aff_len_buf,
const int pref_len,
const int suff_len,
) nogil:
""" TODO : Populate a buffer of length pref+suff with the first pref and the last suff characters of a word within a string.
If the word is shorter than pref and/or suff, the empty character positions in the middle are filled with zeros.
""" TODO : Populate *len_buf*, which has length *pref_len+suff_len* with the byte lengths of the first *pref_len* and the last
*suff_len* characters within *tok_str*. If the word is shorter than pref and/or suff, the empty lengths in the middle are
filled with zeros.
text_buf: a pointer to a UTF-32LE representation of the containing string.
tok_idx: the index of the first character of the word within the containing string.
tok_len: the length of the word.
aff_buf: the buffer to populate.
pref_len: the length of the prefix.
suff_len: the length of the suffix.
to_lower: if *True*, any upper case characters in either affix are converted to lower case.
tok_str: a memoryview of a UTF-8 representation of a string.
len_buf: a buffer of length *pref_len+suff_len* in which to store the lengths. The calling code ensures that lengths
greater than 255 cannot occur.
pref_len: the number of characters to process at the beginning of the word.
suff_len: the number of characters to process at the end of the word.
"""
cdef int text_buf_idx = 0, aff_len_buf_idx = 0, text_buf_len = len(text_buf)
cdef int tok_str_idx = 0, aff_len_buf_idx = 0, tok_str_len = len(tok_str)
while aff_len_buf_idx < pref_len:
if (text_buf[text_buf_idx] >> 6) ^ 2 != 0: # not a continuation character
aff_len_buf[aff_len_buf_idx] = text_buf_idx + 1
if (tok_str[tok_str_idx] & 0xc0) != 0x80: # not a continuation character
aff_len_buf[aff_len_buf_idx] = tok_str_idx + 1
aff_len_buf_idx += 1
text_buf_idx += 1
if text_buf_idx == len(text_buf):
tok_str_idx += 1
if tok_str_idx == len(tok_str):
break
if aff_len_buf_idx < pref_len:
memset(aff_len_buf + aff_len_buf_idx, 0, pref_len - aff_len_buf_idx)
aff_len_buf_idx = pref_len
text_buf_idx = 1
tok_str_idx = 1
while aff_len_buf_idx < pref_len + suff_len:
if (text_buf[text_buf_len - text_buf_idx] >> 6) ^ 2 != 0: # not a continuation character
aff_len_buf[aff_len_buf_idx] = text_buf_len - text_buf_idx
if (tok_str[tok_str_idx] & 0xc0) != 0x80: # not a continuation character
aff_len_buf[aff_len_buf_idx] = tok_str_len - tok_str_idx
aff_len_buf_idx += 1
text_buf_idx += 1
if text_buf_idx > text_buf_len:
tok_str_idx += 1
if tok_str_idx > tok_str_len:
break
if aff_len_buf_idx < pref_len + suff_len:
memset(aff_len_buf + aff_len_buf_idx, 0, suff_len - aff_len_buf_idx)
cdef bint _search_for_chars(
cdef void _search_for_chars(
const unsigned char[:] tok_str,
const unsigned char[:] s_1byte_ch,
const unsigned char[:] s_2byte_ch,
const unsigned char[:] s_3byte_ch,
const unsigned char[:] s_4byte_ch,
unsigned char* res_buf,
int max_res_l,
unsigned char* len_buf,
bint suffs_not_prefs
) nogil:
""" Search a word within a string for characters within *search_buf*, starting at the beginning or
end depending on the value of *suffs_not_prefs*. Wherever a character from *search_buf* matches,
the corresponding character from *lookup_buf* is added to *result_buf*.
""" Search *tok_str* within a string for characters within the *s_<n>byte_ch> buffers, starting at the
beginning or end depending on the value of *suffs_not_prefs*. Wherever a character matches,
it is added to *res_buf* and the byte length up to that point is added to *len_buf*.
text_buf: a pointer to a UTF-32LE representation of the containing string.
tok_idx: the index of the first character of the word within the containing string.
tok_len: the length of the word.
search_buf: the characters to search for (ordered).
lookup_buf: characters corresponding to *search_buf* to add to *result_buf* in the case of a match.
Having separate search and lookup arrays enables case-insensitivity to be handled efficiently.
search_buf_len: the length of *search_buf* and hence also of *lookup_buf*.
result_buf: the buffer in which to place the results.
result_buf_len: the length of *result_buf*.
tok_str: a memoryview of a UTF-8 representation of a string.
s_<n>byte_ch: a byte array containing in order n-byte-wide characters to search for.
res_buf: the buffer in which to place the search results.
max_res_l: the maximum number of found characters to place in *res_buf*.
len_buf: a buffer of length *max_res_l* in which to store the byte lengths.
The calling code ensures that lengths greater than 255 cannot occur.
suffs_not_prefs: if *True*, searching starts from the end of the word; if *False*, from the beginning.
Returns *True* if at least one character from *search_buf* was found in the word, otherwise *False*.
"""
cdef int tok_str_len = len(tok_str), search_char_idx = 0, res_buf_idx = 0, len_buf_idx = 0
cdef int last_tok_str_idx = tok_str_len if suffs_not_prefs else 0
cdef int this_tok_str_idx = tok_str_len - 1 if suffs_not_prefs else 1
cdef int ch_wdth, tok_start_idx
cdef char[:] search_chars
while True:
if (
this_tok_str_idx == tok_str_len or
(tok_str[this_tok_str_idx] & 0xc0) != 0x80 # not continuation character
):
ch_wdth = abs(this_tok_str_idx - last_tok_str_idx)
if ch_wdth == 1:
search_chars = s_1byte_ch
elif ch_wdth == 2:
search_chars = s_2byte_ch
elif ch_wdth == 3:
search_chars = s_3byte_ch
else:
search_chars = s_4byte_ch
tok_start_idx = this_tok_str_idx if suffs_not_prefs else last_tok_str_idx
for search_char_idx in range(0, len(search_chars), ch_wdth):
cmp_result = memcmp(tok_str + tok_start_idx, search_chars + search_char_idx, ch_wdth)
if cmp_result == 0:
memcpy(res_buf + res_buf_idx, search_chars + search_char_idx, ch_wdth)
res_buf_idx += ch_wdth
len_buf[len_buf_idx] = res_buf_idx
len_buf_idx += 1
if len_buf_idx == max_res_l:
return
if cmp_result >= 0:
break
last_tok_str_idx = this_tok_str_idx
if suffs_not_prefs:
this_tok_str_idx -= 1
if this_tok_str_idx < 0:
break
else:
this_tok_str_idx += 1
if this_tok_str_idx >= tok_str_len:
break
# fill in unused characters in the length buffer with 0
memset(res_buf + res_buf_idx, 0, max_res_l - res_buf_idx)
cdef int result_buf_idx = 0, text_string_idx = tok_idx + (tok_len - 1) if suffs_not_prefs else tok_idx
cdef int search_buf_idx
cdef int cmp_result

View File

@ -1773,6 +1773,6 @@ def get_search_char_byte_arrays(
elif char_length == 4:
sc4.extend(encoded_search_char_bytes[working_start:idx])
else:
raise RuntimeError(Errors.E1049)
raise RuntimeError(Errors.E1050)
working_start = idx
return bytes(sc1), bytes(sc2), bytes(sc3), bytes(sc4)