mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Remove xfail on Test #910
This commit is contained in:
parent
ade920c30f
commit
040751ad17
|
@ -70,7 +70,6 @@ def temp_save_model(model):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail
|
|
||||||
@pytest.mark.models
|
@pytest.mark.models
|
||||||
def test_issue910(train_data, additional_entity_types):
|
def test_issue910(train_data, additional_entity_types):
|
||||||
'''Test that adding entities and resuming training works passably OK.
|
'''Test that adding entities and resuming training works passably OK.
|
||||||
|
@ -85,11 +84,10 @@ def test_issue910(train_data, additional_entity_types):
|
||||||
ents_before_train = [(ent.label_, ent.text) for ent in doc.ents]
|
ents_before_train = [(ent.label_, ent.text) for ent in doc.ents]
|
||||||
# Fine tune the ner model
|
# Fine tune the ner model
|
||||||
for entity_type in additional_entity_types:
|
for entity_type in additional_entity_types:
|
||||||
if entity_type not in nlp.entity.cfg['actions']['1']:
|
nlp.entity.add_label(entity_type)
|
||||||
nlp.entity.add_label(entity_type)
|
|
||||||
|
|
||||||
nlp.entity.learn_rate = 0.001
|
nlp.entity.model.learn_rate = 0.001
|
||||||
for itn in range(4):
|
for itn in range(10):
|
||||||
random.shuffle(train_data)
|
random.shuffle(train_data)
|
||||||
for raw_text, entity_offsets in train_data:
|
for raw_text, entity_offsets in train_data:
|
||||||
doc = nlp.make_doc(raw_text)
|
doc = nlp.make_doc(raw_text)
|
||||||
|
@ -101,13 +99,12 @@ def test_issue910(train_data, additional_entity_types):
|
||||||
# Load the fine tuned model
|
# Load the fine tuned model
|
||||||
loaded_ner = EntityRecognizer.load(model_dir, nlp.vocab)
|
loaded_ner = EntityRecognizer.load(model_dir, nlp.vocab)
|
||||||
|
|
||||||
for entity_type in additional_entity_types:
|
for raw_text, entity_offsets in train_data:
|
||||||
if entity_type not in loaded_ner.cfg['actions']['1']:
|
doc = nlp.make_doc(raw_text)
|
||||||
loaded_ner.add_label(entity_type)
|
nlp.tagger(doc)
|
||||||
|
loaded_ner(doc)
|
||||||
doc = nlp(u"I am looking for a restaurant in Berlin", entity=False)
|
ents = {(ent.start_char, ent.end_char): ent.label_ for ent in doc.ents}
|
||||||
nlp.tagger(doc)
|
for start, end, label in entity_offsets:
|
||||||
loaded_ner(doc)
|
if (start, end) not in ents:
|
||||||
|
print(ents)
|
||||||
ents_after_train = [(ent.label_, ent.text) for ent in doc.ents]
|
assert ents[(start, end)] == label
|
||||||
assert ents_before_train == ents_after_train
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user