Add xfail test for #3433. Improve test for add label.

This commit is contained in:
Matthew Honnibal 2019-03-23 12:35:29 +01:00
parent 6b6e9b638e
commit 444a3abfe5

View File

@ -8,7 +8,8 @@ from spacy.attrs import NORM
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.pipeline import DependencyParser from spacy.pipeline import DependencyParser, EntityRecognizer
from spacy.util import fix_random_seed
@pytest.fixture @pytest.fixture
@ -19,18 +20,6 @@ def vocab():
@pytest.fixture @pytest.fixture
def parser(vocab): def parser(vocab):
parser = DependencyParser(vocab) parser = DependencyParser(vocab)
parser.cfg["token_vector_width"] = 8
parser.cfg["hidden_width"] = 30
parser.cfg["hist_size"] = 0
parser.add_label("left")
parser.begin_training([], **parser.cfg)
sgd = Adam(NumpyOps(), 0.001)
for i in range(10):
losses = {}
doc = Doc(vocab, words=["a", "b", "c", "d"])
gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
parser.update([doc], [gold], sgd=sgd, losses=losses)
return parser return parser
@ -38,10 +27,22 @@ def test_init_parser(parser):
pass pass
# TODO: This is flakey, because it depends on what the parser first learns. def _train_parser(parser):
# TODO: This now seems to be implicated in segfaults. Not sure what's up! fix_random_seed(1)
@pytest.mark.skip parser.add_label("left")
parser.begin_training([], **parser.cfg)
sgd = Adam(NumpyOps(), 0.001)
for i in range(10):
losses = {}
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
parser.update([doc], [gold], sgd=sgd, losses=losses)
return parser
def test_add_label(parser): def test_add_label(parser):
parser = _train_parser(parser)
doc = Doc(parser.vocab, words=["a", "b", "c", "d"]) doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
doc = parser(doc) doc = parser(doc)
assert doc[0].head.i == 1 assert doc[0].head.i == 1
@ -69,3 +70,16 @@ def test_add_label(parser):
doc = parser(doc) doc = parser(doc)
assert doc[0].dep_ == "right" assert doc[0].dep_ == "right"
assert doc[2].dep_ == "left" assert doc[2].dep_ == "left"
@pytest.mark.xfail
def test_add_label_deserializes_correctly():
ner1 = EntityRecognizer(Vocab())
ner1.add_label("C")
ner1.add_label("B")
ner1.add_label("A")
ner1.begin_training([])
ner2 = EntityRecognizer(Vocab()).from_bytes(ner1.to_bytes())
assert ner1.moves.n_moves == ner2.moves.n_moves
for i in range(ner1.moves.n_moves):
assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)