Update example for adding entity type

This commit is contained in:
Matthew Honnibal 2017-09-14 16:15:59 +02:00
parent 9cb2aef587
commit 683d81bb49

View File

@ -25,7 +25,7 @@ For more details, see the documentation:
* Saving and loading models: https://spacy.io/docs/usage/saving-loading
Developed for: spaCy 1.7.6
Last tested for: spaCy 1.7.6
Last updated for: spaCy 2.0.0a13
"""
from __future__ import unicode_literals, print_function
@ -34,55 +34,41 @@ from pathlib import Path
import random
import spacy
from spacy.gold import GoldParse
from spacy.tagger import Tagger
from spacy.gold import GoldParse, minibatch
from spacy.pipeline import NeuralEntityRecognizer
from spacy.pipeline import TokenVectorEncoder
def get_gold_parses(tokenizer, train_data):
'''Shuffle and create GoldParse objects'''
random.shuffle(train_data)
for raw_text, entity_offsets in train_data:
doc = tokenizer(raw_text)
gold = GoldParse(doc, entities=entity_offsets)
yield doc, gold
def train_ner(nlp, train_data, output_dir):
# Add new words to vocab
for raw_text, _ in train_data:
doc = nlp.make_doc(raw_text)
for word in doc:
_ = nlp.vocab[word.orth]
random.seed(0)
# You may need to change the learning rate. It's generally difficult to
# guess what rate you should set, especially when you have limited data.
nlp.entity.model.learn_rate = 0.001
for itn in range(1000):
random.shuffle(train_data)
loss = 0.
for raw_text, entity_offsets in train_data:
gold = GoldParse(doc, entities=entity_offsets)
# By default, the GoldParse class assumes that the entities
# described by offset are complete, and all other words should
# have the tag 'O'. You can tell it to make no assumptions
# about the tag of a word by giving it the tag '-'.
# However, this allows a trivial solution to the current
# learning problem: if words are either 'any tag' or 'ANIMAL',
# the model can learn that all words can be tagged 'ANIMAL'.
#for i in range(len(gold.ner)):
#if not gold.ner[i].endswith('ANIMAL'):
# gold.ner[i] = '-'
doc = nlp.make_doc(raw_text)
nlp.tagger(doc)
# As of 1.9, spaCy's parser now lets you supply a dropout probability
# This might help the model generalize better from only a few
# examples.
loss += nlp.entity.update(doc, gold, drop=0.9)
if loss == 0:
break
# This step averages the model's weights. This may or may not be good for
# your situation --- it's empirical.
nlp.end_training()
if output_dir:
if not output_dir.exists():
output_dir.mkdir()
nlp.save_to_directory(output_dir)
optimizer = nlp.begin_training(lambda: [])
nlp.meta['name'] = 'en_ent_animal'
for itn in range(50):
losses = {}
for batch in minibatch(get_gold_parses(nlp.make_doc, train_data), size=3):
docs, golds = zip(*batch)
nlp.update(docs, golds, losses=losses, sgd=optimizer, update_shared=True,
drop=0.35)
print(losses)
if not output_dir:
return
elif not output_dir.exists():
output_dir.mkdir()
nlp.to_disk(output_dir)
def main(model_name, output_directory=None):
print("Loading initial model", model_name)
nlp = spacy.load(model_name)
print("Creating initial model", model_name)
nlp = spacy.blank(model_name)
if output_directory is not None:
output_directory = Path(output_directory)
@ -91,6 +77,11 @@ def main(model_name, output_directory=None):
"Horses are too tall and they pretend to care about your feelings",
[(0, 6, 'ANIMAL')],
),
(
"Do they bite?",
[],
),
(
"horses are too tall and they pretend to care about your feelings",
[(0, 6, 'ANIMAL')]
@ -109,18 +100,20 @@ def main(model_name, output_directory=None):
)
]
nlp.entity.add_label('ANIMAL')
nlp.pipeline.append(TokenVectorEncoder(nlp.vocab))
nlp.pipeline.append(NeuralEntityRecognizer(nlp.vocab))
nlp.pipeline[-1].add_label('ANIMAL')
train_ner(nlp, train_data, output_directory)
# Test that the entity is recognized
doc = nlp('Do you like horses?')
text = 'Do you like horses?'
print("Ents in 'Do you like horses?':")
doc = nlp(text)
for ent in doc.ents:
print(ent.label_, ent.text)
if output_directory:
print("Loading from", output_directory)
nlp2 = spacy.load('en', path=output_directory)
nlp2.entity.add_label('ANIMAL')
nlp2 = spacy.load(output_directory)
doc2 = nlp2('Do you like horses?')
for ent in doc2.ents:
print(ent.label_, ent.text)