Implement Doc.set_ents

This commit is contained in:
Adriane Boyd 2020-09-21 15:54:05 +02:00
parent 13fbf6556a
commit 177df15d89
6 changed files with 192 additions and 21 deletions

View File

@ -682,6 +682,15 @@ class Errors:
E1009 = ("String for hash '{val}' not found in StringStore. Set the value "
"through token.morph_ instead or add the string to the "
"StringStore with `nlp.vocab.strings.add(string)`.")
E1010 = ("Unable to set entity information for token {i} which is included "
"in more than one span in entities, blocked, missing or outside.")
E1011 = ("Unsupported default '{default}' in doc.set_ents. Available "
"options: {modes}")
E1012 = ("Spans provided to doc.set_ents must be provided as a list of "
"`Span` objects.")
E1013 = ("Unable to set entity for span with empty label. Entity spans are "
"required to have a label. To set entity information as missing "
"or blocked, use the keyword arguments with doc.set_ents.")
@add_codes

View File

@ -425,7 +425,7 @@ def test_has_annotation(en_vocab):
doc[0].lemma_ = "a"
doc[0].dep_ = "dep"
doc[0].head = doc[1]
doc.ents = [Span(doc, 0, 1, label="HELLO"), Span(doc, 1, 2, label="")]
doc.set_ents([Span(doc, 0, 1, label="HELLO")], default="missing")
for attr in attrs:
assert doc.has_annotation(attr)
@ -455,15 +455,68 @@ def test_is_flags_deprecated(en_tokenizer):
doc.is_sentenced
def test_block_ents(en_tokenizer):
def test_set_ents(en_tokenizer):
# set ents
doc = en_tokenizer("a b c d e")
doc.block_ents([doc[1:2], doc[3:5]])
doc.set_ents([Span(doc, 0, 1, 10), Span(doc, 1, 3, 11)])
assert [t.ent_iob for t in doc] == [3, 3, 1, 2, 2]
assert [t.ent_type for t in doc] == [10, 11, 11, 0, 0]
# add ents, invalid IOB repaired
doc = en_tokenizer("a b c d e")
doc.set_ents([Span(doc, 0, 1, 10), Span(doc, 1, 3, 11)])
doc.set_ents([Span(doc, 0, 2, 12)], default="unmodified")
assert [t.ent_iob for t in doc] == [3, 1, 3, 2, 2]
assert [t.ent_type for t in doc] == [12, 12, 11, 0, 0]
# missing ents
doc = en_tokenizer("a b c d e")
doc.set_ents([Span(doc, 0, 1, 10), Span(doc, 1, 3, 11)], missing=[doc[4:5]])
assert [t.ent_iob for t in doc] == [3, 3, 1, 2, 0]
assert [t.ent_type for t in doc] == [10, 11, 11, 0, 0]
# outside ents
doc = en_tokenizer("a b c d e")
doc.set_ents(
[Span(doc, 0, 1, 10), Span(doc, 1, 3, 11)],
outside=[doc[4:5]],
default="missing",
)
assert [t.ent_iob for t in doc] == [3, 3, 1, 0, 2]
assert [t.ent_type for t in doc] == [10, 11, 11, 0, 0]
# blocked ents
doc = en_tokenizer("a b c d e")
doc.set_ents([], blocked=[doc[1:2], doc[3:5]], default="unmodified")
assert [t.ent_iob for t in doc] == [0, 3, 0, 3, 3]
assert [t.ent_type for t in doc] == [0, 0, 0, 0, 0]
assert doc.ents == tuple()
# invalid IOB repaired
# invalid IOB repaired after blocked
doc.ents = [Span(doc, 3, 5, "ENT")]
assert [t.ent_iob for t in doc] == [2, 2, 2, 3, 1]
doc.block_ents([doc[3:4]])
doc.set_ents([], blocked=[doc[3:4]], default="unmodified")
assert [t.ent_iob for t in doc] == [2, 2, 2, 3, 3]
# all types
doc = en_tokenizer("a b c d e")
doc.set_ents(
[Span(doc, 0, 1, 10)],
blocked=[doc[1:2]],
missing=[doc[2:3]],
outside=[doc[3:4]],
default="unmodified",
)
assert [t.ent_iob for t in doc] == [3, 3, 0, 2, 0]
assert [t.ent_type for t in doc] == [10, 0, 0, 0, 0]
doc = en_tokenizer("a b c d e")
# single span instead of a list
with pytest.raises(ValueError):
doc.set_ents([], missing=doc[1:2])
# invalid default mode
with pytest.raises(ValueError):
doc.set_ents([], missing=[doc[1:2]], default="none")
# conflicting/overlapping specifications
with pytest.raises(ValueError):
doc.set_ents([], missing=[doc[1:2]], outside=[doc[1:2]])

View File

@ -168,7 +168,7 @@ def test_accept_blocked_token():
ner2 = nlp2.create_pipe("ner", config=config)
# set "New York" to a blocked entity
doc2.block_ents([doc2[3:5]])
doc2.set_ents([], blocked=[doc2[3:5]], default="unmodified")
assert [token.ent_iob_ for token in doc2] == ["", "", "", "B", "B"]
assert [token.ent_type_ for token in doc2] == ["", "", "", "", ""]
@ -358,5 +358,5 @@ class BlockerComponent1:
self.name = name
def __call__(self, doc):
doc.block_ents([doc[self.start:self.end]])
doc.set_ents([], blocked=[doc[self.start:self.end]], default="unmodified")
return doc

View File

@ -7,6 +7,7 @@ from libc.stdint cimport int32_t, uint64_t
import copy
from collections import Counter
from enum import Enum
import numpy
import srsly
from thinc.api import get_array_module
@ -86,6 +87,17 @@ cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name)
return get_token_attr(token, feat_name)
class SetEntsDefault(str, Enum):
blocked = "blocked"
missing = "missing"
outside = "outside"
unmodified = "unmodified"
@classmethod
def values(cls):
return list(cls.__members__.keys())
cdef class Doc:
"""A sequence of Token objects. Access sentences and named entities, export
annotations to numpy arrays, losslessly serialize to compressed binary
@ -597,9 +609,9 @@ cdef class Doc:
if i in tokens_in_ents.keys():
ent_start, ent_end, entity_type, kb_id = tokens_in_ents[i]
if entity_type is None or entity_type <= 0:
# Empty label: Missing, unset this token
ent_iob = 0
entity_type = 0
# Only allow labelled spans
print(i, ent_start, ent_end, entity_type)
raise ValueError(Errors.E1013)
elif ent_start == i:
# Marking the start of an entity
ent_iob = 3
@ -611,19 +623,107 @@ cdef class Doc:
self.c[i].ent_kb_id = kb_id
self.c[i].ent_iob = ent_iob
def block_ents(self, spans):
"""Mark spans as never an entity for the EntityRecognizer.
def set_ents(self, entities, *, blocked=None, missing=None, outside=None, default=SetEntsDefault.outside):
"""Set entity annotation.
spans (List[Span]): The spans to block as never entities.
entities (List[Span]): Spans with labels to set as entities.
blocked (Optional[List[Span]]): Spans to set as 'blocked' (never an
entity) for spacy's built-in NER component. Other components may
ignore this setting.
missing (Optional[List[Span]]): Spans with missing/unknown entity
information.
outside (Optional[List[Span]]): Spans outside of entities (O in IOB).
default (str): How to set entity annotation for tokens outside of any
provided spans. Options: "blocked", "missing", "outside" and
"unmodified" (preserve current state). Defaults to "outside".
"""
for span in spans:
if default not in SetEntsDefault.values():
raise ValueError(Errors.E1011.format(default=default, modes=", ".join(SetEntsDefault)))
if blocked is None:
blocked = tuple()
if missing is None:
missing = tuple()
if outside is None:
outside = tuple()
# Find all tokens covered by spans and check that none are overlapping
seen_tokens = set()
for span in entities:
if not isinstance(span, Span):
raise ValueError(Errors.E1012.format(span=span))
for i in range(span.start, span.end):
if i in seen_tokens:
raise ValueError(Errors.E1010.format(i=i))
seen_tokens.add(i)
for span in blocked:
if not isinstance(span, Span):
raise ValueError(Errors.E1012.format(span=span))
for i in range(span.start, span.end):
if i in seen_tokens:
raise ValueError(Errors.E1010.format(i=i))
seen_tokens.add(i)
for span in missing:
if not isinstance(span, Span):
raise ValueError(Errors.E1012.format(span=span))
for i in range(span.start, span.end):
if i in seen_tokens:
raise ValueError(Errors.E1010.format(i=i))
seen_tokens.add(i)
for span in outside:
if not isinstance(span, Span):
raise ValueError(Errors.E1012.format(span=span))
for i in range(span.start, span.end):
if i in seen_tokens:
raise ValueError(Errors.E1010.format(i=i))
seen_tokens.add(i)
# Set all specified entity information
for span in entities:
for i in range(span.start, span.end):
if not span.label:
raise ValueError(Errors.E1013)
if i == span.start:
self.c[i].ent_iob = 3
else:
self.c[i].ent_iob = 1
self.c[i].ent_type = span.label
for span in blocked:
for i in range(span.start, span.end):
self.c[i].ent_iob = 3
self.c[i].ent_type = 0
# if the following token is I, set to B
if span.end < self.length:
if self.c[span.end].ent_iob == 1:
self.c[span.end].ent_iob = 3
for span in missing:
for i in range(span.start, span.end):
self.c[i].ent_iob = 0
self.c[i].ent_type = 0
for span in outside:
for i in range(span.start, span.end):
self.c[i].ent_iob = 2
self.c[i].ent_type = 0
# Set tokens outside of all provided spans
if default != SetEntsDefault.unmodified:
for i in range(self.length):
if i not in seen_tokens:
self.c[i].ent_type = 0
if default == SetEntsDefault.outside:
self.c[i].ent_iob = 2
elif default == SetEntsDefault.missing:
self.c[i].ent_iob = 0
elif default == SetEntsDefault.blocked:
self.c[i].ent_iob = 3
# Fix any resulting inconsistent annotation
for i in range(self.length - 1):
# I must follow B or I: convert I to B
if (self.c[i].ent_iob == 0 or self.c[i].ent_iob == 2) and \
self.c[i+1].ent_iob == 1:
self.c[i+1].ent_iob = 3
# Change of type with BI or II: convert second I to B
if self.c[i].ent_type != self.c[i+1].ent_type and \
(self.c[i].ent_iob == 3 or self.c[i].ent_iob == 1) and \
self.c[i+1].ent_iob == 1:
self.c[i+1].ent_iob = 3
@property
def noun_chunks(self):

View File

@ -288,6 +288,7 @@ def _annot2array(vocab, tok_annot, doc_annot):
def _add_entities_to_doc(doc, ner_data):
print(ner_data)
if ner_data is None:
return
elif ner_data == []:
@ -303,7 +304,14 @@ def _add_entities_to_doc(doc, ner_data):
spans_from_biluo_tags(doc, ner_data)
)
elif isinstance(ner_data[0], Span):
doc.ents = ner_data
entities = []
missing = []
for span in ner_data:
if span.label:
entities.append(span)
else:
missing.append(span)
doc.set_ents(entities, missing=missing)
else:
raise ValueError(Errors.E973)

View File

@ -149,9 +149,10 @@ def spans_from_biluo_tags(doc, tags):
doc (Doc): The document that the BILUO tags refer to.
entities (iterable): A sequence of BILUO tags with each tag describing one
token. Each tags string will be of the form of either "", "O" or
token. Each tag string will be of the form of either "", "O" or
"{action}-{label}", where action is one of "B", "I", "L", "U".
RETURNS (list): A sequence of Span objects.
RETURNS (list): A sequence of Span objects. Each token with a missing IOB
tag is returned as a Span with an empty label.
"""
token_offsets = tags_to_entities(tags)
spans = []