diff --git a/spacy/tests/parser/test_arc_eager_oracle.py b/spacy/tests/parser/test_arc_eager_oracle.py index c2ab94500..12883ee08 100644 --- a/spacy/tests/parser/test_arc_eager_oracle.py +++ b/spacy/tests/parser/test_arc_eager_oracle.py @@ -153,3 +153,88 @@ def test_get_oracle_actions(): parser.moves.add_action(3, dep) example = Example.from_dict(doc, {"words": words, "tags": tags, "heads": heads, "deps": deps}) parser.moves.get_oracle_sequence(example) + + +def test_oracle_dev_sentence(vocab, arc_eager): + words_deps_heads = """ + Rolls-Royce nn Inc. + Motor nn Inc. + Cars nn Inc. + Inc. nsubj said + said ROOT said + it nsubj expects + expects ccomp said + its poss sales + U.S. nn sales + sales nsubj steady + to aux steady + remain cop steady + steady xcomp expects + at prep steady + about quantmod 1,200 + 1,200 num cars + cars pobj at + in prep steady + 1990 pobj in + . punct said + """ + expected_transitions = [ + "S", # Shift 'Motor' + "S", # Shift 'Cars' + "L-nn", # Attach 'Cars' to 'Inc.' + "L-nn", # Attach 'Motor' to 'Inc.' + "L-nn", # Attach 'Rolls-Royce' to 'Inc.', force shift + "L-nsubj", # Attach 'Inc.' to 'said' + "S", # Shift 'it' + "L-nsubj", # Attach 'it.' to 'expects' + "R-ccomp", # Attach 'expects' to 'said' + "S", # Shift 'its' + "S", # Shift 'U.S.' + "L-nn", # Attach 'U.S.' to 'sales' + "L-poss", # Attach 'its' to 'sales' + "S", # Shift 'sales' + "S", # Shift 'to' + "S", # Shift 'remain' + "L-cop", # Attach 'remain' to 'steady' + "L-aux", # Attach 'to' to 'steady' + "L-nsubj", # Attach 'sales' to 'steady' + "R-xcomp", # Attach 'steady' to 'expects' + "R-prep", # Attach 'at' to 'steady' + "S", # Shift 'about' + "L-quantmod", # Attach "about" to "1,200" + "S", # Shift "1,200" + "L-num", # Attach "1,200" to "cars" + "R-pobj", # Attach "cars" to "at" + "D", # Reduce "cars" + "D", # Reduce "at" + "R-prep", # Attach "in" to "steady" + "R-pobj", # Attach "1990" to "in" + "D", # Reduce "1990" + "D", # Reduce "in" + "D", # Reduce "steady" + "D", # Reduce "expects" + "R-punct", # Attach "." to "said" + ] + + gold_words = [] + gold_deps = [] + gold_heads = [] + for line in words_deps_heads.strip().split("\n"): + line = line.strip() + if not line: + continue + word, dep, head = line.split() + gold_words.append(word) + gold_deps.append(dep) + gold_heads.append(head) + gold_heads = [gold_words.index(head) for head in gold_heads] + for dep in gold_deps: + arc_eager.add_action(2, dep) # Left + arc_eager.add_action(3, dep) # Right + + doc = Doc(Vocab(), words=gold_words) + example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps}) + + ae_oracle_actions = arc_eager.get_oracle_sequence(example) + ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions] + assert ae_oracle_actions == expected_transitions