Fix train_new_entity_type example

This commit is contained in:
Matthew Honnibal 2017-04-18 13:47:36 +02:00
parent 6a4221a6de
commit 2f84626417

View File

@ -4,7 +4,6 @@ import random
from pathlib import Path from pathlib import Path
import spacy import spacy
from spacy.pipeline import EntityRecognizer
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.tagger import Tagger from spacy.tagger import Tagger
@ -25,10 +24,13 @@ def train_ner(nlp, train_data, output_dir):
loss = nlp.entity.update(doc, gold) loss = nlp.entity.update(doc, gold)
nlp.end_training() nlp.end_training()
if output_dir: if output_dir:
if not output_dir.exists():
output_dir.mkdir()
nlp.save_to_directory(output_dir) nlp.save_to_directory(output_dir)
def main(model_name, output_directory=None): def main(model_name, output_directory=None):
print("Loading initial model", model_name)
nlp = spacy.load(model_name) nlp = spacy.load(model_name)
if output_directory is not None: if output_directory is not None:
output_directory = Path(output_directory) output_directory = Path(output_directory)
@ -52,13 +54,14 @@ def main(model_name, output_directory=None):
) )
] ]
nlp.entity.add_label('ANIMAL') nlp.entity.add_label('ANIMAL')
ner = train_ner(nlp, train_data, output_directory) train_ner(nlp, train_data, output_directory)
# Test that the entity is recognized # Test that the entity is recognized
doc = nlp('Do you like horses?') doc = nlp('Do you like horses?')
for ent in doc.ents: for ent in doc.ents:
print(ent.label_, ent.text) print(ent.label_, ent.text)
if output_directory: if output_directory:
print("Loading from", output_directory)
nlp2 = spacy.load('en', path=output_directory) nlp2 = spacy.load('en', path=output_directory)
nlp2.entity.add_label('ANIMAL') nlp2.entity.add_label('ANIMAL')
doc2 = nlp2('Do you like horses?') doc2 = nlp2('Do you like horses?')