1
1
mirror of https://github.com/explosion/spaCy.git synced 2025-01-17 04:54:14 +03:00
spaCy/spacy/tests/parser/test_add_label.py
2017-10-28 13:16:06 +02:00

73 lines
2.1 KiB
Python

'''Test the ability to add a label to a (potentially trained) parsing model.'''
from __future__ import unicode_literals
import pytest
import numpy.random
from thinc.neural.optimizers import Adam
from thinc.neural.ops import NumpyOps
from ...attrs import NORM
from ...gold import GoldParse
from ...vocab import Vocab
from ...tokens import Doc
from ...pipeline import DependencyParser
numpy.random.seed(0)
@pytest.fixture
def vocab():
return Vocab(lex_attr_getters={NORM: lambda s: s})
@pytest.fixture
def parser(vocab):
parser = DependencyParser(vocab)
parser.cfg['token_vector_width'] = 8
parser.cfg['hidden_width'] = 30
parser.cfg['hist_size'] = 0
parser.add_label('left')
parser.begin_training([], **parser.cfg)
sgd = Adam(NumpyOps(), 0.001)
for i in range(10):
losses = {}
doc = Doc(vocab, words=['a', 'b', 'c', 'd'])
gold = GoldParse(doc, heads=[1, 1, 3, 3],
deps=['left', 'ROOT', 'left', 'ROOT'])
parser.update([doc], [gold], sgd=sgd, losses=losses)
return parser
def test_init_parser(parser):
pass
# TODO: This is flakey, because it depends on what the parser first learns.
@pytest.mark.xfail
def test_add_label(parser):
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
doc = parser(doc)
assert doc[0].head.i == 1
assert doc[0].dep_ == 'left'
assert doc[1].head.i == 1
assert doc[2].head.i == 3
assert doc[2].head.i == 3
parser.add_label('right')
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
doc = parser(doc)
assert doc[0].head.i == 1
assert doc[0].dep_ == 'left'
assert doc[1].head.i == 1
assert doc[2].head.i == 3
assert doc[2].head.i == 3
sgd = Adam(NumpyOps(), 0.001)
for i in range(10):
losses = {}
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
gold = GoldParse(doc, heads=[1, 1, 3, 3],
deps=['right', 'ROOT', 'left', 'ROOT'])
parser.update([doc], [gold], sgd=sgd, losses=losses)
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
doc = parser(doc)
assert doc[0].dep_ == 'right'
assert doc[2].dep_ == 'left'