Add multitask objectives options to train CLI

This commit is contained in:
Matthew Honnibal 2018-02-17 22:03:54 +01:00
parent 8f06903e09
commit a34749b2bf

View File

@ -30,11 +30,14 @@ from ..compat import json_dumps
no_tagger=("Don't train tagger", "flag", "T", bool),
no_parser=("Don't train parser", "flag", "P", bool),
no_entities=("Don't train NER", "flag", "N", bool),
parser_multitasks=("Side objectives for parser CNN, e.g. dep dep,tag", "option", "pt", ","),
entity_multitasks=("Side objectives for ner CNN, e.g. dep dep,tag", "option", "et", ","),
gold_preproc=("Use gold preprocessing", "flag", "G", bool),
version=("Model version", "option", "V", str),
meta_path=("Optional path to meta.json. All relevant properties will be "
"overwritten.", "option", "m", Path))
def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
parser_multitasks='', entity_multitasks='',
use_gpu=-1, vectors=None, no_tagger=False,
no_parser=False, no_entities=False, gold_preproc=False,
version="0.0.0", meta_path=None):
@ -102,6 +105,10 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
lex.is_oov = False
for name in pipeline:
nlp.add_pipe(nlp.create_pipe(name), name=name)
for objective in parser_multitasks.split(','):
nlp.parser.add_multitask_objective(objective)
for objective in entity_multitasks.split(','):
nlp.entity.add_multitask_objective(objective)
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
nlp._optimizer = None