diff --git a/examples/training/train_new_entity_type.py b/examples/training/train_new_entity_type.py index cbe2963d3..23cb86596 100644 --- a/examples/training/train_new_entity_type.py +++ b/examples/training/train_new_entity_type.py @@ -4,7 +4,6 @@ import random from pathlib import Path import spacy -from spacy.pipeline import EntityRecognizer from spacy.gold import GoldParse from spacy.tagger import Tagger @@ -25,10 +24,13 @@ def train_ner(nlp, train_data, output_dir): loss = nlp.entity.update(doc, gold) nlp.end_training() if output_dir: + if not output_dir.exists(): + output_dir.mkdir() nlp.save_to_directory(output_dir) def main(model_name, output_directory=None): + print("Loading initial model", model_name) nlp = spacy.load(model_name) if output_directory is not None: output_directory = Path(output_directory) @@ -52,13 +54,14 @@ def main(model_name, output_directory=None): ) ] nlp.entity.add_label('ANIMAL') - ner = train_ner(nlp, train_data, output_directory) + train_ner(nlp, train_data, output_directory) # Test that the entity is recognized doc = nlp('Do you like horses?') for ent in doc.ents: print(ent.label_, ent.text) if output_directory: + print("Loading from", output_directory) nlp2 = spacy.load('en', path=output_directory) nlp2.entity.add_label('ANIMAL') doc2 = nlp2('Do you like horses?')