mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
d9a07a7f6e
v2.1 introduced a regression when deserializing the parser after parser.add_label() had been called. The code around the class mapping is pretty confusing currently, as it was written to accommodate backwards model compatibility. It needs to be revised when the models are next retrained. Closes #3433
71 lines
1.9 KiB
Python
71 lines
1.9 KiB
Python
# coding: utf8
|
|
from __future__ import unicode_literals
|
|
|
|
import pytest
|
|
from thinc.neural.optimizers import Adam
|
|
from thinc.neural.ops import NumpyOps
|
|
from spacy.attrs import NORM
|
|
from spacy.gold import GoldParse
|
|
from spacy.vocab import Vocab
|
|
from spacy.tokens import Doc
|
|
from spacy.pipeline import DependencyParser, EntityRecognizer
|
|
from spacy.util import fix_random_seed
|
|
|
|
|
|
@pytest.fixture
|
|
def vocab():
|
|
return Vocab(lex_attr_getters={NORM: lambda s: s})
|
|
|
|
|
|
@pytest.fixture
|
|
def parser(vocab):
|
|
parser = DependencyParser(vocab)
|
|
return parser
|
|
|
|
|
|
def test_init_parser(parser):
|
|
pass
|
|
|
|
|
|
def _train_parser(parser):
|
|
fix_random_seed(1)
|
|
parser.add_label("left")
|
|
parser.begin_training([], **parser.cfg)
|
|
sgd = Adam(NumpyOps(), 0.001)
|
|
|
|
for i in range(5):
|
|
losses = {}
|
|
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
|
gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
|
|
parser.update([doc], [gold], sgd=sgd, losses=losses)
|
|
return parser
|
|
|
|
|
|
def test_add_label(parser):
|
|
parser = _train_parser(parser)
|
|
parser.add_label("right")
|
|
sgd = Adam(NumpyOps(), 0.001)
|
|
for i in range(10):
|
|
losses = {}
|
|
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
|
gold = GoldParse(
|
|
doc, heads=[1, 1, 3, 3], deps=["right", "ROOT", "left", "ROOT"]
|
|
)
|
|
parser.update([doc], [gold], sgd=sgd, losses=losses)
|
|
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
|
doc = parser(doc)
|
|
assert doc[0].dep_ == "right"
|
|
assert doc[2].dep_ == "left"
|
|
|
|
|
|
def test_add_label_deserializes_correctly():
|
|
ner1 = EntityRecognizer(Vocab())
|
|
ner1.add_label("C")
|
|
ner1.add_label("B")
|
|
ner1.add_label("A")
|
|
ner1.begin_training([])
|
|
ner2 = EntityRecognizer(Vocab()).from_bytes(ner1.to_bytes())
|
|
assert ner1.moves.n_moves == ner2.moves.n_moves
|
|
for i in range(ner1.moves.n_moves):
|
|
assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)
|