fix parser tests to work with example (most still failing)

This commit is contained in:
svlandeg 2020-06-18 11:19:22 +02:00
parent 9f43ba839a
commit cd790aaa2a
8 changed files with 50 additions and 39 deletions

View File

@ -497,7 +497,7 @@ cdef class ArcEager(TransitionSystem):
def has_gold(self, gold, start=0, end=None):
raise NotImplementedError
def preprocess_gold(self, gold):
def preprocess_gold(self, example):
raise NotImplementedError
cdef Transition lookup_transition(self, object name_or_id) except *:

View File

@ -94,7 +94,7 @@ cdef class BiluoPushDown(TransitionSystem):
def has_gold(self, gold, start=0, end=None):
raise NotImplementedError
def preprocess_gold(self, gold):
def preprocess_gold(self, example):
raise NotImplementedError
cdef Transition lookup_transition(self, object name) except *:

View File

@ -97,7 +97,7 @@ cdef class TransitionSystem:
def finalize_doc(self, doc):
pass
def preprocess_gold(self, gold):
def preprocess_gold(self, example):
raise NotImplementedError
def is_gold_parse(self, StateClass state, example):

View File

@ -1,6 +1,7 @@
import pytest
from spacy.vocab import Vocab
from spacy.gold import Example
from spacy.pipeline.defaults import default_parser
from spacy.pipeline import DependencyParser
from spacy.tokens import Doc
@ -11,9 +12,9 @@ from spacy.syntax.arc_eager import ArcEager
def get_sequence_costs(M, words, heads, deps, transitions):
doc = Doc(Vocab(), words=words)
gold = GoldParse(doc, heads=heads, deps=deps)
example = Example.from_dict(doc, {"heads": heads, "deps": deps})
state = StateClass(doc)
M.preprocess_gold(gold)
M.preprocess_gold(example)
cost_history = []
for gold_action in transitions:
state_costs = {}
@ -149,6 +150,6 @@ def test_get_oracle_actions():
elif head < i:
parser.moves.add_action(3, dep)
heads, deps = projectivize(heads, deps)
gold = GoldParse(doc, words=words, tags=tags, heads=heads, deps=deps)
parser.moves.preprocess_gold(gold)
parser.moves.get_oracle_sequence(doc, gold)
example = Example.from_dict(doc, {"words": words, "tags": tags, "heads": heads, "deps": deps})
parser.moves.preprocess_gold(example)
parser.moves.get_oracle_sequence(example)

View File

@ -1,4 +1,6 @@
import pytest
from spacy.attrs import ENT_IOB
from spacy import util
from spacy.lang.en import English
from spacy.pipeline.defaults import default_ner
@ -8,7 +10,7 @@ from spacy.syntax.ner import BiluoPushDown
from spacy.tokens import Doc
from ..util import make_tempdir
from ...gold import Example
TRAIN_DATA = [
("Who is Shaka Khan?", {"entities": [(7, 17, "PERSON")]}),
@ -48,41 +50,45 @@ def tsys(vocab, entity_types):
def test_get_oracle_moves(tsys, doc, entity_annots):
gold = GoldParse(doc, entities=entity_annots)
tsys.preprocess_gold(gold)
act_classes = tsys.get_oracle_sequence(doc, gold)
example = Example.from_dict(doc, {"entities": entity_annots})
tsys.preprocess_gold(example)
act_classes = tsys.get_oracle_sequence(example)
names = [tsys.get_class_name(act) for act in act_classes]
assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"]
def test_get_oracle_moves_negative_entities(tsys, doc, entity_annots):
entity_annots = [(s, e, "!" + label) for s, e, label in entity_annots]
gold = GoldParse(doc, entities=entity_annots)
for i, tag in enumerate(gold.ner):
example = Example.from_dict(doc, {"entities": entity_annots})
ex_dict = example.to_dict()
for i, tag in enumerate(ex_dict["doc_annotation"]["entities"]):
if tag == "L-!GPE":
gold.ner[i] = "-"
tsys.preprocess_gold(gold)
act_classes = tsys.get_oracle_sequence(doc, gold)
ex_dict["doc_annotation"]["entities"][i] = "-"
example = Example.from_dict(doc, ex_dict)
tsys.preprocess_gold(example)
act_classes = tsys.get_oracle_sequence(example)
names = [tsys.get_class_name(act) for act in act_classes]
assert names
def test_get_oracle_moves_negative_entities2(tsys, vocab):
doc = Doc(vocab, words=["A", "B", "C", "D"])
gold = GoldParse(doc, entities=[])
gold.ner = ["B-!PERSON", "L-!PERSON", "B-!PERSON", "L-!PERSON"]
tsys.preprocess_gold(gold)
act_classes = tsys.get_oracle_sequence(doc, gold)
entity_annots = ["B-!PERSON", "L-!PERSON", "B-!PERSON", "L-!PERSON"]
example = Example.from_dict(doc, {"entities": entity_annots})
tsys.preprocess_gold(example)
act_classes = tsys.get_oracle_sequence(example)
names = [tsys.get_class_name(act) for act in act_classes]
assert names
def test_get_oracle_moves_negative_O(tsys, vocab):
doc = Doc(vocab, words=["A", "B", "C", "D"])
gold = GoldParse(doc, entities=[])
gold.ner = ["O", "!O", "O", "!O"]
tsys.preprocess_gold(gold)
act_classes = tsys.get_oracle_sequence(doc, gold)
entity_annots = ["O", "!O", "O", "!O"]
example = Example.from_dict(doc, {"entities": []})
tsys.preprocess_gold(example)
act_classes = tsys.get_oracle_sequence(example)
names = [tsys.get_class_name(act) for act in act_classes]
assert names
@ -92,7 +98,7 @@ def test_oracle_moves_missing_B(en_vocab):
biluo_tags = [None, None, "L-PRODUCT"]
doc = Doc(en_vocab, words=words)
gold = GoldParse(doc, words=words, entities=biluo_tags)
example = Example.from_dict(doc, {"words": words, "entities": biluo_tags})
moves = BiluoPushDown(en_vocab.strings)
move_types = ("M", "B", "I", "L", "U", "O")
@ -107,8 +113,8 @@ def test_oracle_moves_missing_B(en_vocab):
moves.add_action(move_types.index("I"), label)
moves.add_action(move_types.index("L"), label)
moves.add_action(move_types.index("U"), label)
moves.preprocess_gold(gold)
moves.get_oracle_sequence(doc, gold)
moves.preprocess_gold(example)
moves.get_oracle_sequence(example)
def test_oracle_moves_whitespace(en_vocab):
@ -116,7 +122,7 @@ def test_oracle_moves_whitespace(en_vocab):
biluo_tags = ["O", "O", "O", "B-ORG", None, "I-ORG", "L-ORG", "O", "O"]
doc = Doc(en_vocab, words=words)
gold = GoldParse(doc, words=words, entities=biluo_tags)
example = Example.from_dict(doc, {"entities": biluo_tags})
moves = BiluoPushDown(en_vocab.strings)
move_types = ("M", "B", "I", "L", "U", "O")
@ -128,8 +134,8 @@ def test_oracle_moves_whitespace(en_vocab):
else:
action, label = tag.split("-")
moves.add_action(move_types.index(action), label)
moves.preprocess_gold(gold)
moves.get_oracle_sequence(doc, gold)
moves.preprocess_gold(example)
moves.get_oracle_sequence(example)
def test_accept_blocked_token():

View File

@ -1,4 +1,6 @@
import pytest
from spacy.gold import Example
from spacy.pipeline.defaults import default_parser, default_tok2vec
from spacy.vocab import Vocab
from spacy.syntax.arc_eager import ArcEager
@ -71,7 +73,8 @@ def test_update_doc(parser, model, doc, gold):
weights -= 0.001 * gradient
return weights, gradient
parser.update((doc, gold), sgd=optimize)
example = Example.from_dict(doc, gold)
parser.update([example], sgd=optimize)
@pytest.mark.xfail

View File

@ -42,11 +42,6 @@ def tokvecs(docs, vector_size):
return output
@pytest.fixture
def golds(docs):
return [GoldParse(doc) for doc in docs]
@pytest.fixture
def batch_size(docs):
return len(docs)
@ -77,19 +72,24 @@ def scores(moves, batch_size, beam_width):
]
# All tests below are skipped after removing Beam stuff during the Example/GoldParse refactor
@pytest.mark.skip
def test_create_beam(beam):
pass
@pytest.mark.skip
def test_beam_advance(beam, scores):
beam.advance(scores)
@pytest.mark.skip
def test_beam_advance_too_few_scores(beam, scores):
with pytest.raises(IndexError):
beam.advance(scores[:-1])
@pytest.mark.skip
def test_beam_parse():
nlp = Language()
config = {"learn_tokens": False, "min_action_freq": 30, "beam_width": 1, "beam_update_prob": 1.0}

View File

@ -3,6 +3,7 @@ from thinc.api import Adam
from spacy.attrs import NORM
from spacy.vocab import Vocab
from spacy.gold import Example
from spacy.pipeline.defaults import default_parser
from spacy.tokens import Doc
from spacy.pipeline import DependencyParser
@ -27,8 +28,8 @@ def parser(vocab):
for i in range(10):
losses = {}
doc = Doc(vocab, words=["a", "b", "c", "d"])
gold = dict(heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
parser.update((doc, gold), sgd=sgd, losses=losses)
example = Example.from_dict(doc, {"heads": [1, 1, 3, 3], "deps": ["left", "ROOT", "left", "ROOT"]})
parser.update([example], sgd=sgd, losses=losses)
return parser