mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Update test for #3012
This commit is contained in:
parent
3d64eb4a74
commit
bf20252ae0
|
@ -8,33 +8,26 @@ from ...tokens import Doc
|
||||||
from ..util import get_doc
|
from ..util import get_doc
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail
|
|
||||||
def test_issue3012(en_vocab):
|
def test_issue3012(en_vocab):
|
||||||
|
"""Test that the is_tagged attribute doesn't get overwritten when we from_array
|
||||||
|
without tag information."""
|
||||||
words = ["This", "is", "10", "%", "."]
|
words = ["This", "is", "10", "%", "."]
|
||||||
tags = ["DT", "VBZ", "CD", "NN", "."]
|
tags = ["DT", "VBZ", "CD", "NN", "."]
|
||||||
pos = ["DET", "VERB", "NUM", "NOUN", "PUNCT"]
|
pos = ["DET", "VERB", "NUM", "NOUN", "PUNCT"]
|
||||||
ents = [(2, 4, "PERCENT")]
|
ents = [(2, 4, "PERCENT")]
|
||||||
doc = get_doc(en_vocab, words=words, tags=tags, pos=pos, ents=ents)
|
doc = get_doc(en_vocab, words=words, tags=tags, pos=pos, ents=ents)
|
||||||
|
assert doc.is_tagged
|
||||||
|
|
||||||
expected = ("10", "NUM", "CD", "PERCENT")
|
expected = ("10", "NUM", "CD", "PERCENT")
|
||||||
assert (doc[2].text, doc[2].pos_, doc[2].tag_, doc[2].ent_type_) == expected
|
assert (doc[2].text, doc[2].pos_, doc[2].tag_, doc[2].ent_type_) == expected
|
||||||
|
|
||||||
# now removing '10%' entity
|
|
||||||
assert len(doc.ents) == 1
|
|
||||||
header = [ENT_IOB, ENT_TYPE]
|
header = [ENT_IOB, ENT_TYPE]
|
||||||
ent_array = doc.to_array(header)
|
ent_array = doc.to_array(header)
|
||||||
ent = doc.ents[0]
|
|
||||||
for idx in range(ent.start, ent.end):
|
|
||||||
ent_array[idx, 0] = 0
|
|
||||||
ent_array[idx, 1] = 0
|
|
||||||
doc.from_array(header, ent_array)
|
doc.from_array(header, ent_array)
|
||||||
assert len(doc.ents) == 0
|
|
||||||
|
|
||||||
assert (doc[2].text, doc[2].pos_, doc[2].tag_, doc[2].ent_type_) == expected
|
assert (doc[2].text, doc[2].pos_, doc[2].tag_, doc[2].ent_type_) == expected
|
||||||
|
|
||||||
# serializing then deserializing
|
# serializing then deserializing
|
||||||
doc_bytes = doc.to_bytes()
|
doc_bytes = doc.to_bytes()
|
||||||
doc2 = Doc(en_vocab).from_bytes(doc_bytes)
|
doc2 = Doc(en_vocab).from_bytes(doc_bytes)
|
||||||
assert len(doc2.ents) == 0
|
|
||||||
|
|
||||||
assert (doc[2].text, doc[2].pos_, doc[2].tag_, doc[2].ent_type_) == expected
|
assert (doc[2].text, doc[2].pos_, doc[2].tag_, doc[2].ent_type_) == expected
|
||||||
|
|
Loading…
Reference in New Issue
Block a user