Refactor Doc.ents setter to use Doc.set_ents

Additional changes:

* Entity spans with missing labels are ignored
* Fix ent_kb_id setting in `Doc.set_ents`
This commit is contained in:
Adriane Boyd 2020-09-24 12:36:51 +02:00
parent b1a7d6c528
commit 8eaacaae97
3 changed files with 14 additions and 42 deletions

View File

@ -29,10 +29,10 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
ner.begin_training(lambda: [_ner_example(ner)])
ner(doc)
doc.ents = [(doc.vocab.strings["ANIMAL"], 3, 4)]
doc.ents = [("ANIMAL", 3, 4)]
assert [w.ent_iob_ for w in doc] == ["O", "O", "O", "B"]
doc.ents = [(doc.vocab.strings["WORD"], 0, 2)]
doc.ents = [("WORD", 0, 2)]
assert [w.ent_iob_ for w in doc] == ["B", "I", "O", "O"]

View File

@ -534,4 +534,4 @@ def test_doc_ents_setter():
vocab = Vocab()
ents = [("HELLO", 0, 2), (vocab.strings.add("WORLD"), 3, 5)]
doc = Doc(vocab, words=words, ents=ents)
assert [e.label_ for e in doc.ents] == ["HELLO", "WORLD"]
assert [e.label_ for e in doc.ents] == ["HELLO", "WORLD"]

View File

@ -673,49 +673,16 @@ cdef class Doc:
# TODO:
# 1. Test basic data-driven ORTH gazetteer
# 2. Test more nuanced date and currency regex
tokens_in_ents = {}
cdef attr_t entity_type
cdef attr_t kb_id
cdef int ent_start, ent_end, token_index
cdef attr_t entity_type, kb_id
cdef int ent_start, ent_end
ent_spans = []
for ent_info in ents:
entity_type_, kb_id, ent_start, ent_end = get_entity_info(ent_info)
if isinstance(entity_type_, str):
self.vocab.strings.add(entity_type_)
entity_type = self.vocab.strings.as_int(entity_type_)
for token_index in range(ent_start, ent_end):
if token_index in tokens_in_ents:
raise ValueError(Errors.E103.format(
span1=(tokens_in_ents[token_index][0],
tokens_in_ents[token_index][1],
self.vocab.strings[tokens_in_ents[token_index][2]]),
span2=(ent_start, ent_end, self.vocab.strings[entity_type])))
tokens_in_ents[token_index] = (ent_start, ent_end, entity_type, kb_id)
cdef int i
for i in range(self.length):
# default values
entity_type = 0
kb_id = 0
# Set ent_iob to Outside (2) by default
ent_iob = 2
# overwrite if the token was part of a specified entity
if i in tokens_in_ents.keys():
ent_start, ent_end, entity_type, kb_id = tokens_in_ents[i]
if entity_type is None or entity_type <= 0:
# Only allow labelled spans
print(i, ent_start, ent_end, entity_type)
raise ValueError(Errors.E1013)
elif ent_start == i:
# Marking the start of an entity
ent_iob = 3
else:
# Marking the inside of an entity
ent_iob = 1
self.c[i].ent_type = entity_type
self.c[i].ent_kb_id = kb_id
self.c[i].ent_iob = ent_iob
span = Span(self, ent_start, ent_end, label=entity_type_, kb_id=kb_id)
ent_spans.append(span)
self.set_ents(ent_spans, default=SetEntsDefault.outside)
def set_ents(self, entities, *, blocked=None, missing=None, outside=None, default=SetEntsDefault.outside):
"""Set entity annotation.
@ -734,6 +701,9 @@ cdef class Doc:
if default not in SetEntsDefault.values():
raise ValueError(Errors.E1011.format(default=default, modes=", ".join(SetEntsDefault)))
# Ignore spans with missing labels
entities = [ent for ent in entities if ent.label > 0]
if blocked is None:
blocked = tuple()
if missing is None:
@ -742,6 +712,7 @@ cdef class Doc:
outside = tuple()
# Find all tokens covered by spans and check that none are overlapping
cdef int i
seen_tokens = set()
for span in itertools.chain.from_iterable([entities, blocked, missing, outside]):
if not isinstance(span, Span):
@ -761,6 +732,7 @@ cdef class Doc:
else:
self.c[i].ent_iob = 1
self.c[i].ent_type = span.label
self.c[i].ent_kb_id = span.kb_id
for span in blocked:
for i in range(span.start, span.end):
self.c[i].ent_iob = 3