Update train_new_entity_type example

This commit is contained in:
Matthew Honnibal 2019-02-24 16:41:03 +01:00
parent d74dbde828
commit 4dc57d9e15

View File

@ -45,19 +45,19 @@ LABEL = "ANIMAL"
TRAIN_DATA = [ TRAIN_DATA = [
( (
"Horses are too tall and they pretend to care about your feelings", "Horses are too tall and they pretend to care about your feelings",
{"entities": [(0, 6, "ANIMAL")]}, {"entities": [(0, 6, LABEL)]},
), ),
("Do they bite?", {"entities": []}), ("Do they bite?", {"entities": []}),
( (
"horses are too tall and they pretend to care about your feelings", "horses are too tall and they pretend to care about your feelings",
{"entities": [(0, 6, "ANIMAL")]}, {"entities": [(0, 6, LABEL)]},
), ),
("horses pretend to care about your feelings", {"entities": [(0, 6, "ANIMAL")]}), ("horses pretend to care about your feelings", {"entities": [(0, 6, LABEL)]}),
( (
"they pretend to care about your feelings, those horses", "they pretend to care about your feelings, those horses",
{"entities": [(48, 54, "ANIMAL")]}, {"entities": [(48, 54, LABEL)]},
), ),
("horses?", {"entities": [(0, 6, "ANIMAL")]}), ("horses?", {"entities": [(0, 6, LABEL)]}),
] ]
@ -67,8 +67,9 @@ TRAIN_DATA = [
output_dir=("Optional output directory", "option", "o", Path), output_dir=("Optional output directory", "option", "o", Path),
n_iter=("Number of training iterations", "option", "n", int), n_iter=("Number of training iterations", "option", "n", int),
) )
def main(model=None, new_model_name="animal", output_dir=None, n_iter=10): def main(model=None, new_model_name="animal", output_dir=None, n_iter=30):
"""Set up the pipeline and entity recognizer, and train the new entity.""" """Set up the pipeline and entity recognizer, and train the new entity."""
random.seed(0)
if model is not None: if model is not None:
nlp = spacy.load(model) # load existing spaCy model nlp = spacy.load(model) # load existing spaCy model
print("Loaded model '%s'" % model) print("Loaded model '%s'" % model)
@ -85,21 +86,22 @@ def main(model=None, new_model_name="animal", output_dir=None, n_iter=10):
ner = nlp.get_pipe("ner") ner = nlp.get_pipe("ner")
ner.add_label(LABEL) # add new entity label to entity recognizer ner.add_label(LABEL) # add new entity label to entity recognizer
# Adding extraneous labels shouldn't mess anything up
ner.add_label('VEGETABLE')
if model is None: if model is None:
optimizer = nlp.begin_training() optimizer = nlp.begin_training()
else: else:
# Note that 'begin_training' initializes the models, so it'll zero out optimizer = nlp.resume_training()
# existing entity types. move_names = list(ner.move_names)
optimizer = nlp.entity.create_optimizer()
# get names of other pipes to disable them during training # get names of other pipes to disable them during training
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"] other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"]
with nlp.disable_pipes(*other_pipes): # only train NER with nlp.disable_pipes(*other_pipes): # only train NER
sizes = compounding(1.0, 4.0, 1.001)
# batch up the examples using spaCy's minibatch
for itn in range(n_iter): for itn in range(n_iter):
random.shuffle(TRAIN_DATA) random.shuffle(TRAIN_DATA)
batches = minibatch(TRAIN_DATA, size=sizes)
losses = {} losses = {}
# batch up the examples using spaCy's minibatch
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
for batch in batches: for batch in batches:
texts, annotations = zip(*batch) texts, annotations = zip(*batch)
nlp.update(texts, annotations, sgd=optimizer, drop=0.35, losses=losses) nlp.update(texts, annotations, sgd=optimizer, drop=0.35, losses=losses)
@ -124,6 +126,8 @@ def main(model=None, new_model_name="animal", output_dir=None, n_iter=10):
# test the saved model # test the saved model
print("Loading from", output_dir) print("Loading from", output_dir)
nlp2 = spacy.load(output_dir) nlp2 = spacy.load(output_dir)
# Check the classes have loaded back consistently
assert nlp2.get_pipe('ner').move_names == move_names
doc2 = nlp2(test_text) doc2 = nlp2(test_text)
for ent in doc2.ents: for ent in doc2.ents:
print(ent.label_, ent.text) print(ent.label_, ent.text)