Fix CLI for multitask objectives

This commit is contained in:
Matthew Honnibal 2018-02-18 10:59:11 +01:00
parent a34749b2bf
commit 86405e4ad1

View File

@ -30,8 +30,8 @@ from ..compat import json_dumps
no_tagger=("Don't train tagger", "flag", "T", bool), no_tagger=("Don't train tagger", "flag", "T", bool),
no_parser=("Don't train parser", "flag", "P", bool), no_parser=("Don't train parser", "flag", "P", bool),
no_entities=("Don't train NER", "flag", "N", bool), no_entities=("Don't train NER", "flag", "N", bool),
parser_multitasks=("Side objectives for parser CNN, e.g. dep dep,tag", "option", "pt", ","), parser_multitasks=("Side objectives for parser CNN, e.g. dep dep,tag", "option", "pt", str),
entity_multitasks=("Side objectives for ner CNN, e.g. dep dep,tag", "option", "et", ","), entity_multitasks=("Side objectives for ner CNN, e.g. dep dep,tag", "option", "et", str),
gold_preproc=("Use gold preprocessing", "flag", "G", bool), gold_preproc=("Use gold preprocessing", "flag", "G", bool),
version=("Model version", "option", "V", str), version=("Model version", "option", "V", str),
meta_path=("Optional path to meta.json. All relevant properties will be " meta_path=("Optional path to meta.json. All relevant properties will be "
@ -105,10 +105,12 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
lex.is_oov = False lex.is_oov = False
for name in pipeline: for name in pipeline:
nlp.add_pipe(nlp.create_pipe(name), name=name) nlp.add_pipe(nlp.create_pipe(name), name=name)
for objective in parser_multitasks.split(','): if parser_multitasks:
nlp.parser.add_multitask_objective(objective) for objective in parser_multitasks.split(','):
for objective in entity_multitasks.split(','): nlp.parser.add_multitask_objective(objective)
nlp.entity.add_multitask_objective(objective) if entity_multitasks:
for objective in entity_multitasks.split(','):
nlp.entity.add_multitask_objective(objective)
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu) optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
nlp._optimizer = None nlp._optimizer = None