mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 00:46:28 +03:00
Update train CLI
This commit is contained in:
parent
5eac089fbe
commit
dec5571bf3
|
@ -81,16 +81,17 @@ class CLI(object):
|
|||
output_dir=("output directory", "positional", None, str),
|
||||
train_data=("training data", "positional", None, str),
|
||||
dev_data=("development data", "positional", None, str),
|
||||
n_iter=("number of iterations", "flag", "n", int),
|
||||
tagger=("train tagger", "flag", "t", bool),
|
||||
parser=("train parser", "flag", "p", bool),
|
||||
ner=("train NER", "flag", "n", bool)
|
||||
n_iter=("number of iterations", "option", "n", int),
|
||||
no_tagger=("Don't train tagger", "flag", "T", bool),
|
||||
no_parser=("Don't train parser", "flag", "P", bool),
|
||||
no_ner=("Don't train NER", "flag", "N", bool)
|
||||
)
|
||||
def train(self, lang, output_dir, train_data, dev_data, n_iter=15, tagger=True,
|
||||
parser=True, ner=True):
|
||||
def train(self, lang, output_dir, train_data, dev_data, n_iter=15,
|
||||
no_tagger=False, no_parser=False, no_ner=False):
|
||||
"""Train a model."""
|
||||
|
||||
cli_train(output_dir, train_data, dev_data, tagger, parser, ner)
|
||||
cli_train(lang, output_dir, train_data, dev_data, n_iter,
|
||||
not no_tagger, not no_parser, not no_ner)
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
|
|
|
@ -17,20 +17,28 @@ def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ne
|
|||
output_path = Path(output_dir)
|
||||
train_path = Path(train_data)
|
||||
dev_path = Path(dev_data)
|
||||
check_dirs(output_path, data_path, dev_path)
|
||||
check_dirs(output_path, train_path, dev_path)
|
||||
|
||||
lang = util.get_lang_class(language)
|
||||
parser_cfg = dict(locals())
|
||||
tagger_cfg = dict(locals())
|
||||
entity_cfg = dict(locals())
|
||||
parser_cfg['features'] = lang.Defaults.parser_features
|
||||
entity_cfg['features'] = lang.Defaults.entity_features
|
||||
parser_cfg = {
|
||||
'pseudoprojective': True,
|
||||
'n_iter': n_iter,
|
||||
'lang': language,
|
||||
'features': lang.Defaults.parser_features}
|
||||
entity_cfg = {
|
||||
'n_iter': n_iter,
|
||||
'lang': language,
|
||||
'features': lang.Defaults.entity_features}
|
||||
tagger_cfg = {
|
||||
'n_iter': n_iter,
|
||||
'lang': language,
|
||||
'features': lang.Defaults.tagger_features}
|
||||
gold_train = list(read_gold_json(train_path))
|
||||
gold_dev = list(read_gold_json(dev_path))
|
||||
|
||||
train_model(lang, gold_train, gold_dev, output_path, tagger_cfg, parser_cfg,
|
||||
entity_cfg, n_iter)
|
||||
scorer = evaluate(lang, list(read_gold_json(dev_loc)), output_path)
|
||||
scorer = evaluate(lang, list(read_gold_json(dev_path)), output_path)
|
||||
print_results(scorer)
|
||||
|
||||
|
||||
|
@ -79,7 +87,7 @@ def evaluate(Language, gold_tuples, output_path):
|
|||
return scorer
|
||||
|
||||
|
||||
def check_dirs(input_path, train_path, dev_path):
|
||||
def check_dirs(output_path, train_path, dev_path):
|
||||
if not output_path.exists():
|
||||
util.sys_exit(output_path.as_posix(), title="Output directory not found")
|
||||
if not train_path.exists() and train_path.is_file():
|
||||
|
@ -92,7 +100,12 @@ def print_progress(itn, nr_weight, nr_active_feat, **scores):
|
|||
|
||||
|
||||
def print_results(scorer):
|
||||
results = {'TOK': scorer.token_acc, 'POS': scorer.tags_acc, 'UAS': scorer.uas,
|
||||
'LAS': scorer.las, 'NER P': scorer.ents_p, 'NER R': scorer.ents_r,
|
||||
'NER F': scorer.ents_f}
|
||||
results = {
|
||||
'TOK': '%.2f' % scorer.token_acc,
|
||||
'POS': '%.2f' % scorer.tags_acc,
|
||||
'UAS': '%.2f' % scorer.uas,
|
||||
'LAS': '%.2f' % scorer.las,
|
||||
'NER P': '%.2f' % scorer.ents_p,
|
||||
'NER R': '%.2f' % scorer.ents_r,
|
||||
'NER F': '%.2f' % scorer.ents_f}
|
||||
util.print_table(results, title="Results")
|
||||
|
|
Loading…
Reference in New Issue
Block a user