mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-11 00:32:40 +03:00
Update example for adding entity type
This commit is contained in:
parent
9cb2aef587
commit
683d81bb49
|
@ -25,7 +25,7 @@ For more details, see the documentation:
|
||||||
* Saving and loading models: https://spacy.io/docs/usage/saving-loading
|
* Saving and loading models: https://spacy.io/docs/usage/saving-loading
|
||||||
|
|
||||||
Developed for: spaCy 1.7.6
|
Developed for: spaCy 1.7.6
|
||||||
Last tested for: spaCy 1.7.6
|
Last updated for: spaCy 2.0.0a13
|
||||||
"""
|
"""
|
||||||
from __future__ import unicode_literals, print_function
|
from __future__ import unicode_literals, print_function
|
||||||
|
|
||||||
|
@ -34,55 +34,41 @@ from pathlib import Path
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import spacy
|
import spacy
|
||||||
from spacy.gold import GoldParse
|
from spacy.gold import GoldParse, minibatch
|
||||||
from spacy.tagger import Tagger
|
from spacy.pipeline import NeuralEntityRecognizer
|
||||||
|
from spacy.pipeline import TokenVectorEncoder
|
||||||
|
|
||||||
|
|
||||||
|
def get_gold_parses(tokenizer, train_data):
|
||||||
|
'''Shuffle and create GoldParse objects'''
|
||||||
|
random.shuffle(train_data)
|
||||||
|
for raw_text, entity_offsets in train_data:
|
||||||
|
doc = tokenizer(raw_text)
|
||||||
|
gold = GoldParse(doc, entities=entity_offsets)
|
||||||
|
yield doc, gold
|
||||||
|
|
||||||
|
|
||||||
def train_ner(nlp, train_data, output_dir):
|
def train_ner(nlp, train_data, output_dir):
|
||||||
# Add new words to vocab
|
|
||||||
for raw_text, _ in train_data:
|
|
||||||
doc = nlp.make_doc(raw_text)
|
|
||||||
for word in doc:
|
|
||||||
_ = nlp.vocab[word.orth]
|
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
# You may need to change the learning rate. It's generally difficult to
|
optimizer = nlp.begin_training(lambda: [])
|
||||||
# guess what rate you should set, especially when you have limited data.
|
nlp.meta['name'] = 'en_ent_animal'
|
||||||
nlp.entity.model.learn_rate = 0.001
|
for itn in range(50):
|
||||||
for itn in range(1000):
|
losses = {}
|
||||||
random.shuffle(train_data)
|
for batch in minibatch(get_gold_parses(nlp.make_doc, train_data), size=3):
|
||||||
loss = 0.
|
docs, golds = zip(*batch)
|
||||||
for raw_text, entity_offsets in train_data:
|
nlp.update(docs, golds, losses=losses, sgd=optimizer, update_shared=True,
|
||||||
gold = GoldParse(doc, entities=entity_offsets)
|
drop=0.35)
|
||||||
# By default, the GoldParse class assumes that the entities
|
print(losses)
|
||||||
# described by offset are complete, and all other words should
|
if not output_dir:
|
||||||
# have the tag 'O'. You can tell it to make no assumptions
|
return
|
||||||
# about the tag of a word by giving it the tag '-'.
|
elif not output_dir.exists():
|
||||||
# However, this allows a trivial solution to the current
|
output_dir.mkdir()
|
||||||
# learning problem: if words are either 'any tag' or 'ANIMAL',
|
nlp.to_disk(output_dir)
|
||||||
# the model can learn that all words can be tagged 'ANIMAL'.
|
|
||||||
#for i in range(len(gold.ner)):
|
|
||||||
#if not gold.ner[i].endswith('ANIMAL'):
|
|
||||||
# gold.ner[i] = '-'
|
|
||||||
doc = nlp.make_doc(raw_text)
|
|
||||||
nlp.tagger(doc)
|
|
||||||
# As of 1.9, spaCy's parser now lets you supply a dropout probability
|
|
||||||
# This might help the model generalize better from only a few
|
|
||||||
# examples.
|
|
||||||
loss += nlp.entity.update(doc, gold, drop=0.9)
|
|
||||||
if loss == 0:
|
|
||||||
break
|
|
||||||
# This step averages the model's weights. This may or may not be good for
|
|
||||||
# your situation --- it's empirical.
|
|
||||||
nlp.end_training()
|
|
||||||
if output_dir:
|
|
||||||
if not output_dir.exists():
|
|
||||||
output_dir.mkdir()
|
|
||||||
nlp.save_to_directory(output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def main(model_name, output_directory=None):
|
def main(model_name, output_directory=None):
|
||||||
print("Loading initial model", model_name)
|
print("Creating initial model", model_name)
|
||||||
nlp = spacy.load(model_name)
|
nlp = spacy.blank(model_name)
|
||||||
if output_directory is not None:
|
if output_directory is not None:
|
||||||
output_directory = Path(output_directory)
|
output_directory = Path(output_directory)
|
||||||
|
|
||||||
|
@ -91,6 +77,11 @@ def main(model_name, output_directory=None):
|
||||||
"Horses are too tall and they pretend to care about your feelings",
|
"Horses are too tall and they pretend to care about your feelings",
|
||||||
[(0, 6, 'ANIMAL')],
|
[(0, 6, 'ANIMAL')],
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
"Do they bite?",
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
|
||||||
(
|
(
|
||||||
"horses are too tall and they pretend to care about your feelings",
|
"horses are too tall and they pretend to care about your feelings",
|
||||||
[(0, 6, 'ANIMAL')]
|
[(0, 6, 'ANIMAL')]
|
||||||
|
@ -109,18 +100,20 @@ def main(model_name, output_directory=None):
|
||||||
)
|
)
|
||||||
|
|
||||||
]
|
]
|
||||||
nlp.entity.add_label('ANIMAL')
|
nlp.pipeline.append(TokenVectorEncoder(nlp.vocab))
|
||||||
|
nlp.pipeline.append(NeuralEntityRecognizer(nlp.vocab))
|
||||||
|
nlp.pipeline[-1].add_label('ANIMAL')
|
||||||
train_ner(nlp, train_data, output_directory)
|
train_ner(nlp, train_data, output_directory)
|
||||||
|
|
||||||
# Test that the entity is recognized
|
# Test that the entity is recognized
|
||||||
doc = nlp('Do you like horses?')
|
text = 'Do you like horses?'
|
||||||
print("Ents in 'Do you like horses?':")
|
print("Ents in 'Do you like horses?':")
|
||||||
|
doc = nlp(text)
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
print(ent.label_, ent.text)
|
print(ent.label_, ent.text)
|
||||||
if output_directory:
|
if output_directory:
|
||||||
print("Loading from", output_directory)
|
print("Loading from", output_directory)
|
||||||
nlp2 = spacy.load('en', path=output_directory)
|
nlp2 = spacy.load(output_directory)
|
||||||
nlp2.entity.add_label('ANIMAL')
|
|
||||||
doc2 = nlp2('Do you like horses?')
|
doc2 = nlp2('Do you like horses?')
|
||||||
for ent in doc2.ents:
|
for ent in doc2.ents:
|
||||||
print(ent.label_, ent.text)
|
print(ent.label_, ent.text)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user