fixing NER one-to-many alignment

This commit is contained in:
svlandeg 2020-06-15 22:44:17 +02:00
parent a0bf73a5dd
commit 12886b787b
3 changed files with 26 additions and 22 deletions

View File

@ -71,12 +71,13 @@ cdef class Example:
self._alignment = Alignment(spacy_words, gold_words)
return self._alignment
def get_aligned(self, field):
def get_aligned(self, field, as_string=False):
"""Return an aligned array for a token attribute."""
# TODO: This is probably wrong. I just bashed this out and there's probably
# all sorts of edge-cases.
alignment = self.alignment
i2j_multi = alignment.i2j_multi
j2i_multi = alignment.j2i_multi
gold_to_cand = alignment.gold_to_cand
cand_to_gold = alignment.cand_to_gold
@ -92,8 +93,18 @@ cdef class Example:
else:
output.append(None)
else:
output.append([gold_values[gold_i]])
output = [vocab.strings[o] for o in output]
output.append(gold_values[gold_i])
if field in ["ENT_IOB", "ENT_TYPE"]:
# Assign O/- for one-to-many O/- NER tags
for j, cand_j in enumerate(gold_to_cand):
if cand_j is None:
if j in j2i_multi:
i = j2i_multi[j]
output[i] = gold_values[j]
if as_string:
output = [vocab.strings[o] if o is not None else o for o in output]
return output
def to_dict(self):

View File

@ -1,12 +1,10 @@
from spacy.errors import AlignmentError
from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags
from spacy.gold import spans_from_biluo_tags, iob_to_biluo, align
from spacy.gold import GoldCorpus, docs_to_json, DocAnnotation
from spacy.gold import GoldCorpus, docs_to_json
from spacy.gold.example import Example
from spacy.lang.en import English
from spacy.syntax.nonproj import is_nonproj_tree
from spacy.syntax.gold_parse import GoldParse, get_parses_from_example
from spacy.syntax.gold_parse import get_parses_from_example
from spacy.tokens import Doc
from spacy.util import get_words_and_spaces, compounding, minibatch
import pytest
@ -158,12 +156,10 @@ def test_gold_biluo_different_tokenization(en_vocab, en_tokenizer):
spaces = [True, True, False, False]
doc = Doc(en_vocab, words=words, spaces=spaces)
entities = [(len("I flew to "), len("I flew to San Francisco Valley"), "LOC")]
gp = GoldParse(
doc,
words=["I", "flew", "to", "San", "Francisco", "Valley", "."],
entities=entities,
)
assert gp.ner == ["O", "O", "U-LOC", "O"]
gold_words = ["I", "flew", "to", "San", "Francisco", "Valley", "."]
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
assert example.get_aligned("ENT_IOB") == [2, 2, 1, 2]
assert example.get_aligned("ENT_TYPE", as_string=True) == ["", "", "LOC", ""]
# many-to-one
words = ["I", "flew", "to", "San", "Francisco", "Valley", "."]

View File

@ -40,7 +40,7 @@ def test_Example_from_dict_with_tags(pred_words, annots):
example = Example.from_dict(predicted, annots)
for i, token in enumerate(example.reference):
assert token.tag_ == annots["tags"][i]
aligned_tags = example.get_aligned("tag")
aligned_tags = example.get_aligned("tag", as_string=True)
assert aligned_tags == ["NN" for _ in predicted]
@ -52,7 +52,7 @@ def test_aligned_tags():
vocab = Vocab()
predicted = Doc(vocab, words=pred_words)
example = Example.from_dict(predicted, annots)
aligned_tags = example.get_aligned("tag")
aligned_tags = example.get_aligned("tag", as_string=True)
assert aligned_tags == ["VERB", "DET", None, "SCONJ", "PRON", "VERB", "VERB"]
@ -64,7 +64,7 @@ def test_aligned_tags_multi():
vocab = Vocab()
predicted = Doc(vocab, words=pred_words)
example = Example.from_dict(predicted, annots)
aligned_tags = example.get_aligned("tag")
aligned_tags = example.get_aligned("tag", as_string=True)
assert aligned_tags == [None, None, "SCONJ", "PRON", "VERB", "VERB"]
@ -159,14 +159,11 @@ def test_Example_from_dict_with_entities(annots):
vocab = Vocab()
predicted = Doc(vocab, words=annots["words"])
example = Example.from_dict(predicted, annots)
assert len(list(example.reference.ents)) == 2
assert example.reference[0].ent_iob_ == "O"
assert example.reference[1].ent_iob_ == "O"
assert example.reference[2].ent_iob_ == "B"
assert example.reference[3].ent_iob_ == "I"
assert example.reference[4].ent_iob_ == "O"
assert example.reference[5].ent_iob_ == "B"
assert example.reference[6].ent_iob_ == "O"
assert [example.reference[i].ent_iob_ for i in range(7)] == ["O", "O", "B", "I", "O", "B", "O"]
assert example.get_aligned("ENT_IOB") == [2, 2, 3, 1, 2, 3, 2]
assert example.reference[2].ent_type_ == "LOC"
assert example.reference[3].ent_type_ == "LOC"
assert example.reference[5].ent_type_ == "LOC"