mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-25 00:34:20 +03:00
Update train_new_entity_type example
This commit is contained in:
parent
d74dbde828
commit
4dc57d9e15
|
@ -45,19 +45,19 @@ LABEL = "ANIMAL"
|
|||
TRAIN_DATA = [
|
||||
(
|
||||
"Horses are too tall and they pretend to care about your feelings",
|
||||
{"entities": [(0, 6, "ANIMAL")]},
|
||||
{"entities": [(0, 6, LABEL)]},
|
||||
),
|
||||
("Do they bite?", {"entities": []}),
|
||||
(
|
||||
"horses are too tall and they pretend to care about your feelings",
|
||||
{"entities": [(0, 6, "ANIMAL")]},
|
||||
{"entities": [(0, 6, LABEL)]},
|
||||
),
|
||||
("horses pretend to care about your feelings", {"entities": [(0, 6, "ANIMAL")]}),
|
||||
("horses pretend to care about your feelings", {"entities": [(0, 6, LABEL)]}),
|
||||
(
|
||||
"they pretend to care about your feelings, those horses",
|
||||
{"entities": [(48, 54, "ANIMAL")]},
|
||||
{"entities": [(48, 54, LABEL)]},
|
||||
),
|
||||
("horses?", {"entities": [(0, 6, "ANIMAL")]}),
|
||||
("horses?", {"entities": [(0, 6, LABEL)]}),
|
||||
]
|
||||
|
||||
|
||||
|
@ -67,8 +67,9 @@ TRAIN_DATA = [
|
|||
output_dir=("Optional output directory", "option", "o", Path),
|
||||
n_iter=("Number of training iterations", "option", "n", int),
|
||||
)
|
||||
def main(model=None, new_model_name="animal", output_dir=None, n_iter=10):
|
||||
def main(model=None, new_model_name="animal", output_dir=None, n_iter=30):
|
||||
"""Set up the pipeline and entity recognizer, and train the new entity."""
|
||||
random.seed(0)
|
||||
if model is not None:
|
||||
nlp = spacy.load(model) # load existing spaCy model
|
||||
print("Loaded model '%s'" % model)
|
||||
|
@ -85,21 +86,22 @@ def main(model=None, new_model_name="animal", output_dir=None, n_iter=10):
|
|||
ner = nlp.get_pipe("ner")
|
||||
|
||||
ner.add_label(LABEL) # add new entity label to entity recognizer
|
||||
# Adding extraneous labels shouldn't mess anything up
|
||||
ner.add_label('VEGETABLE')
|
||||
if model is None:
|
||||
optimizer = nlp.begin_training()
|
||||
else:
|
||||
# Note that 'begin_training' initializes the models, so it'll zero out
|
||||
# existing entity types.
|
||||
optimizer = nlp.entity.create_optimizer()
|
||||
|
||||
optimizer = nlp.resume_training()
|
||||
move_names = list(ner.move_names)
|
||||
# get names of other pipes to disable them during training
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"]
|
||||
with nlp.disable_pipes(*other_pipes): # only train NER
|
||||
sizes = compounding(1.0, 4.0, 1.001)
|
||||
# batch up the examples using spaCy's minibatch
|
||||
for itn in range(n_iter):
|
||||
random.shuffle(TRAIN_DATA)
|
||||
batches = minibatch(TRAIN_DATA, size=sizes)
|
||||
losses = {}
|
||||
# batch up the examples using spaCy's minibatch
|
||||
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
|
||||
for batch in batches:
|
||||
texts, annotations = zip(*batch)
|
||||
nlp.update(texts, annotations, sgd=optimizer, drop=0.35, losses=losses)
|
||||
|
@ -124,6 +126,8 @@ def main(model=None, new_model_name="animal", output_dir=None, n_iter=10):
|
|||
# test the saved model
|
||||
print("Loading from", output_dir)
|
||||
nlp2 = spacy.load(output_dir)
|
||||
# Check the classes have loaded back consistently
|
||||
assert nlp2.get_pipe('ner').move_names == move_names
|
||||
doc2 = nlp2(test_text)
|
||||
for ent in doc2.ents:
|
||||
print(ent.label_, ent.text)
|
||||
|
|
Loading…
Reference in New Issue
Block a user