spaCy/spacy/tests/parser/test_arc_eager_oracle.py
Matthew Honnibal 8656a08777
Add beam_parser and beam_ner components for v3 (#6369)
* Get basic beam tests working

* Get basic beam tests working

* Compile _beam_utils

* Remove prints

* Test beam density

* Beam parser seems to train

* Draft beam NER

* Upd beam

* Add hypothesis as dev dependency

* Implement missing is-gold-parse method

* Implement early update

* Fix state hashing

* Fix test

* Fix test

* Default to non-beam in parser constructor

* Improve oracle for beam

* Start refactoring beam

* Update test

* Refactor beam

* Update nn

* Refactor beam and weight by cost

* Update ner beam settings

* Update test

* Add __init__.pxd

* Upd test

* Fix test

* Upd test

* Fix test

* Remove ring buffer history from StateC

* WIP change arc-eager transitions

* Add state tests

* Support ternary sent start values

* Fix arc eager

* Fix NER

* Pass oracle cut size for beam

* Fix ner test

* Fix beam

* Improve StateC.clone

* Improve StateClass.borrow

* Work directly with StateC, not StateClass

* Remove print statements

* Fix state copy

* Improve state class

* Refactor parser oracles

* Fix arc eager oracle

* Fix arc eager oracle

* Use a vector to implement the stack

* Refactor state data structure

* Fix alignment of sent start

* Add get_aligned_sent_starts method

* Add test for ae oracle when bad sentence starts

* Fix sentence segment handling

* Avoid Reduce that inserts illegal sentence

* Update preset SBD test

* Fix test

* Remove prints

* Fix sent starts in Example

* Improve python API of StateClass

* Tweak comments and debug output of arc eager

* Upd test

* Fix state test

* Fix state test
2020-12-13 09:08:32 +08:00

276 lines
9.3 KiB
Python

import pytest
from spacy.vocab import Vocab
from spacy import registry
from spacy.training import Example
from spacy.pipeline import DependencyParser
from spacy.tokens import Doc
from spacy.pipeline._parser_internals.nonproj import projectivize
from spacy.pipeline._parser_internals.arc_eager import ArcEager
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
from spacy.pipeline._parser_internals.stateclass import StateClass
def get_sequence_costs(M, words, heads, deps, transitions):
doc = Doc(Vocab(), words=words)
example = Example.from_dict(doc, {"heads": heads, "deps": deps})
states, golds, _ = M.init_gold_batch([example])
state = states[0]
gold = golds[0]
cost_history = []
for gold_action in transitions:
gold.update(state)
state_costs = {}
for i in range(M.n_moves):
name = M.class_name(i)
state_costs[name] = M.get_cost(state, gold, i)
M.transition(state, gold_action)
cost_history.append(state_costs)
return state, cost_history
@pytest.fixture
def vocab():
return Vocab()
@pytest.fixture
def arc_eager(vocab):
moves = ArcEager(vocab.strings, ArcEager.get_actions())
moves.add_action(2, "left")
moves.add_action(3, "right")
return moves
def test_oracle_four_words(arc_eager, vocab):
words = ["a", "b", "c", "d"]
heads = [1, 1, 3, 3]
deps = ["left", "ROOT", "left", "ROOT"]
for dep in deps:
arc_eager.add_action(2, dep) # Left
arc_eager.add_action(3, dep) # Right
actions = ["S", "L-left", "B-ROOT", "S", "D", "S", "L-left", "S", "D"]
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
expected_gold = [
["S"],
["B-ROOT", "L-left"],
["B-ROOT"],
["S"],
["D"],
["S"],
["L-left"],
["S"],
["D"]
]
assert state.is_final()
for i, state_costs in enumerate(cost_history):
# Check gold moves is 0 cost
golds = [act for act, cost in state_costs.items() if cost < 1]
assert golds == expected_gold[i], (i, golds, expected_gold[i])
annot_tuples = [
(0, "When", "WRB", 11, "advmod", "O"),
(1, "Walter", "NNP", 2, "compound", "B-PERSON"),
(2, "Rodgers", "NNP", 11, "nsubj", "L-PERSON"),
(3, ",", ",", 2, "punct", "O"),
(4, "our", "PRP$", 6, "poss", "O"),
(5, "embedded", "VBN", 6, "amod", "O"),
(6, "reporter", "NN", 2, "appos", "O"),
(7, "with", "IN", 6, "prep", "O"),
(8, "the", "DT", 10, "det", "B-ORG"),
(9, "3rd", "NNP", 10, "compound", "I-ORG"),
(10, "Cavalry", "NNP", 7, "pobj", "L-ORG"),
(11, "says", "VBZ", 44, "advcl", "O"),
(12, "three", "CD", 13, "nummod", "U-CARDINAL"),
(13, "battalions", "NNS", 16, "nsubj", "O"),
(14, "of", "IN", 13, "prep", "O"),
(15, "troops", "NNS", 14, "pobj", "O"),
(16, "are", "VBP", 11, "ccomp", "O"),
(17, "on", "IN", 16, "prep", "O"),
(18, "the", "DT", 19, "det", "O"),
(19, "ground", "NN", 17, "pobj", "O"),
(20, ",", ",", 17, "punct", "O"),
(21, "inside", "IN", 17, "prep", "O"),
(22, "Baghdad", "NNP", 21, "pobj", "U-GPE"),
(23, "itself", "PRP", 22, "appos", "O"),
(24, ",", ",", 16, "punct", "O"),
(25, "have", "VBP", 26, "aux", "O"),
(26, "taken", "VBN", 16, "dep", "O"),
(27, "up", "RP", 26, "prt", "O"),
(28, "positions", "NNS", 26, "dobj", "O"),
(29, "they", "PRP", 31, "nsubj", "O"),
(30, "'re", "VBP", 31, "aux", "O"),
(31, "going", "VBG", 26, "parataxis", "O"),
(32, "to", "TO", 33, "aux", "O"),
(33, "spend", "VB", 31, "xcomp", "O"),
(34, "the", "DT", 35, "det", "B-TIME"),
(35, "night", "NN", 33, "dobj", "L-TIME"),
(36, "there", "RB", 33, "advmod", "O"),
(37, "presumably", "RB", 33, "advmod", "O"),
(38, ",", ",", 44, "punct", "O"),
(39, "how", "WRB", 40, "advmod", "O"),
(40, "many", "JJ", 41, "amod", "O"),
(41, "soldiers", "NNS", 44, "pobj", "O"),
(42, "are", "VBP", 44, "aux", "O"),
(43, "we", "PRP", 44, "nsubj", "O"),
(44, "talking", "VBG", 44, "ROOT", "O"),
(45, "about", "IN", 44, "prep", "O"),
(46, "right", "RB", 47, "advmod", "O"),
(47, "now", "RB", 44, "advmod", "O"),
(48, "?", ".", 44, "punct", "O"),
]
def test_get_oracle_actions():
ids, words, tags, heads, deps, ents = [], [], [], [], [], []
for id_, word, tag, head, dep, ent in annot_tuples:
ids.append(id_)
words.append(word)
tags.append(tag)
heads.append(head)
deps.append(dep)
ents.append(ent)
doc = Doc(Vocab(), words=[t[1] for t in annot_tuples])
config = {
"learn_tokens": False,
"min_action_freq": 0,
"update_with_oracle_cut_size": 100,
}
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.resolve(cfg, validate=True)["model"]
parser = DependencyParser(doc.vocab, model, **config)
parser.moves.add_action(0, "")
parser.moves.add_action(1, "")
parser.moves.add_action(1, "")
parser.moves.add_action(4, "ROOT")
heads, deps = projectivize(heads, deps)
for i, (head, dep) in enumerate(zip(heads, deps)):
if head > i:
parser.moves.add_action(2, dep)
elif head < i:
parser.moves.add_action(3, dep)
example = Example.from_dict(
doc, {"words": words, "tags": tags, "heads": heads, "deps": deps}
)
parser.moves.get_oracle_sequence(example)
def test_oracle_dev_sentence(vocab, arc_eager):
words_deps_heads = """
Rolls-Royce nn Inc.
Motor nn Inc.
Cars nn Inc.
Inc. nsubj said
said ROOT said
it nsubj expects
expects ccomp said
its poss sales
U.S. nn sales
sales nsubj steady
to aux steady
remain cop steady
steady xcomp expects
at prep steady
about quantmod 1,200
1,200 num cars
cars pobj at
in prep steady
1990 pobj in
. punct said
"""
expected_transitions = [
"S", # Shift "Rolls-Royce"
"S", # Shift 'Motor'
"S", # Shift 'Cars'
"L-nn", # Attach 'Cars' to 'Inc.'
"L-nn", # Attach 'Motor' to 'Inc.'
"L-nn", # Attach 'Rolls-Royce' to 'Inc.'
"S", # Shift "Inc."
"L-nsubj", # Attach 'Inc.' to 'said'
"S", # Shift 'said'
"S", # Shift 'it'
"L-nsubj", # Attach 'it.' to 'expects'
"R-ccomp", # Attach 'expects' to 'said'
"S", # Shift 'its'
"S", # Shift 'U.S.'
"L-nn", # Attach 'U.S.' to 'sales'
"L-poss", # Attach 'its' to 'sales'
"S", # Shift 'sales'
"S", # Shift 'to'
"S", # Shift 'remain'
"L-cop", # Attach 'remain' to 'steady'
"L-aux", # Attach 'to' to 'steady'
"L-nsubj", # Attach 'sales' to 'steady'
"R-xcomp", # Attach 'steady' to 'expects'
"R-prep", # Attach 'at' to 'steady'
"S", # Shift 'about'
"L-quantmod", # Attach "about" to "1,200"
"S", # Shift "1,200"
"L-num", # Attach "1,200" to "cars"
"R-pobj", # Attach "cars" to "at"
"D", # Reduce "cars"
"D", # Reduce "at"
"R-prep", # Attach "in" to "steady"
"R-pobj", # Attach "1990" to "in"
"D", # Reduce "1990"
"D", # Reduce "in"
"D", # Reduce "steady"
"D", # Reduce "expects"
"R-punct", # Attach "." to "said"
"D", # Reduce "."
"D", # Reduce "said"
]
gold_words = []
gold_deps = []
gold_heads = []
for line in words_deps_heads.strip().split("\n"):
line = line.strip()
if not line:
continue
word, dep, head = line.split()
gold_words.append(word)
gold_deps.append(dep)
gold_heads.append(head)
gold_heads = [gold_words.index(head) for head in gold_heads]
for dep in gold_deps:
arc_eager.add_action(2, dep) # Left
arc_eager.add_action(3, dep) # Right
doc = Doc(Vocab(), words=gold_words)
example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps})
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
assert ae_oracle_actions == expected_transitions
def test_oracle_bad_tokenization(vocab, arc_eager):
words_deps_heads = """
[catalase] dep is
: punct is
that nsubj is
is root is
bad comp is
"""
gold_words = []
gold_deps = []
gold_heads = []
for line in words_deps_heads.strip().split("\n"):
line = line.strip()
if not line:
continue
word, dep, head = line.split()
gold_words.append(word)
gold_deps.append(dep)
gold_heads.append(head)
gold_heads = [gold_words.index(head) for head in gold_heads]
for dep in gold_deps:
arc_eager.add_action(2, dep) # Left
arc_eager.add_action(3, dep) # Right
reference = Doc(Vocab(), words=gold_words, deps=gold_deps, heads=gold_heads)
predicted = Doc(reference.vocab, words=["[", "catalase", "]", ":", "that", "is", "bad"])
example = Example(predicted=predicted, reference=reference)
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
assert ae_oracle_actions