Fix train_new_entity_type example

This commit is contained in:
Matthew Honnibal 2017-04-18 13:47:36 +02:00
parent 6a4221a6de
commit 2f84626417

View File

@ -4,7 +4,6 @@ import random
from pathlib import Path
import spacy
from spacy.pipeline import EntityRecognizer
from spacy.gold import GoldParse
from spacy.tagger import Tagger
@ -25,10 +24,13 @@ def train_ner(nlp, train_data, output_dir):
loss = nlp.entity.update(doc, gold)
nlp.end_training()
if output_dir:
if not output_dir.exists():
output_dir.mkdir()
nlp.save_to_directory(output_dir)
def main(model_name, output_directory=None):
print("Loading initial model", model_name)
nlp = spacy.load(model_name)
if output_directory is not None:
output_directory = Path(output_directory)
@ -52,13 +54,14 @@ def main(model_name, output_directory=None):
)
]
nlp.entity.add_label('ANIMAL')
ner = train_ner(nlp, train_data, output_directory)
train_ner(nlp, train_data, output_directory)
# Test that the entity is recognized
doc = nlp('Do you like horses?')
for ent in doc.ents:
print(ent.label_, ent.text)
if output_directory:
print("Loading from", output_directory)
nlp2 = spacy.load('en', path=output_directory)
nlp2.entity.add_label('ANIMAL')
doc2 = nlp2('Do you like horses?')