From c3b681e5fbe157ea70167da1e67c740e8339af6f Mon Sep 17 00:00:00 2001 From: ines Date: Thu, 26 Oct 2017 16:11:05 +0200 Subject: [PATCH] Use plac annotations for arguments and add n_iter --- examples/training/train_new_entity_type.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/examples/training/train_new_entity_type.py b/examples/training/train_new_entity_type.py index ea6c08763..69ee20e04 100644 --- a/examples/training/train_new_entity_type.py +++ b/examples/training/train_new_entity_type.py @@ -29,6 +29,7 @@ Last updated for: spaCy 2.0.0a18 """ from __future__ import unicode_literals, print_function +import plac import random from pathlib import Path @@ -58,16 +59,13 @@ TRAIN_DATA = [ ] -def main(model=None, new_model_name='animal', output_dir=None): - """Set up the pipeline and entity recognizer, and train the new entity. - - model (unicode): Model name to start off with. If None, a blank English - Language class is created. - new_model_name (unicode): Name of new model to create. Will be added to the - model meta and prefixed by the language code, e.g. 'en_animal'. - output_dir (unicode / Path): Optional output directory. If None, no model - will be saved. - """ +@plac.annotations( + model=("Model name. Defaults to blank 'en' model.", "option", "m", str), + new_model_name=("New model name for model meta.", "option", "nm", str), + output_dir=("Optional output directory", "option", "o", Path), + n_iter=("Number of training iterations", "option", "n", int)) +def main(model=None, new_model_name='animal', output_dir=None, n_iter=50): + """Set up the pipeline and entity recognizer, and train the new entity.""" if model is not None: nlp = spacy.load(model) # load existing spaCy model print("Loaded model '%s'" % model) @@ -91,7 +89,7 @@ def main(model=None, new_model_name='animal', output_dir=None): with nlp.disable_pipes(*other_pipes) as disabled: # only train NER random.seed(0) optimizer = nlp.begin_training(lambda: []) - for itn in range(50): + for itn in range(n_iter): losses = {} gold_parses = get_gold_parses(nlp.make_doc, TRAIN_DATA) for batch in minibatch(gold_parses, size=3): @@ -139,5 +137,4 @@ def get_gold_parses(tokenizer, train_data): if __name__ == '__main__': - import plac plac.call(main)