mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Fix CLI for multitask objectives
This commit is contained in:
parent
a34749b2bf
commit
86405e4ad1
|
@ -30,8 +30,8 @@ from ..compat import json_dumps
|
||||||
no_tagger=("Don't train tagger", "flag", "T", bool),
|
no_tagger=("Don't train tagger", "flag", "T", bool),
|
||||||
no_parser=("Don't train parser", "flag", "P", bool),
|
no_parser=("Don't train parser", "flag", "P", bool),
|
||||||
no_entities=("Don't train NER", "flag", "N", bool),
|
no_entities=("Don't train NER", "flag", "N", bool),
|
||||||
parser_multitasks=("Side objectives for parser CNN, e.g. dep dep,tag", "option", "pt", ","),
|
parser_multitasks=("Side objectives for parser CNN, e.g. dep dep,tag", "option", "pt", str),
|
||||||
entity_multitasks=("Side objectives for ner CNN, e.g. dep dep,tag", "option", "et", ","),
|
entity_multitasks=("Side objectives for ner CNN, e.g. dep dep,tag", "option", "et", str),
|
||||||
gold_preproc=("Use gold preprocessing", "flag", "G", bool),
|
gold_preproc=("Use gold preprocessing", "flag", "G", bool),
|
||||||
version=("Model version", "option", "V", str),
|
version=("Model version", "option", "V", str),
|
||||||
meta_path=("Optional path to meta.json. All relevant properties will be "
|
meta_path=("Optional path to meta.json. All relevant properties will be "
|
||||||
|
@ -105,10 +105,12 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
||||||
lex.is_oov = False
|
lex.is_oov = False
|
||||||
for name in pipeline:
|
for name in pipeline:
|
||||||
nlp.add_pipe(nlp.create_pipe(name), name=name)
|
nlp.add_pipe(nlp.create_pipe(name), name=name)
|
||||||
for objective in parser_multitasks.split(','):
|
if parser_multitasks:
|
||||||
nlp.parser.add_multitask_objective(objective)
|
for objective in parser_multitasks.split(','):
|
||||||
for objective in entity_multitasks.split(','):
|
nlp.parser.add_multitask_objective(objective)
|
||||||
nlp.entity.add_multitask_objective(objective)
|
if entity_multitasks:
|
||||||
|
for objective in entity_multitasks.split(','):
|
||||||
|
nlp.entity.add_multitask_objective(objective)
|
||||||
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
|
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
|
||||||
nlp._optimizer = None
|
nlp._optimizer = None
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user