diff --git a/examples/training/train_new_entity_type.py b/examples/training/train_new_entity_type.py index 656ae1d83..b6fc84590 100644 --- a/examples/training/train_new_entity_type.py +++ b/examples/training/train_new_entity_type.py @@ -45,19 +45,19 @@ LABEL = "ANIMAL" TRAIN_DATA = [ ( "Horses are too tall and they pretend to care about your feelings", - {"entities": [(0, 6, "ANIMAL")]}, + {"entities": [(0, 6, LABEL)]}, ), ("Do they bite?", {"entities": []}), ( "horses are too tall and they pretend to care about your feelings", - {"entities": [(0, 6, "ANIMAL")]}, + {"entities": [(0, 6, LABEL)]}, ), - ("horses pretend to care about your feelings", {"entities": [(0, 6, "ANIMAL")]}), + ("horses pretend to care about your feelings", {"entities": [(0, 6, LABEL)]}), ( "they pretend to care about your feelings, those horses", - {"entities": [(48, 54, "ANIMAL")]}, + {"entities": [(48, 54, LABEL)]}, ), - ("horses?", {"entities": [(0, 6, "ANIMAL")]}), + ("horses?", {"entities": [(0, 6, LABEL)]}), ] @@ -67,8 +67,9 @@ TRAIN_DATA = [ output_dir=("Optional output directory", "option", "o", Path), n_iter=("Number of training iterations", "option", "n", int), ) -def main(model=None, new_model_name="animal", output_dir=None, n_iter=10): +def main(model=None, new_model_name="animal", output_dir=None, n_iter=30): """Set up the pipeline and entity recognizer, and train the new entity.""" + random.seed(0) if model is not None: nlp = spacy.load(model) # load existing spaCy model print("Loaded model '%s'" % model) @@ -85,21 +86,22 @@ def main(model=None, new_model_name="animal", output_dir=None, n_iter=10): ner = nlp.get_pipe("ner") ner.add_label(LABEL) # add new entity label to entity recognizer + # Adding extraneous labels shouldn't mess anything up + ner.add_label('VEGETABLE') if model is None: optimizer = nlp.begin_training() else: - # Note that 'begin_training' initializes the models, so it'll zero out - # existing entity types. - optimizer = nlp.entity.create_optimizer() - + optimizer = nlp.resume_training() + move_names = list(ner.move_names) # get names of other pipes to disable them during training other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"] with nlp.disable_pipes(*other_pipes): # only train NER + sizes = compounding(1.0, 4.0, 1.001) + # batch up the examples using spaCy's minibatch for itn in range(n_iter): random.shuffle(TRAIN_DATA) + batches = minibatch(TRAIN_DATA, size=sizes) losses = {} - # batch up the examples using spaCy's minibatch - batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001)) for batch in batches: texts, annotations = zip(*batch) nlp.update(texts, annotations, sgd=optimizer, drop=0.35, losses=losses) @@ -124,6 +126,8 @@ def main(model=None, new_model_name="animal", output_dir=None, n_iter=10): # test the saved model print("Loading from", output_dir) nlp2 = spacy.load(output_dir) + # Check the classes have loaded back consistently + assert nlp2.get_pipe('ner').move_names == move_names doc2 = nlp2(test_text) for ent in doc2.ents: print(ent.label_, ent.text)