Add meta.json option to cli.train and add relevant properties

Add accuracy scores to meta.json instead of accuracy.json and replace
all relevant properties like lang, pipeline, spacy_version in existing
meta.json. If not present, also add name and version placeholders to
make it packagable.
This commit is contained in:
ines 2017-09-25 19:00:47 +02:00
parent d2d35b63b7
commit edf7e4881d

View File

@ -18,6 +18,7 @@ from ..gold import GoldParse, merge_sents
from ..gold import GoldCorpus, minibatch
from ..util import prints
from .. import util
from .. import about
from .. import displacy
from ..compat import json_dumps
@ -35,10 +36,11 @@ from ..compat import json_dumps
no_parser=("Don't train parser", "flag", "P", bool),
no_entities=("Don't train NER", "flag", "N", bool),
gold_preproc=("Use gold preprocessing", "flag", "G", bool),
meta_path=("Optional path to meta.json. All relevant properties will be overwritten.", "option", "m", Path)
)
def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
use_gpu=-1, vectors=None, no_tagger=False, no_parser=False, no_entities=False,
gold_preproc=False):
gold_preproc=False, meta_path=None):
"""
Train a model. Expects data in spaCy's JSON format.
"""
@ -47,13 +49,19 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
output_path = util.ensure_path(output_dir)
train_path = util.ensure_path(train_data)
dev_path = util.ensure_path(dev_data)
meta_path = util.ensure_path(meta_path)
if not output_path.exists():
output_path.mkdir()
if not train_path.exists():
prints(train_path, title="Training data not found", exits=1)
if dev_path and not dev_path.exists():
prints(dev_path, title="Development data not found", exits=1)
if meta_path is not None and not meta_path.exists():
prints(meta_path, title="meta.json not found", exits=1)
meta = util.read_json(meta_path) if meta_path else {}
if not isinstance(meta, dict):
prints("Expected dict but got: {}".format(type(meta)),
title="Not a valid meta.json format", exits=1)
pipeline = ['token_vectors', 'tags', 'dependencies', 'entities']
if no_tagger and 'tags' in pipeline: pipeline.remove('tags')
@ -105,9 +113,16 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
corpus.dev_docs(
nlp,
gold_preproc=gold_preproc))
acc_loc =(output_path / ('model%d' % i) / 'accuracy.json')
with acc_loc.open('w') as file_:
file_.write(json_dumps(scorer.scores))
meta_loc = output_path / ('model%d' % i) / 'meta.json'
meta['accuracy'] = scorer.scores
meta['lang'] = nlp.lang
meta['pipeline'] = pipeline
meta['spacy_version'] = '>=%s' % about.__version__
meta.setdefault('name', 'model%d' % i)
meta.setdefault('version', '0.0.0')
with meta_loc.open('w') as file_:
file_.write(json_dumps(meta))
util.set_env_log(True)
print_progress(i, losses, scorer.scores)
finally: