mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-02 19:30:19 +03:00
Refactoring
This commit is contained in:
parent
a18bac40f5
commit
146d286da6
|
@ -977,10 +977,13 @@ def test_doc_spans_setdefault(en_tokenizer):
|
|||
assert len(doc.spans["key3"]) == 2
|
||||
|
||||
|
||||
def test_get_affixes_good_case(en_tokenizer):
|
||||
@pytest.mark.parametrize(
|
||||
"case_sensitive", [True, False]
|
||||
)
|
||||
def test_get_affixes_good_case(en_tokenizer, case_sensitive):
|
||||
doc = en_tokenizer("spaCy✨ and Prodigy")
|
||||
prefixes = doc.get_affixes(False, 1, 5, "", 2, 3)
|
||||
suffixes = doc.get_affixes(True, 2, 6, "xx✨rP", 2, 3)
|
||||
prefixes = doc.get_affixes(False, case_sensitive, 1, 5, "", 2, 3)
|
||||
suffixes = doc.get_affixes(True, case_sensitive, 2, 6, "xx✨rp", 2, 3)
|
||||
assert prefixes[0][3, 3, 3] == suffixes[0][3, 3, 3]
|
||||
assert prefixes[0][3, 3, 2] == suffixes[0][3, 3, 4]
|
||||
assert (prefixes[0][0, :, 1:] == 0).all()
|
||||
|
@ -991,12 +994,15 @@ def test_get_affixes_good_case(en_tokenizer):
|
|||
assert not (suffixes[0][1, :, 2:] == 0).all()
|
||||
assert (suffixes[0][1, :, 3:] == 0).all()
|
||||
assert suffixes[1][0][1].tolist() == [10024, 0]
|
||||
assert suffixes[1][0][3].tolist() == [114, 112]
|
||||
if case_sensitive:
|
||||
assert suffixes[1][0][3].tolist() == [114, 0]
|
||||
else:
|
||||
assert suffixes[1][0][3].tolist() == [114, 112]
|
||||
|
||||
|
||||
def test_get_affixes_4_byte_normal_char(en_tokenizer):
|
||||
doc = en_tokenizer("and𐌞")
|
||||
suffixes = doc.get_affixes(True, 2, 6, "a", 1, 2)
|
||||
suffixes = doc.get_affixes(True, False, 2, 6, "a", 1, 2)
|
||||
assert (suffixes[0][:, 0, 1] == 55296).all()
|
||||
assert suffixes[0][3, 0, 4] == 97
|
||||
assert suffixes[1][0, 0, 0] == 97
|
||||
|
@ -1005,4 +1011,4 @@ def test_get_affixes_4_byte_normal_char(en_tokenizer):
|
|||
def test_get_affixes_4_byte_special_char(en_tokenizer):
|
||||
doc = en_tokenizer("and𐌞")
|
||||
with pytest.raises(ValueError):
|
||||
doc.get_affixes(True, 2, 6, "𐌞", 2, 3)
|
||||
doc.get_affixes(True, False, 2, 6, "𐌞", 2, 3)
|
||||
|
|
|
@ -9,6 +9,7 @@ from ..attrs cimport attr_id_t
|
|||
|
||||
cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil
|
||||
cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) nogil
|
||||
cdef const unsigned char[:] get_utf16_memoryview(str unicode_string, bint check_2_bytes)
|
||||
|
||||
|
||||
ctypedef const LexemeC* const_Lexeme_ptr
|
||||
|
|
|
@ -94,6 +94,18 @@ cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name)
|
|||
return get_token_attr(token, feat_name)
|
||||
|
||||
|
||||
cdef const unsigned char[:] get_utf16_memoryview(str unicode_string, bint check_2_bytes):
|
||||
"""
|
||||
Returns a memory view of the UTF-16 representation of a string with the default endianness of the platform.
|
||||
Throws a ValueError if *check_2_bytes == True* and one or more characters in the UTF-16 representation
|
||||
occupy four bytes rather than two.
|
||||
"""
|
||||
cdef const unsigned char[:] view = memoryview(unicode_string.encode("UTF-16"))[2:] # first two bytes are endianness
|
||||
if check_2_bytes and len(unicode_string) * 2 != len(view):
|
||||
raise ValueError(Errors.E1044)
|
||||
return view
|
||||
|
||||
|
||||
class SetEntsDefault(str, Enum):
|
||||
blocked = "blocked"
|
||||
missing = "missing"
|
||||
|
@ -1734,52 +1746,57 @@ cdef class Doc:
|
|||
j += 1
|
||||
return output
|
||||
|
||||
def get_affixes(self, bint suffs_not_prefs, int len_start, int len_end, special_chars:str, int sc_len_start, int sc_len_end):
|
||||
def get_affixes(self, bint suffs_not_prefs, bint case_sensitive, unsigned int len_start, unsigned int len_end,
|
||||
str special_chars, unsigned int sc_len_start, unsigned int sc_len_end):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
cdef bytes byte_string
|
||||
cdef np.uint16_t this_char
|
||||
cdef int idx, len_byte_string, sc_char_idx, sc_test_idx, this_len, this_sc_len
|
||||
|
||||
cdef int num_tokens = len(self)
|
||||
cdef bytes sc_enc = special_chars.lower().encode("utf-16BE")
|
||||
cdef int sc_test_len = len(special_chars)
|
||||
if sc_test_len * 2 != len(sc_enc):
|
||||
raise ValueError(Errors.E1044)
|
||||
cdef np.ndarray[np.uint16_t, ndim=1] scs = numpy.empty((sc_test_len,), dtype="uint16")
|
||||
for idx in range(sc_test_len):
|
||||
scs[idx] = (sc_enc[idx*2] << 8) + sc_enc[idx * 2 + 1]
|
||||
|
||||
if case_sensitive:
|
||||
token_attrs = [t.orth_ for t in self]
|
||||
else:
|
||||
token_attrs = [t.lower_ for t in self]
|
||||
cdef unsigned int sc_len = len(special_chars)
|
||||
cdef const unsigned char[:] sc_bytes = get_utf16_memoryview(special_chars, True)
|
||||
cdef np.ndarray[np.uint16_t, ndim=1] scs = numpy.ndarray((sc_len,), buffer=sc_bytes, dtype="uint16")
|
||||
|
||||
cdef unsigned int num_tokens = len(self)
|
||||
cdef np.ndarray[np.uint16_t, ndim=3] outputs = numpy.zeros(
|
||||
(len_end - len_start, num_tokens, len_end - 1), dtype="uint16")
|
||||
cdef np.ndarray[np.uint16_t, ndim=3] sc_outputs = numpy.zeros(
|
||||
(sc_len_end - sc_len_start, num_tokens, sc_len_end - 1), dtype="uint16")
|
||||
|
||||
for token_idx in range(num_tokens):
|
||||
byte_string = self[token_idx].lower_.encode("utf-16BE")
|
||||
idx = 0
|
||||
sc_char_idx = 0
|
||||
len_byte_string = len(byte_string)
|
||||
cdef const unsigned char[:] token_bytes
|
||||
cdef np.uint16_t working_char
|
||||
cdef unsigned int token_bytes_len, token_idx, char_idx, working_len, sc_char_idx, sc_test_idx, working_sc_len
|
||||
cdef unsigned int char_byte_idx
|
||||
|
||||
while (idx < len_end - 1 or sc_char_idx < sc_len_end - 1) and idx * 2 < len_byte_string:
|
||||
char_first_byte_idx = len_byte_string - 2 * (idx + 1) if suffs_not_prefs else idx * 2
|
||||
this_char = (byte_string[char_first_byte_idx] << 8) + byte_string[char_first_byte_idx + 1]
|
||||
for this_len in range(len_end-1, len_start-1, -1):
|
||||
if idx >= this_len:
|
||||
for token_idx in range(num_tokens):
|
||||
token_bytes = get_utf16_memoryview(token_attrs[token_idx], False)
|
||||
char_idx = 0
|
||||
sc_char_idx = 0
|
||||
token_bytes_len = len(token_bytes)
|
||||
|
||||
while (char_idx < len_end - 1 or sc_char_idx < sc_len_end - 1) and char_idx * 2 < token_bytes_len:
|
||||
if suffs_not_prefs:
|
||||
char_byte_idx = token_bytes_len - 2 * (char_idx + 1)
|
||||
else:
|
||||
char_byte_idx = char_idx * 2
|
||||
working_char = (<np.uint16_t*> &token_bytes[char_byte_idx])[0]
|
||||
for working_len in range(len_end-1, len_start-1, -1):
|
||||
if char_idx >= working_len:
|
||||
break
|
||||
outputs[this_len - len_start, token_idx, idx] = this_char
|
||||
outputs[working_len - len_start, token_idx, char_idx] = working_char
|
||||
sc_test_idx = 0
|
||||
while sc_test_len > sc_test_idx:
|
||||
if this_char == scs[sc_test_idx]:
|
||||
for this_sc_len in range(sc_len_end-1, sc_len_start-1, -1):
|
||||
if sc_char_idx >= this_sc_len:
|
||||
while sc_len > sc_test_idx:
|
||||
if working_char == scs[sc_test_idx]:
|
||||
for working_sc_len in range(sc_len_end-1, sc_len_start-1, -1):
|
||||
if sc_char_idx >= working_sc_len:
|
||||
break
|
||||
sc_outputs[this_sc_len - sc_len_start, token_idx, sc_char_idx] = this_char
|
||||
sc_outputs[working_sc_len - sc_len_start, token_idx, sc_char_idx] = working_char
|
||||
sc_char_idx += 1
|
||||
break
|
||||
sc_test_idx += 1
|
||||
idx += 1
|
||||
char_idx += 1
|
||||
return outputs, sc_outputs
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in New Issue
Block a user