spaCy/spacy/tests/parser/test_add_label.py
Sofie Van Landeghem e48a09df4e Example class for training data (#4543)
* OrigAnnot class instead of gold.orig_annot list of zipped tuples

* from_orig to replace from_annot_tuples

* rename to RawAnnot

* some unit tests for GoldParse creation and internal format

* removing orig_annot and switching to lists instead of tuple

* rewriting tuples to use RawAnnot (+ debug statements, WIP)

* fix pop() changing the data

* small fixes

* pop-append fixes

* return RawAnnot for existing GoldParse to have uniform interface

* clean up imports

* fix merge_sents

* add unit test for 4402 with new structure (not working yet)

* introduce DocAnnot

* typo fixes

* add unit test for merge_sents

* rename from_orig to from_raw

* fixing unit tests

* fix nn parser

* read_annots to produce text, doc_annot pairs

* _make_golds fix

* rename golds_to_gold_annots

* small fixes

* fix encoding

* have golds_to_gold_annots use DocAnnot

* missed a spot

* merge_sents as function in DocAnnot

* allow specifying only part of the token-level annotations

* refactor with Example class + underlying dicts

* pipeline components to work with Example objects (wip)

* input checking

* fix yielding

* fix calls to update

* small fixes

* fix scorer unit test with new format

* fix kwargs order

* fixes for ud and conllu scripts

* fix reading data for conllu script

* add in proper errors (not fixed numbering yet to avoid merge conflicts)

* fixing few more small bugs

* fix EL script
2019-11-11 17:35:27 +01:00

88 lines
2.5 KiB
Python

# coding: utf8
from __future__ import unicode_literals
import pytest
from thinc.neural.optimizers import Adam
from thinc.neural.ops import NumpyOps
from spacy.attrs import NORM
from spacy.gold import GoldParse
from spacy.vocab import Vocab
from spacy.tokens import Doc
from spacy.pipeline import DependencyParser, EntityRecognizer
from spacy.util import fix_random_seed
@pytest.fixture
def vocab():
return Vocab(lex_attr_getters={NORM: lambda s: s})
@pytest.fixture
def parser(vocab):
parser = DependencyParser(vocab)
return parser
def test_init_parser(parser):
pass
def _train_parser(parser):
fix_random_seed(1)
parser.add_label("left")
parser.begin_training([], **parser.cfg)
sgd = Adam(NumpyOps(), 0.001)
for i in range(5):
losses = {}
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
parser.update((doc, gold), sgd=sgd, losses=losses)
return parser
def test_add_label(parser):
parser = _train_parser(parser)
parser.add_label("right")
sgd = Adam(NumpyOps(), 0.001)
for i in range(10):
losses = {}
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
gold = GoldParse(
doc, heads=[1, 1, 3, 3], deps=["right", "ROOT", "left", "ROOT"]
)
parser.update((doc, gold), sgd=sgd, losses=losses)
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
doc = parser(doc)
assert doc[0].dep_ == "right"
assert doc[2].dep_ == "left"
def test_add_label_deserializes_correctly():
ner1 = EntityRecognizer(Vocab())
ner1.add_label("C")
ner1.add_label("B")
ner1.add_label("A")
ner1.begin_training([])
ner2 = EntityRecognizer(Vocab()).from_bytes(ner1.to_bytes())
assert ner1.moves.n_moves == ner2.moves.n_moves
for i in range(ner1.moves.n_moves):
assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)
@pytest.mark.parametrize(
"pipe_cls,n_moves", [(DependencyParser, 5), (EntityRecognizer, 4)]
)
def test_add_label_get_label(pipe_cls, n_moves):
"""Test that added labels are returned correctly. This test was added to
test for a bug in DependencyParser.labels that'd cause it to fail when
splitting the move names.
"""
labels = ["A", "B", "C"]
pipe = pipe_cls(Vocab())
for label in labels:
pipe.add_label(label)
assert len(pipe.move_names) == len(labels) * n_moves
pipe_labels = sorted(list(pipe.labels))
assert pipe_labels == labels