Use plac annotations for arguments and add n_iter

This commit is contained in:
ines 2017-10-26 16:11:05 +02:00
parent bc2c92f22d
commit c3b681e5fb

View File

@ -29,6 +29,7 @@ Last updated for: spaCy 2.0.0a18
""" """
from __future__ import unicode_literals, print_function from __future__ import unicode_literals, print_function
import plac
import random import random
from pathlib import Path from pathlib import Path
@ -58,16 +59,13 @@ TRAIN_DATA = [
] ]
def main(model=None, new_model_name='animal', output_dir=None): @plac.annotations(
"""Set up the pipeline and entity recognizer, and train the new entity. model=("Model name. Defaults to blank 'en' model.", "option", "m", str),
new_model_name=("New model name for model meta.", "option", "nm", str),
model (unicode): Model name to start off with. If None, a blank English output_dir=("Optional output directory", "option", "o", Path),
Language class is created. n_iter=("Number of training iterations", "option", "n", int))
new_model_name (unicode): Name of new model to create. Will be added to the def main(model=None, new_model_name='animal', output_dir=None, n_iter=50):
model meta and prefixed by the language code, e.g. 'en_animal'. """Set up the pipeline and entity recognizer, and train the new entity."""
output_dir (unicode / Path): Optional output directory. If None, no model
will be saved.
"""
if model is not None: if model is not None:
nlp = spacy.load(model) # load existing spaCy model nlp = spacy.load(model) # load existing spaCy model
print("Loaded model '%s'" % model) print("Loaded model '%s'" % model)
@ -91,7 +89,7 @@ def main(model=None, new_model_name='animal', output_dir=None):
with nlp.disable_pipes(*other_pipes) as disabled: # only train NER with nlp.disable_pipes(*other_pipes) as disabled: # only train NER
random.seed(0) random.seed(0)
optimizer = nlp.begin_training(lambda: []) optimizer = nlp.begin_training(lambda: [])
for itn in range(50): for itn in range(n_iter):
losses = {} losses = {}
gold_parses = get_gold_parses(nlp.make_doc, TRAIN_DATA) gold_parses = get_gold_parses(nlp.make_doc, TRAIN_DATA)
for batch in minibatch(gold_parses, size=3): for batch in minibatch(gold_parses, size=3):
@ -139,5 +137,4 @@ def get_gold_parses(tokenizer, train_data):
if __name__ == '__main__': if __name__ == '__main__':
import plac
plac.call(main) plac.call(main)