mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Update train_new_entity_type example
This commit is contained in:
parent
d74dbde828
commit
4dc57d9e15
|
@ -45,19 +45,19 @@ LABEL = "ANIMAL"
|
||||||
TRAIN_DATA = [
|
TRAIN_DATA = [
|
||||||
(
|
(
|
||||||
"Horses are too tall and they pretend to care about your feelings",
|
"Horses are too tall and they pretend to care about your feelings",
|
||||||
{"entities": [(0, 6, "ANIMAL")]},
|
{"entities": [(0, 6, LABEL)]},
|
||||||
),
|
),
|
||||||
("Do they bite?", {"entities": []}),
|
("Do they bite?", {"entities": []}),
|
||||||
(
|
(
|
||||||
"horses are too tall and they pretend to care about your feelings",
|
"horses are too tall and they pretend to care about your feelings",
|
||||||
{"entities": [(0, 6, "ANIMAL")]},
|
{"entities": [(0, 6, LABEL)]},
|
||||||
),
|
),
|
||||||
("horses pretend to care about your feelings", {"entities": [(0, 6, "ANIMAL")]}),
|
("horses pretend to care about your feelings", {"entities": [(0, 6, LABEL)]}),
|
||||||
(
|
(
|
||||||
"they pretend to care about your feelings, those horses",
|
"they pretend to care about your feelings, those horses",
|
||||||
{"entities": [(48, 54, "ANIMAL")]},
|
{"entities": [(48, 54, LABEL)]},
|
||||||
),
|
),
|
||||||
("horses?", {"entities": [(0, 6, "ANIMAL")]}),
|
("horses?", {"entities": [(0, 6, LABEL)]}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,8 +67,9 @@ TRAIN_DATA = [
|
||||||
output_dir=("Optional output directory", "option", "o", Path),
|
output_dir=("Optional output directory", "option", "o", Path),
|
||||||
n_iter=("Number of training iterations", "option", "n", int),
|
n_iter=("Number of training iterations", "option", "n", int),
|
||||||
)
|
)
|
||||||
def main(model=None, new_model_name="animal", output_dir=None, n_iter=10):
|
def main(model=None, new_model_name="animal", output_dir=None, n_iter=30):
|
||||||
"""Set up the pipeline and entity recognizer, and train the new entity."""
|
"""Set up the pipeline and entity recognizer, and train the new entity."""
|
||||||
|
random.seed(0)
|
||||||
if model is not None:
|
if model is not None:
|
||||||
nlp = spacy.load(model) # load existing spaCy model
|
nlp = spacy.load(model) # load existing spaCy model
|
||||||
print("Loaded model '%s'" % model)
|
print("Loaded model '%s'" % model)
|
||||||
|
@ -85,21 +86,22 @@ def main(model=None, new_model_name="animal", output_dir=None, n_iter=10):
|
||||||
ner = nlp.get_pipe("ner")
|
ner = nlp.get_pipe("ner")
|
||||||
|
|
||||||
ner.add_label(LABEL) # add new entity label to entity recognizer
|
ner.add_label(LABEL) # add new entity label to entity recognizer
|
||||||
|
# Adding extraneous labels shouldn't mess anything up
|
||||||
|
ner.add_label('VEGETABLE')
|
||||||
if model is None:
|
if model is None:
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
else:
|
else:
|
||||||
# Note that 'begin_training' initializes the models, so it'll zero out
|
optimizer = nlp.resume_training()
|
||||||
# existing entity types.
|
move_names = list(ner.move_names)
|
||||||
optimizer = nlp.entity.create_optimizer()
|
|
||||||
|
|
||||||
# get names of other pipes to disable them during training
|
# get names of other pipes to disable them during training
|
||||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"]
|
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"]
|
||||||
with nlp.disable_pipes(*other_pipes): # only train NER
|
with nlp.disable_pipes(*other_pipes): # only train NER
|
||||||
|
sizes = compounding(1.0, 4.0, 1.001)
|
||||||
|
# batch up the examples using spaCy's minibatch
|
||||||
for itn in range(n_iter):
|
for itn in range(n_iter):
|
||||||
random.shuffle(TRAIN_DATA)
|
random.shuffle(TRAIN_DATA)
|
||||||
|
batches = minibatch(TRAIN_DATA, size=sizes)
|
||||||
losses = {}
|
losses = {}
|
||||||
# batch up the examples using spaCy's minibatch
|
|
||||||
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
|
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
texts, annotations = zip(*batch)
|
texts, annotations = zip(*batch)
|
||||||
nlp.update(texts, annotations, sgd=optimizer, drop=0.35, losses=losses)
|
nlp.update(texts, annotations, sgd=optimizer, drop=0.35, losses=losses)
|
||||||
|
@ -124,6 +126,8 @@ def main(model=None, new_model_name="animal", output_dir=None, n_iter=10):
|
||||||
# test the saved model
|
# test the saved model
|
||||||
print("Loading from", output_dir)
|
print("Loading from", output_dir)
|
||||||
nlp2 = spacy.load(output_dir)
|
nlp2 = spacy.load(output_dir)
|
||||||
|
# Check the classes have loaded back consistently
|
||||||
|
assert nlp2.get_pipe('ner').move_names == move_names
|
||||||
doc2 = nlp2(test_text)
|
doc2 = nlp2(test_text)
|
||||||
for ent in doc2.ents:
|
for ent in doc2.ents:
|
||||||
print(ent.label_, ent.text)
|
print(ent.label_, ent.text)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user