mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Make dev data optional
This commit is contained in:
parent
0fc56e2544
commit
53cf2f1c0e
|
@ -16,7 +16,7 @@ from .. import util
|
||||||
def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ner):
|
def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ner):
|
||||||
output_path = Path(output_dir)
|
output_path = Path(output_dir)
|
||||||
train_path = Path(train_data)
|
train_path = Path(train_data)
|
||||||
dev_path = Path(dev_data)
|
dev_path = Path(dev_data) if dev_data else None
|
||||||
check_dirs(output_path, data_path, dev_path)
|
check_dirs(output_path, data_path, dev_path)
|
||||||
|
|
||||||
lang = util.get_lang_class(language)
|
lang = util.get_lang_class(language)
|
||||||
|
@ -26,12 +26,13 @@ def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ne
|
||||||
parser_cfg['features'] = lang.Defaults.parser_features
|
parser_cfg['features'] = lang.Defaults.parser_features
|
||||||
entity_cfg['features'] = lang.Defaults.entity_features
|
entity_cfg['features'] = lang.Defaults.entity_features
|
||||||
gold_train = list(read_gold_json(train_path))
|
gold_train = list(read_gold_json(train_path))
|
||||||
gold_dev = list(read_gold_json(dev_path))
|
gold_dev = list(read_gold_json(dev_path)) if dev_path else None
|
||||||
|
|
||||||
train_model(lang, gold_train, gold_dev, output_path, tagger_cfg, parser_cfg,
|
train_model(lang, gold_train, gold_dev, output_path, tagger_cfg, parser_cfg,
|
||||||
entity_cfg, n_iter)
|
entity_cfg, n_iter)
|
||||||
scorer = evaluate(lang, list(read_gold_json(dev_loc)), output_path)
|
if gold_dev:
|
||||||
print_results(scorer)
|
scorer = evaluate(lang, gold_dev, output_path)
|
||||||
|
print_results(scorer)
|
||||||
|
|
||||||
|
|
||||||
def train_config(config):
|
def train_config(config):
|
||||||
|
@ -54,7 +55,7 @@ def train_model(Language, train_data, dev_data, output_path, tagger_cfg, parser_
|
||||||
for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)):
|
for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)):
|
||||||
for doc, gold in epoch:
|
for doc, gold in epoch:
|
||||||
trainer.update(doc, gold)
|
trainer.update(doc, gold)
|
||||||
dev_scores = trainer.evaluate(dev_data)
|
dev_scores = trainer.evaluate(dev_data) if dev_data else []
|
||||||
print_progress(itn, trainer.nlp.parser.model.nr_weight,
|
print_progress(itn, trainer.nlp.parser.model.nr_weight,
|
||||||
trainer.nlp.parser.model.nr_active_feat,
|
trainer.nlp.parser.model.nr_active_feat,
|
||||||
**dev_scores.scores)
|
**dev_scores.scores)
|
||||||
|
@ -82,8 +83,10 @@ def evaluate(Language, gold_tuples, output_path):
|
||||||
def check_dirs(input_path, train_path, dev_path):
|
def check_dirs(input_path, train_path, dev_path):
|
||||||
if not output_path.exists():
|
if not output_path.exists():
|
||||||
util.sys_exit(output_path.as_posix(), title="Output directory not found")
|
util.sys_exit(output_path.as_posix(), title="Output directory not found")
|
||||||
if not train_path.exists() and train_path.is_file():
|
if not train_path.exists() or not train_path.is_file():
|
||||||
util.sys_exit(train_path.as_posix(), title="Training data not found")
|
util.sys_exit(train_path.as_posix(), title="Training data not found")
|
||||||
|
if dev_path and not dev_path.exists():
|
||||||
|
util.sys_exit(dev_path.as_posix(), title="Development data not found")
|
||||||
|
|
||||||
|
|
||||||
def print_progress(itn, nr_weight, nr_active_feat, **scores):
|
def print_progress(itn, nr_weight, nr_active_feat, **scores):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user