mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-02 19:30:19 +03:00
Add tests for get_affixes
This commit is contained in:
parent
b2074a15d1
commit
910a6bc98f
|
@ -975,3 +975,40 @@ def test_doc_spans_setdefault(en_tokenizer):
|
|||
assert len(doc.spans["key2"]) == 1
|
||||
doc.spans.setdefault("key3", default=SpanGroup(doc, spans=[doc[0:1], doc[1:2]]))
|
||||
assert len(doc.spans["key3"]) == 2
|
||||
|
||||
|
||||
def test_get_affixes_good_case(en_tokenizer):
|
||||
doc = en_tokenizer("spaCy✨ and Prodigy")
|
||||
prefixes = doc.get_affixes(False, 1, 5, "", 2, 3)
|
||||
suffixes = doc.get_affixes(True, 2, 6, "xx✨rP", 2, 3)
|
||||
assert prefixes[0][3, 3, 6] == suffixes[0][3, 3, 6]
|
||||
assert prefixes[0][3, 3, 7] == suffixes[0][3, 3, 7]
|
||||
assert prefixes[0][3, 3, 4] == suffixes[0][3, 3, 8]
|
||||
assert prefixes[0][3, 3, 5] == suffixes[0][3, 3, 9]
|
||||
assert (prefixes[0][0, :, 2:] == 0).all()
|
||||
assert not (suffixes[0][0, :, 2:] == 0).all()
|
||||
assert (suffixes[0][0, :, 4:] == 0).all()
|
||||
assert (prefixes[0][1, :, 4:] == 0).all()
|
||||
assert (prefixes[0][:, 1, 2:] == 0).all()
|
||||
assert not (suffixes[0][1, :, 4:] == 0).all()
|
||||
assert (suffixes[0][1, :, 6:] == 0).all()
|
||||
assert prefixes[0][0][0][0] == 0
|
||||
assert prefixes[0][0][1][0] != 0
|
||||
assert (prefixes[1] == 0).all()
|
||||
assert (suffixes[1][0][0] == 0).all()
|
||||
assert suffixes[1][0][1].tolist() == [39, 40, 0, 0]
|
||||
assert suffixes[1][0][3].tolist() == [0, 114, 0, 80]
|
||||
|
||||
|
||||
def test_get_affixes_4_byte_normal_char(en_tokenizer):
|
||||
doc = en_tokenizer("and𐌞")
|
||||
suffixes = doc.get_affixes(True, 2, 6, "a", 1, 2)
|
||||
assert (suffixes[0][:, 0, 2] == 216).all()
|
||||
assert suffixes[0][3, 0, 9] == 97
|
||||
assert suffixes[1][0, 0, 1] == 97
|
||||
|
||||
|
||||
def test_get_affixes_4_byte_special_char(en_tokenizer):
|
||||
doc = en_tokenizer("and𐌞")
|
||||
with pytest.raises(ValueError):
|
||||
doc.get_affixes(True, 2, 6, "𐌞", 2, 3)
|
||||
|
|
|
@ -1743,6 +1743,8 @@ cdef class Doc:
|
|||
|
||||
special_chars_enc = special_chars.encode('utf-16BE')
|
||||
cdef int sc_test_len = len(special_chars)
|
||||
if sc_test_len * 2 != len(special_chars_enc):
|
||||
raise ValueError(Errors.E1044)
|
||||
|
||||
cdef np.ndarray[np.uint8_t, ndim=3] outputs = numpy.zeros(
|
||||
(len_end - len_start, num_tokens, (len_end - 1) * 2), dtype="uint8")
|
||||
|
|
Loading…
Reference in New Issue
Block a user