mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Update train CLI
This commit is contained in:
parent
5eac089fbe
commit
dec5571bf3
|
@ -81,16 +81,17 @@ class CLI(object):
|
||||||
output_dir=("output directory", "positional", None, str),
|
output_dir=("output directory", "positional", None, str),
|
||||||
train_data=("training data", "positional", None, str),
|
train_data=("training data", "positional", None, str),
|
||||||
dev_data=("development data", "positional", None, str),
|
dev_data=("development data", "positional", None, str),
|
||||||
n_iter=("number of iterations", "flag", "n", int),
|
n_iter=("number of iterations", "option", "n", int),
|
||||||
tagger=("train tagger", "flag", "t", bool),
|
no_tagger=("Don't train tagger", "flag", "T", bool),
|
||||||
parser=("train parser", "flag", "p", bool),
|
no_parser=("Don't train parser", "flag", "P", bool),
|
||||||
ner=("train NER", "flag", "n", bool)
|
no_ner=("Don't train NER", "flag", "N", bool)
|
||||||
)
|
)
|
||||||
def train(self, lang, output_dir, train_data, dev_data, n_iter=15, tagger=True,
|
def train(self, lang, output_dir, train_data, dev_data, n_iter=15,
|
||||||
parser=True, ner=True):
|
no_tagger=False, no_parser=False, no_ner=False):
|
||||||
"""Train a model."""
|
"""Train a model."""
|
||||||
|
|
||||||
cli_train(output_dir, train_data, dev_data, tagger, parser, ner)
|
cli_train(lang, output_dir, train_data, dev_data, n_iter,
|
||||||
|
not no_tagger, not no_parser, not no_ner)
|
||||||
|
|
||||||
|
|
||||||
@plac.annotations(
|
@plac.annotations(
|
||||||
|
|
|
@ -17,20 +17,28 @@ def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ne
|
||||||
output_path = Path(output_dir)
|
output_path = Path(output_dir)
|
||||||
train_path = Path(train_data)
|
train_path = Path(train_data)
|
||||||
dev_path = Path(dev_data)
|
dev_path = Path(dev_data)
|
||||||
check_dirs(output_path, data_path, dev_path)
|
check_dirs(output_path, train_path, dev_path)
|
||||||
|
|
||||||
lang = util.get_lang_class(language)
|
lang = util.get_lang_class(language)
|
||||||
parser_cfg = dict(locals())
|
parser_cfg = {
|
||||||
tagger_cfg = dict(locals())
|
'pseudoprojective': True,
|
||||||
entity_cfg = dict(locals())
|
'n_iter': n_iter,
|
||||||
parser_cfg['features'] = lang.Defaults.parser_features
|
'lang': language,
|
||||||
entity_cfg['features'] = lang.Defaults.entity_features
|
'features': lang.Defaults.parser_features}
|
||||||
|
entity_cfg = {
|
||||||
|
'n_iter': n_iter,
|
||||||
|
'lang': language,
|
||||||
|
'features': lang.Defaults.entity_features}
|
||||||
|
tagger_cfg = {
|
||||||
|
'n_iter': n_iter,
|
||||||
|
'lang': language,
|
||||||
|
'features': lang.Defaults.tagger_features}
|
||||||
gold_train = list(read_gold_json(train_path))
|
gold_train = list(read_gold_json(train_path))
|
||||||
gold_dev = list(read_gold_json(dev_path))
|
gold_dev = list(read_gold_json(dev_path))
|
||||||
|
|
||||||
train_model(lang, gold_train, gold_dev, output_path, tagger_cfg, parser_cfg,
|
train_model(lang, gold_train, gold_dev, output_path, tagger_cfg, parser_cfg,
|
||||||
entity_cfg, n_iter)
|
entity_cfg, n_iter)
|
||||||
scorer = evaluate(lang, list(read_gold_json(dev_loc)), output_path)
|
scorer = evaluate(lang, list(read_gold_json(dev_path)), output_path)
|
||||||
print_results(scorer)
|
print_results(scorer)
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,7 +87,7 @@ def evaluate(Language, gold_tuples, output_path):
|
||||||
return scorer
|
return scorer
|
||||||
|
|
||||||
|
|
||||||
def check_dirs(input_path, train_path, dev_path):
|
def check_dirs(output_path, train_path, dev_path):
|
||||||
if not output_path.exists():
|
if not output_path.exists():
|
||||||
util.sys_exit(output_path.as_posix(), title="Output directory not found")
|
util.sys_exit(output_path.as_posix(), title="Output directory not found")
|
||||||
if not train_path.exists() and train_path.is_file():
|
if not train_path.exists() and train_path.is_file():
|
||||||
|
@ -92,7 +100,12 @@ def print_progress(itn, nr_weight, nr_active_feat, **scores):
|
||||||
|
|
||||||
|
|
||||||
def print_results(scorer):
|
def print_results(scorer):
|
||||||
results = {'TOK': scorer.token_acc, 'POS': scorer.tags_acc, 'UAS': scorer.uas,
|
results = {
|
||||||
'LAS': scorer.las, 'NER P': scorer.ents_p, 'NER R': scorer.ents_r,
|
'TOK': '%.2f' % scorer.token_acc,
|
||||||
'NER F': scorer.ents_f}
|
'POS': '%.2f' % scorer.tags_acc,
|
||||||
|
'UAS': '%.2f' % scorer.uas,
|
||||||
|
'LAS': '%.2f' % scorer.las,
|
||||||
|
'NER P': '%.2f' % scorer.ents_p,
|
||||||
|
'NER R': '%.2f' % scorer.ents_r,
|
||||||
|
'NER F': '%.2f' % scorer.ents_f}
|
||||||
util.print_table(results, title="Results")
|
util.print_table(results, title="Results")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user