Simplify NER alignment

This commit is contained in:
Matthew Honnibal 2020-06-23 23:57:54 +02:00
parent 10eb44d76a
commit b82431207d

View File

@ -144,58 +144,16 @@ cdef class Example:
return aligned_heads, aligned_deps
def get_aligned_ner(self):
cand_to_gold = self.alignment.cand_to_gold
gold_to_cand = self.alignment.gold_to_cand
i2j_multi = self.alignment.i2j_multi
j2i_multi = self.alignment.j2i_multi
y_tags = biluo_tags_from_offsets(
self.y,
[(e.start_char, e.end_char, e.label_) for e in self.y.ents]
x_ents = []
for y_ent in self.y.ents:
x_span = self.x.char_span(y_ent.start_char, y_ent.end_char, label=y_ent.label)
if x_span is not None:
x_ents.append(x_span)
x_tags = biluo_tags_from_offsets(
self.x,
[(e.start_char, e.end_char, e.label_) for e in x_ents],
missing="O"
)
x_tags = [None] * self.x.length
for i in range(self.x.length):
if self.x[i].is_space:
pass
elif cand_to_gold[i] is not None:
x_tags[i] = y_tags[cand_to_gold[i]]
elif i in i2j_multi:
# Assign O/- for many-to-one O/- NER tags
if y_tags[i2j_multi[i]] in ("O", "-"):
x_tags[i] = y_tags[i2j_multi[i]]
# Assign O/- for one-to-many O/- NER tags
for gold_i, cand_i in enumerate(gold_to_cand):
if y_tags[gold_i] in ("O", "-"):
if cand_i is None and gold_i in j2i_multi:
x_tags[j2i_multi[gold_i]] = y_tags[gold_i]
# TODO: I'm copying this over from v2.x but this seems kind of nuts?
# If there is entity annotation and some tokens remain unaligned,
# align all entities at the character level to account for all
# possible token misalignments within the entity spans
if list(self.y.ents) and None in x_tags:
# Get offsets based on gold words and BILUO entities
aligned_offsets = []
aligned_spans = []
# Filter offsets to identify those that align with doc tokens
for span in spans_from_biluo_tags(self.x, x_tags):
if span and not span.text.isspace():
aligned_offsets.append(
(span.start_char, span.end_char, span.label_)
)
aligned_spans.append(span)
# Convert back to BILUO for doc tokens and assign NER for all
# aligned spans
aligned_tags = biluo_tags_from_offsets(self.x, aligned_offsets, missing=None)
for span in aligned_spans:
for i in range(span.start, span.end):
x_tags[i] = aligned_tags[i]
# Prevent whitespace that isn't within entities from being tagged as
# an entity.
for i, token in enumerate(self.x):
if token.is_space:
prev_ner = x_tags[i] if i >= 1 else None
next_ner = x_tags[i+1] if (i+1) < self.x.length else None
if prev_ner == "O" or next_ner == "O":
x_tags[i] = "O"
return x_tags
def to_dict(self):