From 040751ad17c96a6e6bf2c0d70e7e789cbd8dd7d6 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 23 Apr 2017 16:28:55 +0200 Subject: [PATCH] Remove xfail on Test #910 --- spacy/tests/regression/test_issue910.py | 27 +++++++++++-------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/spacy/tests/regression/test_issue910.py b/spacy/tests/regression/test_issue910.py index 9b2b2287b..3790d429b 100644 --- a/spacy/tests/regression/test_issue910.py +++ b/spacy/tests/regression/test_issue910.py @@ -70,7 +70,6 @@ def temp_save_model(model): -@pytest.mark.xfail @pytest.mark.models def test_issue910(train_data, additional_entity_types): '''Test that adding entities and resuming training works passably OK. @@ -85,11 +84,10 @@ def test_issue910(train_data, additional_entity_types): ents_before_train = [(ent.label_, ent.text) for ent in doc.ents] # Fine tune the ner model for entity_type in additional_entity_types: - if entity_type not in nlp.entity.cfg['actions']['1']: - nlp.entity.add_label(entity_type) + nlp.entity.add_label(entity_type) - nlp.entity.learn_rate = 0.001 - for itn in range(4): + nlp.entity.model.learn_rate = 0.001 + for itn in range(10): random.shuffle(train_data) for raw_text, entity_offsets in train_data: doc = nlp.make_doc(raw_text) @@ -101,13 +99,12 @@ def test_issue910(train_data, additional_entity_types): # Load the fine tuned model loaded_ner = EntityRecognizer.load(model_dir, nlp.vocab) - for entity_type in additional_entity_types: - if entity_type not in loaded_ner.cfg['actions']['1']: - loaded_ner.add_label(entity_type) - - doc = nlp(u"I am looking for a restaurant in Berlin", entity=False) - nlp.tagger(doc) - loaded_ner(doc) - - ents_after_train = [(ent.label_, ent.text) for ent in doc.ents] - assert ents_before_train == ents_after_train + for raw_text, entity_offsets in train_data: + doc = nlp.make_doc(raw_text) + nlp.tagger(doc) + loaded_ner(doc) + ents = {(ent.start_char, ent.end_char): ent.label_ for ent in doc.ents} + for start, end, label in entity_offsets: + if (start, end) not in ents: + print(ents) + assert ents[(start, end)] == label