Add tests for NER oracle with whitespace

This commit is contained in:
Matthew Honnibal 2019-08-29 14:33:39 +02:00
parent 6511e1d8d3
commit 3c1c0ec18e

View File

@ -91,3 +91,69 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
assert [w.ent_iob_ for w in doc] == ["", "", "", "B"]
doc.ents = [(doc.vocab.strings["WORD"], 0, 2)]
assert [w.ent_iob_ for w in doc] == ["B", "I", "", ""]
def test_oracle_moves_missing_B(en_vocab):
words = ["B", "52", "Bomber"]
biluo_tags = [None, None, "L-PRODUCT"]
doc = Doc(en_vocab, words=words)
gold = GoldParse(doc, words=words, entities=biluo_tags)
moves = BiluoPushDown(en_vocab.strings)
move_types = ("M", "B", "I", "L", "U", "O")
for tag in biluo_tags:
if tag is None:
continue
elif tag == "O":
moves.add_action(move_types.index("O"), "")
else:
action, label = tag.split("-")
moves.add_action(move_types.index("B"), label)
moves.add_action(move_types.index("I"), label)
moves.add_action(move_types.index("L"), label)
moves.add_action(move_types.index("U"), label)
moves.preprocess_gold(gold)
seq = moves.get_oracle_sequence(doc, gold)
print(seq)
def test_oracle_moves_whitespace(en_vocab):
words = [
"production",
"\n",
"of",
"Northrop",
"\n",
"Corp.",
"\n",
"'s",
"radar",
]
biluo_tags = [
"O",
"O",
"O",
"B-ORG",
None,
"I-ORG",
"L-ORG",
"O",
"O",
]
doc = Doc(en_vocab, words=words)
gold = GoldParse(doc, words=words, entities=biluo_tags)
moves = BiluoPushDown(en_vocab.strings)
move_types = ("M", "B", "I", "L", "U", "O")
for tag in biluo_tags:
if tag is None:
continue
elif tag == "O":
moves.add_action(move_types.index("O"), "")
else:
action, label = tag.split("-")
moves.add_action(move_types.index(action), label)
moves.preprocess_gold(gold)
seq = moves.get_oracle_sequence(doc, gold)