Remove xfail on Test #910

This commit is contained in:
Matthew Honnibal 2017-04-23 16:28:55 +02:00
parent ade920c30f
commit 040751ad17

View File

@ -70,7 +70,6 @@ def temp_save_model(model):
@pytest.mark.xfail
@pytest.mark.models @pytest.mark.models
def test_issue910(train_data, additional_entity_types): def test_issue910(train_data, additional_entity_types):
'''Test that adding entities and resuming training works passably OK. '''Test that adding entities and resuming training works passably OK.
@ -85,11 +84,10 @@ def test_issue910(train_data, additional_entity_types):
ents_before_train = [(ent.label_, ent.text) for ent in doc.ents] ents_before_train = [(ent.label_, ent.text) for ent in doc.ents]
# Fine tune the ner model # Fine tune the ner model
for entity_type in additional_entity_types: for entity_type in additional_entity_types:
if entity_type not in nlp.entity.cfg['actions']['1']: nlp.entity.add_label(entity_type)
nlp.entity.add_label(entity_type)
nlp.entity.learn_rate = 0.001 nlp.entity.model.learn_rate = 0.001
for itn in range(4): for itn in range(10):
random.shuffle(train_data) random.shuffle(train_data)
for raw_text, entity_offsets in train_data: for raw_text, entity_offsets in train_data:
doc = nlp.make_doc(raw_text) doc = nlp.make_doc(raw_text)
@ -101,13 +99,12 @@ def test_issue910(train_data, additional_entity_types):
# Load the fine tuned model # Load the fine tuned model
loaded_ner = EntityRecognizer.load(model_dir, nlp.vocab) loaded_ner = EntityRecognizer.load(model_dir, nlp.vocab)
for entity_type in additional_entity_types: for raw_text, entity_offsets in train_data:
if entity_type not in loaded_ner.cfg['actions']['1']: doc = nlp.make_doc(raw_text)
loaded_ner.add_label(entity_type) nlp.tagger(doc)
loaded_ner(doc)
doc = nlp(u"I am looking for a restaurant in Berlin", entity=False) ents = {(ent.start_char, ent.end_char): ent.label_ for ent in doc.ents}
nlp.tagger(doc) for start, end, label in entity_offsets:
loaded_ner(doc) if (start, end) not in ents:
print(ents)
ents_after_train = [(ent.label_, ent.text) for ent in doc.ents] assert ents[(start, end)] == label
assert ents_before_train == ents_after_train