Update train_new_entity_type example

This commit is contained in:
Matthew Honnibal 2019-02-24 16:41:03 +01:00
parent d74dbde828
commit 4dc57d9e15

View File

@ -45,19 +45,19 @@ LABEL = "ANIMAL"
TRAIN_DATA = [
(
"Horses are too tall and they pretend to care about your feelings",
{"entities": [(0, 6, "ANIMAL")]},
{"entities": [(0, 6, LABEL)]},
),
("Do they bite?", {"entities": []}),
(
"horses are too tall and they pretend to care about your feelings",
{"entities": [(0, 6, "ANIMAL")]},
{"entities": [(0, 6, LABEL)]},
),
("horses pretend to care about your feelings", {"entities": [(0, 6, "ANIMAL")]}),
("horses pretend to care about your feelings", {"entities": [(0, 6, LABEL)]}),
(
"they pretend to care about your feelings, those horses",
{"entities": [(48, 54, "ANIMAL")]},
{"entities": [(48, 54, LABEL)]},
),
("horses?", {"entities": [(0, 6, "ANIMAL")]}),
("horses?", {"entities": [(0, 6, LABEL)]}),
]
@ -67,8 +67,9 @@ TRAIN_DATA = [
output_dir=("Optional output directory", "option", "o", Path),
n_iter=("Number of training iterations", "option", "n", int),
)
def main(model=None, new_model_name="animal", output_dir=None, n_iter=10):
def main(model=None, new_model_name="animal", output_dir=None, n_iter=30):
"""Set up the pipeline and entity recognizer, and train the new entity."""
random.seed(0)
if model is not None:
nlp = spacy.load(model) # load existing spaCy model
print("Loaded model '%s'" % model)
@ -85,21 +86,22 @@ def main(model=None, new_model_name="animal", output_dir=None, n_iter=10):
ner = nlp.get_pipe("ner")
ner.add_label(LABEL) # add new entity label to entity recognizer
# Adding extraneous labels shouldn't mess anything up
ner.add_label('VEGETABLE')
if model is None:
optimizer = nlp.begin_training()
else:
# Note that 'begin_training' initializes the models, so it'll zero out
# existing entity types.
optimizer = nlp.entity.create_optimizer()
optimizer = nlp.resume_training()
move_names = list(ner.move_names)
# get names of other pipes to disable them during training
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"]
with nlp.disable_pipes(*other_pipes): # only train NER
sizes = compounding(1.0, 4.0, 1.001)
# batch up the examples using spaCy's minibatch
for itn in range(n_iter):
random.shuffle(TRAIN_DATA)
batches = minibatch(TRAIN_DATA, size=sizes)
losses = {}
# batch up the examples using spaCy's minibatch
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
for batch in batches:
texts, annotations = zip(*batch)
nlp.update(texts, annotations, sgd=optimizer, drop=0.35, losses=losses)
@ -124,6 +126,8 @@ def main(model=None, new_model_name="animal", output_dir=None, n_iter=10):
# test the saved model
print("Loading from", output_dir)
nlp2 = spacy.load(output_dir)
# Check the classes have loaded back consistently
assert nlp2.get_pipe('ner').move_names == move_names
doc2 = nlp2(test_text)
for ent in doc2.ents:
print(ent.label_, ent.text)