From 5b29568fb7281a6a30fba0bf0f15d6ba0bb72683 Mon Sep 17 00:00:00 2001 From: richardpaulhudson Date: Thu, 10 Nov 2022 11:37:03 +0100 Subject: [PATCH] Fix wild pointer problem --- spacy/strings.pxd | 2 +- spacy/strings.pyx | 20 +++++++------------- spacy/tokens/doc.pyx | 21 +++++++++++++++------ 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/spacy/strings.pxd b/spacy/strings.pxd index 59cdff522..278624202 100644 --- a/spacy/strings.pxd +++ b/spacy/strings.pxd @@ -27,4 +27,4 @@ cdef class StringStore: cdef const Utf8Str* intern_unicode(self, str py_string) cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash) - cdef const unsigned char* utf8_ptr(self, attr_t hash_val) + cdef (const unsigned char*, int) utf8_ptr(self, const attr_t hash_val) diff --git a/spacy/strings.pyx b/spacy/strings.pyx index 7883bb951..9983f6b1f 100644 --- a/spacy/strings.pyx +++ b/spacy/strings.pyx @@ -317,23 +317,17 @@ cdef class StringStore: return value @cython.boundscheck(False) # Deactivate bounds checking - cdef const unsigned char* utf8_ptr(self, const attr_t hash_val): - if hash_val == 0: - return b"" - elif hash_val < len(SYMBOLS_BY_INT): - return SYMBOLS_BY_INT[hash_val].encode("utf-8") + cdef (const unsigned char*, int) utf8_ptr(self, const attr_t hash_val): + # Returns a pointer to the UTF-8 string together with its length in bytes. + # This method presumes the calling code has already checked that *hash_val* + # is not 0 and does not refer to a member of *SYMBOLS_BY_INT*. cdef Utf8Str* string = self._map.get(hash_val) if string.s[0] < sizeof(string.s) and string.s[0] != 0: - return string.s[1:string.s[0]+1] - elif string.p[0] < 255: - return string.p[1:string.p[0]+1] - cdef int i, length - i = 0 - length = 0 + return &string.s[1], string.s[0] + cdef length=0, i=0 while string.p[i] == 255: i += 1 length += 255 length += string.p[i] - i += 1 - return string.p[i:length + i] + return &string.p[i + 1], length diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 8333be89b..32259b985 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -21,6 +21,7 @@ from .span cimport Span from .token cimport MISSING_DEP from ._dict_proxies import SpanGroups from .token cimport Token +from ..symbols import NAMES as SYMBOLS_BY_INT from ..lexeme cimport Lexeme, EMPTY_LEXEME from ..typedefs cimport attr_t, flags_t from ..attrs cimport attr_id_t @@ -1820,17 +1821,25 @@ cdef class Doc: # Define working variables cdef TokenC tok_c - cdef int tok_i, tok_str_l + cdef int tok_i, tok_str_l, working_store_i cdef attr_t num_tok_attr + cdef bytes tok_str_bytes cdef const unsigned char* tok_str cdef np.uint64_t* w_hashes_ptr = hashes_ptr for tok_i in range(doc_l): tok_c = self.c[tok_i] num_tok_attr = tok_c.lex.orth if cs else tok_c.lex.lower - tok_str = self.vocab.strings.utf8_ptr(num_tok_attr) - tok_str_l = strlen( tok_str) - + if num_tok_attr < len(SYMBOLS_BY_INT): # hardly ever happens + if num_tok_attr == 0: + tok_str_bytes = b"" + else: + tok_str_bytes = SYMBOLS_BY_INT[num_tok_attr].encode("UTF-8") + tok_str = tok_str_bytes + tok_str_l = len(tok_str_bytes) + else: + tok_str, tok_str_l = self.vocab.strings.utf8_ptr(num_tok_attr) + if p_max_l > 0: _set_prefix_lengths(tok_str, tok_str_l, p_max_l, pref_l_buf) w_hashes_ptr += _write_hashes(tok_str, p_lengths, pref_l_buf, 0, w_hashes_ptr) @@ -2055,13 +2064,13 @@ cdef void _set_prefix_lengths( cdef int tok_str_idx = 1, pref_l_buf_idx = 0 while pref_l_buf_idx < p_max_l: - if (tok_str[tok_str_idx] == 0 # end of string + if (tok_str_idx >= tok_str_l or ((tok_str[tok_str_idx] & 0xc0) != 0x80) # not a continuation character ): pref_l_buf[pref_l_buf_idx] = tok_str_idx pref_l_buf_idx += 1 - if tok_str[tok_str_idx] == 0: # end of string + if tok_str_idx >= tok_str_l: break tok_str_idx += 1