Fix handling of NER data in Example

This commit is contained in:
Matthew Honnibal 2020-06-24 18:03:24 +02:00
parent 359e874766
commit 7eb064854e

View File

@ -4,6 +4,8 @@ import numpy
from ..tokens import Token from ..tokens import Token
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..tokens.span cimport Span
from ..tokens.span import Span
from ..attrs import IDS from ..attrs import IDS
from .align cimport Alignment from .align cimport Alignment
from .iob_utils import biluo_to_iob, biluo_tags_from_offsets, biluo_tags_from_doc from .iob_utils import biluo_to_iob, biluo_tags_from_offsets, biluo_tags_from_doc
@ -19,6 +21,8 @@ cpdef Doc annotations2doc(vocab, tok_annot, doc_annot):
output = Doc(vocab, words=tok_annot["ORTH"], spaces=tok_annot["SPACY"]) output = Doc(vocab, words=tok_annot["ORTH"], spaces=tok_annot["SPACY"])
if array.size: if array.size:
output = output.from_array(attrs, array) output = output.from_array(attrs, array)
if "entities" in doc_annot:
_add_entities_to_doc(output, doc_annot["entities"])
# TODO: links ?! # TODO: links ?!
output.cats.update(doc_annot.get("cats", {})) output.cats.update(doc_annot.get("cats", {}))
return output return output
@ -99,29 +103,6 @@ cdef class Example:
output[i] = None output[i] = None
else: else:
output[i] = gold_values[gold_i] output[i] = gold_values[gold_i]
if field in ["ENT_IOB"]:
# Fix many-to-one IOB codes
prev_j = -1
prev_value = -1
for i, value in enumerate(output):
if i in i2j_multi:
j = i2j_multi[i]
if j == prev_j and prev_value == value == 3:
output[i] = 1 # set B to I
prev_j = j
else:
prev_j = -1
prev_value = value
if field in ["ENT_IOB", "ENT_TYPE", "ENT_KB_ID"]:
# Assign one-to-many NER tags
for j, cand_j in enumerate(gold_to_cand):
if cand_j is None:
if j in j2i_multi:
i = j2i_multi[j]
if output[i] is None:
output[i] = gold_values[j]
if as_string and field not in ["ENT_IOB", "SENT_START"]: if as_string and field not in ["ENT_IOB", "SENT_START"]:
output = [vocab.strings[o] if o is not None else o for o in output] output = [vocab.strings[o] if o is not None else o for o in output]
return output return output
@ -145,15 +126,30 @@ cdef class Example:
def get_aligned_ner(self): def get_aligned_ner(self):
x_ents = [] x_ents = []
gold_to_cand = self.alignment.gold_to_cand
for y_ent in self.y.ents: for y_ent in self.y.ents:
x_span = self.x.char_span(y_ent.start_char, y_ent.end_char, label=y_ent.label) x_start = gold_to_cand[y_ent.start]
if x_span is not None: x_end = gold_to_cand[y_ent.end-1]
x_ents.append(x_span) if x_start is not None and x_end is not None:
x_ents.append(Span(self.x, x_start, x_end+1, label=y_ent.label))
else:
x_span = self.x.char_span(
y_ent.start_char,
y_ent.end_char,
label=y_ent.label
)
if x_span is not None:
x_ents.append(x_span)
x_tags = biluo_tags_from_offsets( x_tags = biluo_tags_from_offsets(
self.x, self.x,
[(e.start_char, e.end_char, e.label_) for e in x_ents], [(e.start_char, e.end_char, e.label_) for e in x_ents],
missing="O" missing="O"
) )
for token in self.y:
if token.ent_iob == 0:
cand_i = gold_to_cand[token.i]
if cand_i is not None:
x_tags[cand_i] = None
return x_tags return x_tags
def to_dict(self): def to_dict(self):
@ -222,11 +218,7 @@ def _annot2array(vocab, tok_annot, doc_annot):
for key, value in doc_annot.items(): for key, value in doc_annot.items():
if value: if value:
if key == "entities": if key == "entities":
words = tok_annot["ORTH"] pass
spaces = tok_annot["SPACY"]
ent_iobs, ent_types = _parse_ner_tags(value, vocab, words, spaces)
tok_annot["ENT_IOB"] = ent_iobs
tok_annot["ENT_TYPE"] = ent_types
elif key == "links": elif key == "links":
entities = doc_annot.get("entities", {}) entities = doc_annot.get("entities", {})
if value and not entities: if value and not entities:
@ -252,13 +244,6 @@ def _annot2array(vocab, tok_annot, doc_annot):
elif key == "MORPH": elif key == "MORPH":
attrs.append(key) attrs.append(key)
values.append([vocab.morphology.add(v) for v in value]) values.append([vocab.morphology.add(v) for v in value])
elif key == "ENT_IOB":
iob_strings = Token.iob_strings()
attrs.append(key)
try:
values.append([iob_strings.index(v) for v in value])
except ValueError:
raise ValueError(Errors.E982.format(values=iob_strings, value=values))
else: else:
attrs.append(key) attrs.append(key)
values.append([vocab.strings.add(v) for v in value]) values.append([vocab.strings.add(v) for v in value])
@ -267,6 +252,29 @@ def _annot2array(vocab, tok_annot, doc_annot):
return attrs, array.T return attrs, array.T
def _add_entities_to_doc(doc, ner_data):
if ner_data is None:
return
elif ner_data == []:
doc.ents = []
elif isinstance(ner_data[0], tuple):
return _add_entities_to_doc(
doc,
biluo_tags_from_offsets(doc, ner_data)
)
elif isinstance(ner_data[0], str) or ner_data[0] is None:
return _add_entities_to_doc(
doc,
spans_from_biluo_tags(doc, ner_data)
)
elif isinstance(ner_data[0], Span):
# Ugh, this is super messy. Really hard to set O entities
doc.ents = ner_data
doc.ents = [span for span in ner_data if span.label_]
else:
raise ValueError("Unexpected type for NER data")
def _parse_example_dict_data(example_dict): def _parse_example_dict_data(example_dict):
return ( return (
example_dict["token_annotation"], example_dict["token_annotation"],
@ -332,7 +340,7 @@ def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces):
ent_iobs = [] ent_iobs = []
ent_types = [] ent_types = []
for iob_tag in biluo_to_iob(biluo): for iob_tag in biluo_to_iob(biluo):
if iob_tag is None: if iob_tag in (None, "-"):
ent_iobs.append("") ent_iobs.append("")
ent_types.append("") ent_types.append("")
else: else: