spaCy/spacy/tests/parser/test_arc_eager_oracle.py
Matthew Honnibal 6f5e308d17
Support negative examples in partial NER annotations (#8106)
* Support a cfg field in transition system

* Make NER 'has gold' check use right alignment for span

* Pass 'negative_samples_key' property into NER transition system

* Add field for negative samples to NER transition system

* Check neg_key in NER has_gold

* Support negative examples in NER oracle

* Test for negative examples in NER

* Fix name of config variable in NER

* Remove vestiges of old-style partial annotation

* Remove obsolete tests

* Add comment noting lack of support for negative samples in parser

* Additions to "neg examples" PR (#8201)

* add custom error and test for deprecated format

* add test for unlearning an entity

* add break also for Begin's cost

* add negative_samples_key property on Parser

* rename

* extend docs & fix some older docs issues

* add subclass constructors, clean up tests, fix docs

* add flaky test with ValueError if gold parse was not found

* remove ValueError if n_gold == 0

* fix docstring

* Hack in environment variables to try out training

* Remove hack

* Remove NER hack, and support 'negative O' samples

* Fix O oracle

* Fix transition parser

* Remove 'not O' from oracle

* Fix NER oracle

* check for spans in both gold.ents and gold.spans and raise if so, to prevent memory access violation

* use set instead of list in consistency check

Co-authored-by: svlandeg <sofie.vanlandeghem@gmail.com>
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
2021-06-17 17:33:00 +10:00

272 lines
9.1 KiB
Python

import pytest
from spacy.vocab import Vocab
from spacy import registry
from spacy.training import Example
from spacy.pipeline import DependencyParser
from spacy.tokens import Doc
from spacy.pipeline._parser_internals.nonproj import projectivize
from spacy.pipeline._parser_internals.arc_eager import ArcEager
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
def get_sequence_costs(M, words, heads, deps, transitions):
doc = Doc(Vocab(), words=words)
example = Example.from_dict(doc, {"heads": heads, "deps": deps})
states, golds, _ = M.init_gold_batch([example])
state = states[0]
gold = golds[0]
cost_history = []
for gold_action in transitions:
gold.update(state)
state_costs = {}
for i in range(M.n_moves):
name = M.class_name(i)
state_costs[name] = M.get_cost(state, gold, i)
M.transition(state, gold_action)
cost_history.append(state_costs)
return state, cost_history
@pytest.fixture
def vocab():
return Vocab()
@pytest.fixture
def arc_eager(vocab):
moves = ArcEager(vocab.strings, ArcEager.get_actions())
moves.add_action(2, "left")
moves.add_action(3, "right")
return moves
def test_oracle_four_words(arc_eager, vocab):
words = ["a", "b", "c", "d"]
heads = [1, 1, 3, 3]
deps = ["left", "ROOT", "left", "ROOT"]
for dep in deps:
arc_eager.add_action(2, dep) # Left
arc_eager.add_action(3, dep) # Right
actions = ["S", "L-left", "B-ROOT", "S", "D", "S", "L-left", "S", "D"]
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
expected_gold = [
["S"],
["B-ROOT", "L-left"],
["B-ROOT"],
["S"],
["D"],
["S"],
["L-left"],
["S"],
["D"],
]
assert state.is_final()
for i, state_costs in enumerate(cost_history):
# Check gold moves is 0 cost
golds = [act for act, cost in state_costs.items() if cost < 1]
assert golds == expected_gold[i], (i, golds, expected_gold[i])
annot_tuples = [
(0, "When", "WRB", 11, "advmod", "O"),
(1, "Walter", "NNP", 2, "compound", "B-PERSON"),
(2, "Rodgers", "NNP", 11, "nsubj", "L-PERSON"),
(3, ",", ",", 2, "punct", "O"),
(4, "our", "PRP$", 6, "poss", "O"),
(5, "embedded", "VBN", 6, "amod", "O"),
(6, "reporter", "NN", 2, "appos", "O"),
(7, "with", "IN", 6, "prep", "O"),
(8, "the", "DT", 10, "det", "B-ORG"),
(9, "3rd", "NNP", 10, "compound", "I-ORG"),
(10, "Cavalry", "NNP", 7, "pobj", "L-ORG"),
(11, "says", "VBZ", 44, "advcl", "O"),
(12, "three", "CD", 13, "nummod", "U-CARDINAL"),
(13, "battalions", "NNS", 16, "nsubj", "O"),
(14, "of", "IN", 13, "prep", "O"),
(15, "troops", "NNS", 14, "pobj", "O"),
(16, "are", "VBP", 11, "ccomp", "O"),
(17, "on", "IN", 16, "prep", "O"),
(18, "the", "DT", 19, "det", "O"),
(19, "ground", "NN", 17, "pobj", "O"),
(20, ",", ",", 17, "punct", "O"),
(21, "inside", "IN", 17, "prep", "O"),
(22, "Baghdad", "NNP", 21, "pobj", "U-GPE"),
(23, "itself", "PRP", 22, "appos", "O"),
(24, ",", ",", 16, "punct", "O"),
(25, "have", "VBP", 26, "aux", "O"),
(26, "taken", "VBN", 16, "dep", "O"),
(27, "up", "RP", 26, "prt", "O"),
(28, "positions", "NNS", 26, "dobj", "O"),
(29, "they", "PRP", 31, "nsubj", "O"),
(30, "'re", "VBP", 31, "aux", "O"),
(31, "going", "VBG", 26, "parataxis", "O"),
(32, "to", "TO", 33, "aux", "O"),
(33, "spend", "VB", 31, "xcomp", "O"),
(34, "the", "DT", 35, "det", "B-TIME"),
(35, "night", "NN", 33, "dobj", "L-TIME"),
(36, "there", "RB", 33, "advmod", "O"),
(37, "presumably", "RB", 33, "advmod", "O"),
(38, ",", ",", 44, "punct", "O"),
(39, "how", "WRB", 40, "advmod", "O"),
(40, "many", "JJ", 41, "amod", "O"),
(41, "soldiers", "NNS", 44, "pobj", "O"),
(42, "are", "VBP", 44, "aux", "O"),
(43, "we", "PRP", 44, "nsubj", "O"),
(44, "talking", "VBG", 44, "ROOT", "O"),
(45, "about", "IN", 44, "prep", "O"),
(46, "right", "RB", 47, "advmod", "O"),
(47, "now", "RB", 44, "advmod", "O"),
(48, "?", ".", 44, "punct", "O"),
]
def test_get_oracle_actions():
ids, words, tags, heads, deps, ents = [], [], [], [], [], []
for id_, word, tag, head, dep, ent in annot_tuples:
ids.append(id_)
words.append(word)
tags.append(tag)
heads.append(head)
deps.append(dep)
ents.append(ent)
doc = Doc(Vocab(), words=[t[1] for t in annot_tuples])
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.resolve(cfg, validate=True)["model"]
parser = DependencyParser(doc.vocab, model)
parser.moves.add_action(0, "")
parser.moves.add_action(1, "")
parser.moves.add_action(1, "")
parser.moves.add_action(4, "ROOT")
heads, deps = projectivize(heads, deps)
for i, (head, dep) in enumerate(zip(heads, deps)):
if head > i:
parser.moves.add_action(2, dep)
elif head < i:
parser.moves.add_action(3, dep)
example = Example.from_dict(
doc, {"words": words, "tags": tags, "heads": heads, "deps": deps}
)
parser.moves.get_oracle_sequence(example)
def test_oracle_dev_sentence(vocab, arc_eager):
words_deps_heads = """
Rolls-Royce nn Inc.
Motor nn Inc.
Cars nn Inc.
Inc. nsubj said
said ROOT said
it nsubj expects
expects ccomp said
its poss sales
U.S. nn sales
sales nsubj steady
to aux steady
remain cop steady
steady xcomp expects
at prep steady
about quantmod 1,200
1,200 num cars
cars pobj at
in prep steady
1990 pobj in
. punct said
"""
expected_transitions = [
"S", # Shift "Rolls-Royce"
"S", # Shift 'Motor'
"S", # Shift 'Cars'
"L-nn", # Attach 'Cars' to 'Inc.'
"L-nn", # Attach 'Motor' to 'Inc.'
"L-nn", # Attach 'Rolls-Royce' to 'Inc.'
"S", # Shift "Inc."
"L-nsubj", # Attach 'Inc.' to 'said'
"S", # Shift 'said'
"S", # Shift 'it'
"L-nsubj", # Attach 'it.' to 'expects'
"R-ccomp", # Attach 'expects' to 'said'
"S", # Shift 'its'
"S", # Shift 'U.S.'
"L-nn", # Attach 'U.S.' to 'sales'
"L-poss", # Attach 'its' to 'sales'
"S", # Shift 'sales'
"S", # Shift 'to'
"S", # Shift 'remain'
"L-cop", # Attach 'remain' to 'steady'
"L-aux", # Attach 'to' to 'steady'
"L-nsubj", # Attach 'sales' to 'steady'
"R-xcomp", # Attach 'steady' to 'expects'
"R-prep", # Attach 'at' to 'steady'
"S", # Shift 'about'
"L-quantmod", # Attach "about" to "1,200"
"S", # Shift "1,200"
"L-num", # Attach "1,200" to "cars"
"R-pobj", # Attach "cars" to "at"
"D", # Reduce "cars"
"D", # Reduce "at"
"R-prep", # Attach "in" to "steady"
"R-pobj", # Attach "1990" to "in"
"D", # Reduce "1990"
"D", # Reduce "in"
"D", # Reduce "steady"
"D", # Reduce "expects"
"R-punct", # Attach "." to "said"
"D", # Reduce "."
"D", # Reduce "said"
]
gold_words = []
gold_deps = []
gold_heads = []
for line in words_deps_heads.strip().split("\n"):
line = line.strip()
if not line:
continue
word, dep, head = line.split()
gold_words.append(word)
gold_deps.append(dep)
gold_heads.append(head)
gold_heads = [gold_words.index(head) for head in gold_heads]
for dep in gold_deps:
arc_eager.add_action(2, dep) # Left
arc_eager.add_action(3, dep) # Right
doc = Doc(Vocab(), words=gold_words)
example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps})
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
assert ae_oracle_actions == expected_transitions
def test_oracle_bad_tokenization(vocab, arc_eager):
words_deps_heads = """
[catalase] dep is
: punct is
that nsubj is
is root is
bad comp is
"""
gold_words = []
gold_deps = []
gold_heads = []
for line in words_deps_heads.strip().split("\n"):
line = line.strip()
if not line:
continue
word, dep, head = line.split()
gold_words.append(word)
gold_deps.append(dep)
gold_heads.append(head)
gold_heads = [gold_words.index(head) for head in gold_heads]
for dep in gold_deps:
arc_eager.add_action(2, dep) # Left
arc_eager.add_action(3, dep) # Right
reference = Doc(Vocab(), words=gold_words, deps=gold_deps, heads=gold_heads)
predicted = Doc(
reference.vocab, words=["[", "catalase", "]", ":", "that", "is", "bad"]
)
example = Example(predicted=predicted, reference=reference)
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
assert ae_oracle_actions