mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Add multitask objectives options to train CLI
This commit is contained in:
parent
8f06903e09
commit
a34749b2bf
|
@ -30,11 +30,14 @@ from ..compat import json_dumps
|
|||
no_tagger=("Don't train tagger", "flag", "T", bool),
|
||||
no_parser=("Don't train parser", "flag", "P", bool),
|
||||
no_entities=("Don't train NER", "flag", "N", bool),
|
||||
parser_multitasks=("Side objectives for parser CNN, e.g. dep dep,tag", "option", "pt", ","),
|
||||
entity_multitasks=("Side objectives for ner CNN, e.g. dep dep,tag", "option", "et", ","),
|
||||
gold_preproc=("Use gold preprocessing", "flag", "G", bool),
|
||||
version=("Model version", "option", "V", str),
|
||||
meta_path=("Optional path to meta.json. All relevant properties will be "
|
||||
"overwritten.", "option", "m", Path))
|
||||
def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
||||
parser_multitasks='', entity_multitasks='',
|
||||
use_gpu=-1, vectors=None, no_tagger=False,
|
||||
no_parser=False, no_entities=False, gold_preproc=False,
|
||||
version="0.0.0", meta_path=None):
|
||||
|
@ -102,6 +105,10 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
|||
lex.is_oov = False
|
||||
for name in pipeline:
|
||||
nlp.add_pipe(nlp.create_pipe(name), name=name)
|
||||
for objective in parser_multitasks.split(','):
|
||||
nlp.parser.add_multitask_objective(objective)
|
||||
for objective in entity_multitasks.split(','):
|
||||
nlp.entity.add_multitask_objective(objective)
|
||||
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
|
||||
nlp._optimizer = None
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user