Modify setting missing and blocked entity tokens

In order to make it easier to construct `Doc` objects as training data,
modify how missing and blocked entity tokens are set to prioritize
setting `O` and missing entity tokens for training purposes over setting
blocked entity tokens.

* `Doc.ents` setter sets tokens outside entity spans to `O` regardless
of the current state of each token

* For `Doc.ents`, setting a span with a missing label sets the `ent_iob`
to missing instead of blocked

* `Doc.block_ents(spans)` marks spans as hard `O` for use with the
`EntityRecognizer`
This commit is contained in:
Adriane Boyd 2020-09-17 21:10:41 +02:00
parent 8303d101a5
commit 8b650f3a78
5 changed files with 42 additions and 21 deletions

View File

@ -137,7 +137,7 @@ def test_doc_api_set_ents(en_tokenizer):
assert len(tokens.ents) == 0 assert len(tokens.ents) == 0
tokens.ents = [(tokens.vocab.strings["PRODUCT"], 2, 4)] tokens.ents = [(tokens.vocab.strings["PRODUCT"], 2, 4)]
assert len(list(tokens.ents)) == 1 assert len(list(tokens.ents)) == 1
assert [t.ent_iob for t in tokens] == [0, 0, 3, 1, 0, 0, 0, 0] assert [t.ent_iob for t in tokens] == [2, 2, 3, 1, 2, 2, 2, 2]
assert tokens.ents[0].label_ == "PRODUCT" assert tokens.ents[0].label_ == "PRODUCT"
assert tokens.ents[0].start == 2 assert tokens.ents[0].start == 2
assert tokens.ents[0].end == 4 assert tokens.ents[0].end == 4
@ -426,7 +426,7 @@ def test_has_annotation(en_vocab):
doc[0].lemma_ = "a" doc[0].lemma_ = "a"
doc[0].dep_ = "dep" doc[0].dep_ = "dep"
doc[0].head = doc[1] doc[0].head = doc[1]
doc.ents = [Span(doc, 0, 1, label="HELLO")] doc.ents = [Span(doc, 0, 1, label="HELLO"), Span(doc, 1, 2, label="")]
for attr in attrs: for attr in attrs:
assert doc.has_annotation(attr) assert doc.has_annotation(attr)
@ -454,3 +454,17 @@ def test_is_flags_deprecated(en_tokenizer):
doc.is_nered doc.is_nered
with pytest.deprecated_call(): with pytest.deprecated_call():
doc.is_sentenced doc.is_sentenced
def test_block_ents(en_tokenizer):
doc = en_tokenizer("a b c d e")
doc.block_ents([doc[1:2], doc[3:5]])
assert [t.ent_iob for t in doc] == [0, 3, 0, 3, 3]
assert [t.ent_type for t in doc] == [0, 0, 0, 0, 0]
assert doc.ents == tuple()
# invalid IOB repaired
doc.ents = [Span(doc, 3, 5, "ENT")]
assert [t.ent_iob for t in doc] == [2, 2, 2, 3, 1]
doc.block_ents([doc[3:4]])
assert [t.ent_iob for t in doc] == [2, 2, 2, 3, 3]

View File

@ -168,7 +168,7 @@ def test_accept_blocked_token():
ner2 = nlp2.create_pipe("ner", config=config) ner2 = nlp2.create_pipe("ner", config=config)
# set "New York" to a blocked entity # set "New York" to a blocked entity
doc2.ents = [(0, 3, 5)] doc2.block_ents([doc2[3:5]])
assert [token.ent_iob_ for token in doc2] == ["", "", "", "B", "B"] assert [token.ent_iob_ for token in doc2] == ["", "", "", "B", "B"]
assert [token.ent_type_ for token in doc2] == ["", "", "", "", ""] assert [token.ent_type_ for token in doc2] == ["", "", "", "", ""]
@ -358,5 +358,5 @@ class BlockerComponent1:
self.name = name self.name = name
def __call__(self, doc): def __call__(self, doc):
doc.ents = [(0, self.start, self.end)] doc.block_ents([doc[self.start:self.end]])
return doc return doc

View File

@ -590,17 +590,16 @@ cdef class Doc:
entity_type = 0 entity_type = 0
kb_id = 0 kb_id = 0
# Set ent_iob to Missing (0) by default unless this token was nered before # Set ent_iob to Outside (2) by default
ent_iob = 0
if self.c[i].ent_iob != 0:
ent_iob = 2 ent_iob = 2
# overwrite if the token was part of a specified entity # overwrite if the token was part of a specified entity
if i in tokens_in_ents.keys(): if i in tokens_in_ents.keys():
ent_start, ent_end, entity_type, kb_id = tokens_in_ents[i] ent_start, ent_end, entity_type, kb_id = tokens_in_ents[i]
if entity_type is None or entity_type <= 0: if entity_type is None or entity_type <= 0:
# Blocking this token from being overwritten by downstream NER # Empty label: Missing, unset this token
ent_iob = 3 ent_iob = 0
entity_type = 0
elif ent_start == i: elif ent_start == i:
# Marking the start of an entity # Marking the start of an entity
ent_iob = 3 ent_iob = 3
@ -612,6 +611,20 @@ cdef class Doc:
self.c[i].ent_kb_id = kb_id self.c[i].ent_kb_id = kb_id
self.c[i].ent_iob = ent_iob self.c[i].ent_iob = ent_iob
def block_ents(self, spans):
"""Mark spans as never an entity for the EntityRecognizer.
spans (List[Span]): The spans to block as never entities.
"""
for span in spans:
for i in range(span.start, span.end):
self.c[i].ent_iob = 3
self.c[i].ent_type = 0
# if the following token is I, set to B
if span.end < self.length:
if self.c[span.end].ent_iob == 1:
self.c[span.end].ent_iob = 3
@property @property
def noun_chunks(self): def noun_chunks(self):
"""Iterate over the base noun phrases in the document. Yields base """Iterate over the base noun phrases in the document. Yields base

View File

@ -172,7 +172,7 @@ cdef class Example:
return output return output
def get_aligned_ner(self): def get_aligned_ner(self):
if not self.y.is_nered: if not self.y.has_annotation("ENT_IOB"):
return [None] * len(self.x) # should this be 'missing' instead of 'None' ? return [None] * len(self.x) # should this be 'missing' instead of 'None' ?
x_ents = self.get_aligned_spans_y2x(self.y.ents) x_ents = self.get_aligned_spans_y2x(self.y.ents)
# Default to 'None' for missing values # Default to 'None' for missing values
@ -303,9 +303,7 @@ def _add_entities_to_doc(doc, ner_data):
spans_from_biluo_tags(doc, ner_data) spans_from_biluo_tags(doc, ner_data)
) )
elif isinstance(ner_data[0], Span): elif isinstance(ner_data[0], Span):
# Ugh, this is super messy. Really hard to set O entities
doc.ents = ner_data doc.ents = ner_data
doc.ents = [span for span in ner_data if span.label_]
else: else:
raise ValueError(Errors.E973) raise ValueError(Errors.E973)

View File

@ -182,22 +182,18 @@ def tags_to_entities(tags):
entities = [] entities = []
start = None start = None
for i, tag in enumerate(tags): for i, tag in enumerate(tags):
if tag is None: if tag is None or tag.startswith("-"):
continue
if tag.startswith("O"):
# TODO: We shouldn't be getting these malformed inputs. Fix this. # TODO: We shouldn't be getting these malformed inputs. Fix this.
if start is not None: if start is not None:
start = None start = None
else: else:
entities.append(("", i, i)) entities.append(("", i, i))
continue elif tag.startswith("O"):
elif tag == "-": pass
continue
elif tag.startswith("I"): elif tag.startswith("I"):
if start is None: if start is None:
raise ValueError(Errors.E067.format(start="I", tags=tags[: i + 1])) raise ValueError(Errors.E067.format(start="I", tags=tags[: i + 1]))
continue elif tag.startswith("U"):
if tag.startswith("U"):
entities.append((tag[2:], i, i)) entities.append((tag[2:], i, i))
elif tag.startswith("B"): elif tag.startswith("B"):
start = i start = i