From 5d32dd624620b14332eb1b13460b823c950795ba Mon Sep 17 00:00:00 2001 From: "richard@explosion.ai" Date: Thu, 3 Nov 2022 20:54:07 +0100 Subject: [PATCH] Intermediate state --- spacy/errors.py | 2 +- spacy/ml/models/tok2vec.py | 4 +- spacy/tests/doc/test_doc_api.py | 169 +++++++------------------------- spacy/tests/test_util.py | 24 +++++ spacy/tokens/doc.pxd | 18 ++-- spacy/tokens/doc.pyx | 144 +++++++++++---------------- spacy/util.py | 17 ++-- 7 files changed, 135 insertions(+), 243 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index 896ad2ca6..f10723d20 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -955,7 +955,7 @@ class Errors(metaclass=ErrorsWithCodes): E1046 = ("{cls_name} is an abstract class and cannot be instantiated. If you are looking for spaCy's default " "knowledge base, use `InMemoryLookupKB`.") E1047 = ("Invalid rich group config '{label}'.") - E1048 = ("Length > 31 in rich group config '{label}.") + E1048 = ("Length > 63 in rich group config '{label}.") E1049 = ("Rich group config {label} specifies lengths that are not in ascending order.") E1050 = ("Error splitting UTF-8 byte string into separate characters.") diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index a0b0fc3b3..5cdbabf52 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -199,6 +199,8 @@ def _verify_rich_config_group( if lengths is not None or rows is not None: if is_search_char_group and (search_chars is None or len(search_chars) == 0): raise ValueError(Errors.E1047.format(label=label)) + if len(search_chars) > 63: + raise ValueError(Errors.E1048.format(label=label)) if lengths is None or rows is None: raise ValueError(Errors.E1047.format(label=label)) if len(lengths) != len(rows): @@ -208,7 +210,7 @@ def _verify_rich_config_group( elif search_chars is not None: raise ValueError(Errors.E1047.format(label=label)) if lengths is not None: - if lengths[-1] > 31: + if lengths[-1] > 63: raise ValueError(Errors.E1048.format(label=label)) if len(lengths) != len(set(lengths)) or lengths != sorted(lengths): raise ValueError(Errors.E1049.format(label=label)) diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index 632b4ab3b..f76dc45b5 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -3,7 +3,7 @@ from pickle import EMPTY_DICT import weakref import numpy -import ctypes +from time import time from numpy.testing import assert_array_equal from murmurhash.mrmr import hash import pytest @@ -1438,8 +1438,13 @@ def _encode_and_hash(input: str, *, reverse: bool = False) -> int: @pytest.mark.parametrize("case_sensitive", [True, False]) def test_get_character_combination_hashes_good_case(en_tokenizer, case_sensitive): doc = en_tokenizer("spaCy✨ and Prodigy") - ps1, ps2, ps3, ps4 = get_search_char_byte_arrays("Rp", case_sensitive) - ss1, ss2, ss3, ss4 = get_search_char_byte_arrays("xx✨rp", case_sensitive) + + ps_search_chars, ps_width_offsets = get_search_char_byte_arrays( + "Rp", case_sensitive + ) + ss_search_chars, ss_width_offsets = get_search_char_byte_arrays( + "xx✨rp", case_sensitive + ) hashes = doc.get_character_combination_hashes( cs=case_sensitive, p_lengths=bytes( @@ -1449,7 +1454,6 @@ def test_get_character_combination_hashes_good_case(en_tokenizer, case_sensitive 4, ) ), - p_max_l=4, s_lengths=bytes( ( 2, @@ -1458,35 +1462,18 @@ def test_get_character_combination_hashes_good_case(en_tokenizer, case_sensitive 5, ) ), - s_max_l=5, - ps_1byte_ch=ps1, - ps_1byte_ch_l=len(ps1), - ps_2byte_ch=ps2, - ps_2byte_ch_l=len(ps2), - ps_3byte_ch=ps3, - ps_3byte_ch_l=len(ps3), - ps_4byte_ch=ps4, - ps_4byte_ch_l=len(ps4), + ps_search_chars=ps_search_chars, + ps_width_offsets=ps_width_offsets, ps_lengths=bytes((2,)), - ps_max_l=2, - ss_1byte_ch=ss1, - ss_1byte_ch_l=len(ss1), - ss_2byte_ch=ss2, - ss_2byte_ch_l=len(ss2), - ss_3byte_ch=ss3, - ss_3byte_ch_l=len(ss3), - ss_4byte_ch=ss4, - ss_4byte_ch_l=len(ss4), + ss_search_chars=ss_search_chars, + ss_width_offsets=ss_width_offsets, ss_lengths=bytes( ( 1, 2, ) ), - ss_max_l=2, - hashes_per_tok=10, ) - assert hashes[0][0] == _encode_and_hash("s") assert hashes[0][1] == _encode_and_hash("spa") assert hashes[0][2] == _encode_and_hash("spaC" if case_sensitive else "spac") @@ -1494,7 +1481,6 @@ def test_get_character_combination_hashes_good_case(en_tokenizer, case_sensitive assert hashes[0][4] == _encode_and_hash("yCa" if case_sensitive else "yca") assert hashes[0][5] == _encode_and_hash("yCap" if case_sensitive else "ycap") assert hashes[0][6] == _encode_and_hash("yCaps" if case_sensitive else "ycaps") - assert hashes[0][7] == _encode_and_hash("p") assert hashes[0][8] == _encode_and_hash("p") assert hashes[0][9] == _encode_and_hash("p") @@ -1539,11 +1525,10 @@ def test_get_character_combination_hashes_good_case(en_tokenizer, case_sensitive def test_get_character_combination_hashes_good_case_partial(en_tokenizer): doc = en_tokenizer("spaCy✨ and Prodigy") - ps1, ps2, ps3, ps4 = get_search_char_byte_arrays("rp", False) + ps_search_chars, ps_width_offsets = get_search_char_byte_arrays("rp", False) hashes = doc.get_character_combination_hashes( cs=False, p_lengths=bytes(), - p_max_l=0, s_lengths=bytes( ( 2, @@ -1552,28 +1537,12 @@ def test_get_character_combination_hashes_good_case_partial(en_tokenizer): 5, ) ), - s_max_l=5, - ps_1byte_ch=ps1, - ps_1byte_ch_l=len(ps1), - ps_2byte_ch=ps2, - ps_2byte_ch_l=len(ps2), - ps_3byte_ch=ps3, - ps_3byte_ch_l=len(ps3), - ps_4byte_ch=ps4, - ps_4byte_ch_l=len(ps4), + ps_search_chars=ps_search_chars, + ps_width_offsets=ps_width_offsets, ps_lengths=bytes((2,)), - ps_max_l=2, - ss_1byte_ch=bytes(), - ss_1byte_ch_l=0, - ss_2byte_ch=bytes(), - ss_2byte_ch_l=0, - ss_3byte_ch=bytes(), - ss_3byte_ch_l=0, - ss_4byte_ch=bytes(), - ss_4byte_ch_l=0, + ss_search_chars=bytes(), + ss_width_offsets=bytes(), ss_lengths=bytes(), - ss_max_l=0, - hashes_per_tok=5, ) assert hashes[0][0] == _encode_and_hash("yc") @@ -1607,30 +1576,13 @@ def test_get_character_combination_hashes_copying_in_middle(en_tokenizer): hashes = doc.get_character_combination_hashes( cs=False, p_lengths=bytes((p_length,)), - p_max_l=p_length, s_lengths=bytes((s_length,)), - s_max_l=s_length, - ps_1byte_ch=bytes(), - ps_1byte_ch_l=0, - ps_2byte_ch=bytes(), - ps_2byte_ch_l=0, - ps_3byte_ch=bytes(), - ps_3byte_ch_l=0, - ps_4byte_ch=bytes(), - ps_4byte_ch_l=0, + ps_search_chars=bytes(), + ps_width_offsets=bytes(), ps_lengths=bytes(), - ps_max_l=0, - ss_1byte_ch=bytes(), - ss_1byte_ch_l=0, - ss_2byte_ch=bytes(), - ss_2byte_ch_l=0, - ss_3byte_ch=bytes(), - ss_3byte_ch_l=0, - ss_4byte_ch=bytes(), - ss_4byte_ch_l=0, + ss_search_chars=bytes(), + ss_width_offsets=bytes(), ss_lengths=bytes(), - ss_max_l=0, - hashes_per_tok=2, ) assert hashes[0][0] == _encode_and_hash("sp𐌞cé"[:p_length]) @@ -1642,7 +1594,7 @@ def test_get_character_combination_hashes_turkish_i_with_dot( en_tokenizer, case_sensitive ): doc = en_tokenizer("İ".lower() + "İ") - s1, s2, s3, s4 = get_search_char_byte_arrays("İ", case_sensitive) + search_chars, width_offsets = get_search_char_byte_arrays("İ", case_sensitive) hashes = doc.get_character_combination_hashes( cs=case_sensitive, p_lengths=bytes( @@ -1653,7 +1605,6 @@ def test_get_character_combination_hashes_turkish_i_with_dot( 4, ) ), - p_max_l=4, s_lengths=bytes( ( 1, @@ -1662,15 +1613,8 @@ def test_get_character_combination_hashes_turkish_i_with_dot( 4, ) ), - s_max_l=4, - ps_1byte_ch=s1, - ps_1byte_ch_l=len(s1), - ps_2byte_ch=s2, - ps_2byte_ch_l=len(s2), - ps_3byte_ch=s3, - ps_3byte_ch_l=len(s3), - ps_4byte_ch=s4, - ps_4byte_ch_l=len(s4), + ps_search_chars=search_chars, + ps_width_offsets=width_offsets, ps_lengths=bytes( ( 1, @@ -1679,15 +1623,8 @@ def test_get_character_combination_hashes_turkish_i_with_dot( 4, ) ), - ps_max_l=4, - ss_1byte_ch=s1, - ss_1byte_ch_l=len(s1), - ss_2byte_ch=s2, - ss_2byte_ch_l=len(s2), - ss_3byte_ch=s3, - ss_3byte_ch_l=len(s3), - ss_4byte_ch=s4, - ss_4byte_ch_l=len(s4), + ss_search_chars=search_chars, + ss_width_offsets=width_offsets, ss_lengths=bytes( ( 1, @@ -1696,8 +1633,6 @@ def test_get_character_combination_hashes_turkish_i_with_dot( 4, ) ), - ss_max_l=4, - hashes_per_tok=16, ) COMBINING_DOT_ABOVE = b"\xcc\x87".decode("UTF-8") @@ -1747,34 +1682,17 @@ def test_get_character_combination_hashes_string_store_spec_cases( assert len(long_word) > 255 doc = en_tokenizer(" ".join((symbol, short_word, normal_word, long_word))) assert len(doc) == 4 - ps1, ps2, ps3, ps4 = get_search_char_byte_arrays("E", case_sensitive) + ps_search_chars, ps_width_offsets = get_search_char_byte_arrays("E", case_sensitive) hashes = doc.get_character_combination_hashes( cs=case_sensitive, p_lengths=bytes((2,)), - p_max_l=2, s_lengths=bytes((2,)), - s_max_l=2, - ps_1byte_ch=ps1, - ps_1byte_ch_l=len(ps1), - ps_2byte_ch=ps2, - ps_2byte_ch_l=len(ps2), - ps_3byte_ch=ps3, - ps_3byte_ch_l=len(ps3), - ps_4byte_ch=ps4, - ps_4byte_ch_l=len(ps4), + ps_search_chars=ps_search_chars, + ps_width_offsets=ps_width_offsets, ps_lengths=bytes((2,)), - ps_max_l=2, - ss_1byte_ch=bytes(), - ss_1byte_ch_l=0, - ss_2byte_ch=bytes(), - ss_2byte_ch_l=0, - ss_3byte_ch=bytes(), - ss_3byte_ch_l=0, - ss_4byte_ch=bytes(), - ss_4byte_ch_l=0, + ss_search_chars=bytes(), + ss_width_offsets=bytes(), ss_lengths=bytes(), - ss_max_l=0, - hashes_per_tok=3, ) assert hashes[0][0] == _encode_and_hash("FL" if case_sensitive else "fl") assert hashes[0][1] == _encode_and_hash("91") @@ -1799,30 +1717,13 @@ def test_character_combination_hashes_empty_lengths(en_tokenizer): doc.get_character_combination_hashes( cs=True, p_lengths=bytes(), - p_max_l=0, s_lengths=bytes(), - s_max_l=0, - ps_1byte_ch=bytes(), - ps_1byte_ch_l=0, - ps_2byte_ch=bytes(), - ps_2byte_ch_l=0, - ps_3byte_ch=bytes(), - ps_3byte_ch_l=0, - ps_4byte_ch=bytes(), - ps_4byte_ch_l=0, + ps_search_chars=bytes(), + ps_width_offsets=bytes(), ps_lengths=bytes(), - ps_max_l=0, - ss_1byte_ch=bytes(), - ss_1byte_ch_l=0, - ss_2byte_ch=bytes(), - ss_2byte_ch_l=0, - ss_3byte_ch=bytes(), - ss_3byte_ch_l=0, - ss_4byte_ch=bytes(), - ss_4byte_ch_l=0, + ss_search_chars=bytes(), + ss_width_offsets=bytes(), ss_lengths=bytes(), - ss_max_l=0, - hashes_per_tok=0, ).shape == (1, 0) ) diff --git a/spacy/tests/test_util.py b/spacy/tests/test_util.py index c119fdb79..148bac7cd 100644 --- a/spacy/tests/test_util.py +++ b/spacy/tests/test_util.py @@ -36,6 +36,30 @@ def test_get_search_char_byte_arrays_all_widths(case_sensitive): assert width_offsets == b"\x00\x02\x04\x07\x0b" +@pytest.mark.parametrize("case_sensitive", [True, False]) +def test_get_search_char_byte_arrays_widths_1_and_3(case_sensitive): + search_chars, width_offsets = spacy.util.get_search_char_byte_arrays( + "B—", case_sensitive + ) + if case_sensitive: + assert search_chars == "B—".encode("utf-8") + else: + assert search_chars == "b—".encode("utf-8") + assert width_offsets == b"\x00\x01\x01\x04\x04" + + +@pytest.mark.parametrize("case_sensitive", [True, False]) +def test_get_search_char_byte_arrays_widths_1_and_4(case_sensitive): + search_chars, width_offsets = spacy.util.get_search_char_byte_arrays( + "B𐌞", case_sensitive + ) + if case_sensitive: + assert search_chars == "B𐌞".encode("utf-8") + else: + assert search_chars == "b𐌞".encode("utf-8") + assert width_offsets == b"\x00\x01\x01\x01\x05" + + @pytest.mark.parametrize("case_sensitive", [True, False]) def test_turkish_i_with_dot(case_sensitive): search_chars, width_offsets = spacy.util.get_search_char_byte_arrays( diff --git a/spacy/tokens/doc.pxd b/spacy/tokens/doc.pxd index 84b333643..fbb537408 100644 --- a/spacy/tokens/doc.pxd +++ b/spacy/tokens/doc.pxd @@ -43,34 +43,28 @@ cdef int [:,:] _get_lca_matrix(Doc, int start, int end) cdef void _set_prefix_lengths( const unsigned char* tok_str, const int tok_str_l, - unsigned char* pref_l_buf, const int p_max_l, + unsigned char* pref_l_buf, ) nogil cdef void _set_suffix_lengths( const unsigned char* tok_str, const int tok_str_l, - unsigned char* suff_l_buf, const int s_max_l, + unsigned char* suff_l_buf, ) nogil cdef void _search_for_chars( const unsigned char* tok_str, const int tok_str_l, - const unsigned char* s_1byte_ch, - const int s_1byte_ch_l, - const unsigned char* s_2byte_ch, - const int s_2byte_ch_l, - const unsigned char* s_3byte_ch, - const int s_3byte_ch_l, - const unsigned char* s_4byte_ch, - const int s_4byte_ch_l, + const unsigned char* search_chars, + const unsigned char* width_offsets, + const int max_res_l, + const bint suffs_not_prefs, unsigned char* res_buf, - int max_res_l, unsigned char* l_buf, - bint suffs_not_prefs ) nogil diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 7c8764749..f0ab14776 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -1739,30 +1739,13 @@ cdef class Doc: *, const bint cs, const unsigned char* p_lengths, - const int p_max_l, const unsigned char* s_lengths, - const int s_max_l, - const unsigned char* ps_1byte_ch, - const int ps_1byte_ch_l, - const unsigned char* ps_2byte_ch, - const int ps_2byte_ch_l, - const unsigned char* ps_3byte_ch, - const int ps_3byte_ch_l, - const unsigned char* ps_4byte_ch, - const int ps_4byte_ch_l, + const unsigned char* ps_search_chars, + const unsigned char* ps_width_offsets, const unsigned char* ps_lengths, - const int ps_max_l, - const unsigned char* ss_1byte_ch, - const int ss_1byte_ch_l, - const unsigned char* ss_2byte_ch, - const int ss_2byte_ch_l, - const unsigned char* ss_3byte_ch, - const int ss_3byte_ch_l, - const unsigned char* ss_4byte_ch, - const int ss_4byte_ch_l, + const unsigned char* ss_search_chars, + const unsigned char* ss_width_offsets, const unsigned char* ss_lengths, - const int ss_max_l, - const int hashes_per_tok ): """ Returns a 2D NumPy array where the rows represent tokens and the columns represent hashes of various character combinations @@ -1778,31 +1761,43 @@ cdef class Doc: cs: if *False*, hashes are generated based on the lower-case version of each token. p_lengths: an array of single-byte values specifying the lengths of prefixes to be hashed in ascending order. For example, if *p_lengths==[2, 3]*, the prefixes hashed for "spaCy" would be "sp" and "spa". - p_max_l: the value of *p_lengths[-1]*, or *0* if *p_lengths==None*. Passed in for speed. s_lengths: an array of single-byte values specifying the lengths of suffixes to be hashed in ascending order. - For example, if *s_lengths==[2, 3]* and *cs == True*, the suffixes hashed for "spaCy" would be "Cy" and "aCy". - s_max_l: the value of *s_lengths[-1]*, or *0* if *s_lengths==None*. Passed in for speed. - ps_byte_ch: a byte array containing in order n-byte-wide characters to search for within each token, + For example, if *s_lengths==[2, 3]* and *cs == True*, the suffixes hashed for "spaCy" would be "yC" and "yCa". + ps_search_chars: a byte array containing, in numerical order, UTF-8 characters to search for within each token, starting at the beginning. - ps_byte_ch_l: the length of *ps_byte_ch*. Passed in for speed. + ps_width_offsets: an array of single-byte values [1-char-start, 2-char-start, 3-char-start, 4-char-start, 4-char-end] + specifying the offsets within *ps_search_chars* that contain UTF-8 characters with the specified widths. ps_lengths: an array of single-byte values specifying the lengths of search results (from the beginning) to be hashed in ascending order. For example, if *ps_lengths==[1, 2]*, *ps_search=="aC" and *cs==False*, the searched strings hashed for "spaCy" would be "a" and "ac". - ps_max_l: the value of *ps_lengths[-1]*, or *0* if *ps_lengths==None*. Passed in for speed. - ss_byte_ch: a byte array containing in order n-byte-wide characters to search for within each token, + ss_search_chars: a byte array containing, in numerical order, UTF-8 characters to search for within each token, starting at the end. - ss_byte_ch_l: the length of *ss_byte_ch*. Passed in for speed. + ss_width_offsets: an array of single-byte values [1-char-start, 2-char-start, 3-char-start, 4-char-start, 4-char-end] + specifying the offsets within *ss_search_chars* that contain UTF-8 characters with the specified widths. ss_lengths: an array of single-byte values specifying the lengths of search results (from the end) to be hashed in ascending order. For example, if *ss_lengths==[1, 2]*, *ss_search=="aC" and *cs==False*, the searched strings hashed for "spaCy" would be "c" and "ca". - ss_max_l: the value of *ss_lengths[-1]*, or *0* if *ss_lengths==None*. Passed in for speed. - hashes_per_tok: the total number of hashes produced for each token. Passed in for speed. + + Many of the buffers passed into and used by this method contain single-byte numerical values. This takes advantage of + the fact that we are hashing short affixes and searching for small groups of characters; the calling code is responsible + lengths being passed in cannot exceed 63 and that *_search_chars buffers are never longer than 255. """ + # Work out lengths + cdef int p_lengths_l = strlen( p_lengths) + cdef int s_lengths_l = strlen( s_lengths) + cdef int ps_lengths_l = strlen( ps_lengths) + cdef int ss_lengths_l = strlen( ss_lengths) + cdef int hashes_per_tok = p_lengths_l + s_lengths_l + ps_lengths_l + ss_lengths_l + cdef int p_max_l = p_lengths[p_lengths_l - 1] if p_lengths_l > 0 else 0 + cdef int s_max_l = s_lengths[s_lengths_l - 1] if s_lengths_l > 0 else 0 + cdef int ps_max_l = ps_lengths[ps_lengths_l - 1] if ps_lengths_l > 0 else 0 + cdef int ss_max_l = ss_lengths[ss_lengths_l - 1] if ss_lengths_l > 0 else 0 + # Define / allocate buffers cdef Pool mem = Pool() cdef unsigned char* pref_l_buf = mem.alloc(p_max_l, 1) - cdef unsigned char* suff_l_buf = mem.alloc(p_max_l, 1) + cdef unsigned char* suff_l_buf = mem.alloc(s_max_l, 1) cdef unsigned char* ps_res_buf = mem.alloc(ps_max_l, 4) cdef unsigned char* ps_l_buf = mem.alloc(ps_max_l, 1) cdef unsigned char* ss_res_buf = mem.alloc(ss_max_l, 4) @@ -1813,7 +1808,7 @@ cdef class Doc: # Define working variables cdef TokenC tok_c - cdef int hash_idx, tok_i, tok_str_l + cdef int tok_i, tok_str_l cdef attr_t num_tok_attr cdef const unsigned char* tok_str cdef np.uint32_t* w_hashes_ptr = hashes_ptr @@ -1825,21 +1820,21 @@ cdef class Doc: tok_str_l = strlen( tok_str) if p_max_l > 0: - _set_prefix_lengths(tok_str, tok_str_l, pref_l_buf, p_max_l) + _set_prefix_lengths(tok_str, tok_str_l, p_max_l, pref_l_buf) w_hashes_ptr += _write_hashes(tok_str, p_lengths, pref_l_buf, 0, w_hashes_ptr) if s_max_l > 0: - _set_suffix_lengths(tok_str, tok_str_l, suff_l_buf, s_max_l) + _set_suffix_lengths(tok_str, tok_str_l, s_max_l, suff_l_buf) w_hashes_ptr += _write_hashes(tok_str, s_lengths, suff_l_buf, tok_str_l - 1, w_hashes_ptr) if ps_max_l > 0: - _search_for_chars(tok_str, tok_str_l, ps_1byte_ch, ps_1byte_ch_l, ps_2byte_ch, ps_2byte_ch_l, - ps_3byte_ch, ps_3byte_ch_l, ps_4byte_ch, ps_4byte_ch_l, ps_res_buf, ps_max_l, ps_l_buf, False) + _search_for_chars(tok_str, tok_str_l, ps_search_chars, ps_width_offsets, + ps_max_l, False, ps_res_buf, ps_l_buf) w_hashes_ptr += _write_hashes(ps_res_buf, ps_lengths, ps_l_buf, 0, w_hashes_ptr) if ss_max_l > 0: - _search_for_chars(tok_str, tok_str_l, ss_1byte_ch, ss_1byte_ch_l, ss_2byte_ch, ss_2byte_ch_l, - ss_3byte_ch, ss_3byte_ch_l, ss_4byte_ch, ss_4byte_ch_l, ss_res_buf, ss_max_l, ss_l_buf, True) + _search_for_chars(tok_str, tok_str_l, ss_search_chars, ss_width_offsets, + ss_max_l, True, ss_res_buf, ss_l_buf) w_hashes_ptr += _write_hashes(ss_res_buf, ss_lengths, ss_l_buf, 0, w_hashes_ptr) cdef np.ndarray[np.uint32_t, ndim=2] hashes = numpy.empty( @@ -2033,17 +2028,17 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end): cdef void _set_prefix_lengths( const unsigned char* tok_str, const int tok_str_l, - unsigned char* pref_l_buf, const int p_max_l, + unsigned char* pref_l_buf, ) nogil: """ Populate *pref_l_buf*, which has length *pref_l*, with the byte lengths of the first *pref_l* characters within *tok_str*. Lengths that are greater than the character length of the whole word are populated with the byte length of the whole word. tok_str: a UTF-8 representation of a string. tok_str_l: the length of *tok_str*. + p_max_l: the number of characters to process at the beginning of the word. pref_l_buf: a buffer of length *p_max_l* in which to store the lengths. The calling code ensures that lengths greater than 255 cannot occur. - p_max_l: the number of characters to process at the beginning of the word. """ cdef int tok_str_idx = 1, pref_l_buf_idx = 0 @@ -2066,17 +2061,17 @@ cdef void _set_prefix_lengths( cdef void _set_suffix_lengths( const unsigned char* tok_str, const int tok_str_l, + const int s_max_l, unsigned char* suff_l_buf, - const int s_max_l, ) nogil: """ Populate *suff_l_buf*, which has length *suff_l*, with the byte lengths of the last *suff_l* characters within *tok_str*. Lengths that are greater than the character length of the whole word are populated with the byte length of the whole word. tok_str: a UTF-8 representation of a string. tok_str_l: the length of *tok_str*. + s_max_l: the number of characters to process at the end of the word. suff_l_buf: a buffer of length *s_max_l* in which to store the lengths. The calling code ensures that lengths greater than 255 cannot occur. - s_max_l: the number of characters to process at the end of the word. """ cdef int tok_str_idx = tok_str_l - 1, suff_l_buf_idx = 0 @@ -2096,67 +2091,48 @@ cdef void _set_suffix_lengths( cdef void _search_for_chars( const unsigned char* tok_str, const int tok_str_l, - const unsigned char* s_1byte_ch, - const int s_1byte_ch_l, - const unsigned char* s_2byte_ch, - const int s_2byte_ch_l, - const unsigned char* s_3byte_ch, - const int s_3byte_ch_l, - const unsigned char* s_4byte_ch, - const int s_4byte_ch_l, + const unsigned char* search_chars, + const unsigned char* width_offsets, + const int max_res_l, + const bint suffs_not_prefs, unsigned char* res_buf, - int max_res_l, unsigned char* l_buf, - bint suffs_not_prefs ) nogil: - """ Search *tok_str* within a string for characters within the *s_byte_ch> buffers, starting at the + """ Search *tok_str* within a string for characters within *search_chars*, starting at the beginning or end depending on the value of *suffs_not_prefs*. Wherever a character matches, - it is added to *res_buf* and the byte length up to that point is added to *len_buf*. When nothing - more is found, the remainder of *len_buf* is populated wth the byte length from the last result, + it is added to *res_buf* and the byte length up to that point is added to *l_buf*. When nothing + more is found, the remainder of *l_buf* is populated wth the byte length from the last result, which may be *0* if the search was not successful. tok_str: a UTF-8 representation of a string. tok_str_l: the length of *tok_str*. - s_byte_ch: a byte array containing in order n-byte-wide characters to search for. - res_buf: the buffer in which to place the search results. + search_chars: a byte array containing, in numerical order, UTF-8 characters to search for within *tok_str*. + width_offsets: an array of single-byte values [1-char-start, 2-char-start, 3-char-start, 4-char-start, 4-char-end] + specifying the offsets within *search_chars* that contain UTF-8 characters with the specified widths. max_res_l: the maximum number of found characters to place in *res_buf*. - l_buf: a buffer of length *max_res_l* in which to store the byte lengths. - The calling code ensures that lengths greater than 255 cannot occur. suffs_not_prefs: if *True*, searching starts from the end of the word; if *False*, from the beginning. + res_buf: the buffer in which to place the search results. + l_buf: a buffer of length *max_res_l* in which to store the byte lengths. + The calling code ensures that lengths greater than 255 cannot occur. """ cdef int res_buf_idx = 0, l_buf_idx = 0, ch_wdth, tok_start_idx, search_char_idx - cdef int search_chars_l - cdef const unsigned char* search_chars - cdef int last_tok_str_idx = tok_str_l if suffs_not_prefs else 0 cdef int this_tok_str_idx = tok_str_l - 1 if suffs_not_prefs else 1 - while True: - if ( - this_tok_str_idx == tok_str_l or - (tok_str[this_tok_str_idx] & 0xc0) != 0x80 # not continuation character, always applies to [0]. + while this_tok_str_idx >= 0 and this_tok_str_idx <= tok_str_l: + if ( + (this_tok_str_idx == tok_str_l) or + ((tok_str[this_tok_str_idx] & 0xc0) != 0x80) # not continuation character, always applies to [0]. ): if this_tok_str_idx > last_tok_str_idx: ch_wdth = this_tok_str_idx - last_tok_str_idx else: ch_wdth = last_tok_str_idx - this_tok_str_idx - if ch_wdth == 1: - search_chars = s_1byte_ch - search_chars_l = s_1byte_ch_l - elif ch_wdth == 2: - search_chars = s_2byte_ch - search_chars_l = s_2byte_ch_l - elif ch_wdth == 3: - search_chars = s_3byte_ch - search_chars_l = s_3byte_ch_l - else: - search_chars = s_4byte_ch - search_chars_l = s_4byte_ch_l + tok_start_idx = this_tok_str_idx if suffs_not_prefs else last_tok_str_idx - - search_char_idx = 0 - while search_char_idx < search_chars_l: + search_char_idx = width_offsets[ch_wdth - 1] + while search_char_idx < width_offsets[ch_wdth]: cmp_result = memcmp(&tok_str[tok_start_idx], &search_chars[search_char_idx], ch_wdth) if cmp_result == 0: memcpy(res_buf + res_buf_idx, &search_chars[search_char_idx], ch_wdth) @@ -2171,12 +2147,8 @@ cdef void _search_for_chars( last_tok_str_idx = this_tok_str_idx if suffs_not_prefs: this_tok_str_idx -= 1 - if this_tok_str_idx < 0: - break else: this_tok_str_idx += 1 - if this_tok_str_idx > tok_str_l: - break # fill in unused characters in the length buffer memset(l_buf + l_buf_idx, res_buf_idx, max_res_l - l_buf_idx) diff --git a/spacy/util.py b/spacy/util.py index 90160d023..e7dc7dd56 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -1757,10 +1757,10 @@ def get_search_char_byte_arrays( search_char_string = search_char_string.lower() ordered_search_char_string = "".join(sorted(set(search_char_string))) search_chars = ordered_search_char_string.encode("UTF-8") - width_offsets = [0, -1, -1, -1, -1] - working_start = -1 - working_width = 1 - for idx in range(len(search_chars) + 1): + width_offsets = [-1] * 5 + working_start = 0 + working_width = 0 + for idx in range(1, len(search_chars) + 1): if ( idx == len(search_chars) or search_chars[idx] & 0xC0 != 0x80 # not continuation byte @@ -1769,11 +1769,10 @@ def get_search_char_byte_arrays( if this_width > 4 or this_width < working_width: raise RuntimeError(Errors.E1050) if this_width > working_width: - width_offsets[this_width - 1] = working_start + for i in range(working_width, 5): + width_offsets[i] = working_start working_width = this_width working_start = idx - width_offsets[this_width] = idx - for i in range(5): - if width_offsets[i] == -1: - width_offsets[i] = width_offsets[i - 1] + for i in range(this_width, 5): + width_offsets[i] = idx return search_chars, bytes((width_offsets))