From 5d8af404455bd8ed54ad4cfc99b44bb8beda3cc1 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 23 Apr 2017 17:06:30 +0200 Subject: [PATCH] Add test for Issue #999 --- spacy/tests/regression/test_issue999.py | 75 +++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 spacy/tests/regression/test_issue999.py diff --git a/spacy/tests/regression/test_issue999.py b/spacy/tests/regression/test_issue999.py new file mode 100644 index 000000000..6e3e9af3f --- /dev/null +++ b/spacy/tests/regression/test_issue999.py @@ -0,0 +1,75 @@ +from __future__ import unicode_literals +import json +import os +import random +import contextlib +import shutil +import pytest +import tempfile +from pathlib import Path + + +import pathlib +from ...gold import GoldParse +from ...pipeline import EntityRecognizer +from ...en import English + +try: + unicode +except NameError: + unicode = str + + +@pytest.fixture +def train_data(): + return [ + ["hey",[]], + ["howdy",[]], + ["hey there",[]], + ["hello",[]], + ["hi",[]], + ["i'm looking for a place to eat",[]], + ["i'm looking for a place in the north of town",[[31,36,"location"]]], + ["show me chinese restaurants",[[8,15,"cuisine"]]], + ["show me chines restaurants",[[8,14,"cuisine"]]], + ] + + +@contextlib.contextmanager +def temp_save_model(model): + model_dir = Path(tempfile.mkdtemp()) + model.save_to_directory(model_dir) + yield model_dir + shutil.rmtree(model_dir.as_posix()) + + +def test_issue999(train_data): + '''Test that adding entities and resuming training works passably OK. + There are two issues here: + + 1) We have to readd labels. This isn't very nice. + 2) There's no way to set the learning rate for the weight update, so we + end up out-of-scale, causing it to learn too fast. + ''' + nlp = English(entity=False) + nlp.entity = EntityRecognizer(nlp.vocab, features=English.Defaults.entity_features) + for _, offsets in train_data: + for start, end, ent_type in offsets: + nlp.entity.add_label(ent_type) + for itn in range(10): + random.shuffle(train_data) + for raw_text, entity_offsets in train_data: + doc = nlp.make_doc(raw_text) + gold = GoldParse(doc, entities=entity_offsets) + loss = nlp.entity.update(doc, gold) + + with temp_save_model(nlp) as model_dir: + nlp2 = English(path=model_dir) + + for raw_text, entity_offsets in train_data: + doc = nlp2(raw_text) + ents = {(ent.start_char, ent.end_char): ent.label_ for ent in doc.ents} + for start, end, label in entity_offsets: + if (start, end) not in ents: + print(ents) + assert ents[(start, end)] == label