mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 00:46:28 +03:00
Add meta.json option to cli.train and add relevant properties
Add accuracy scores to meta.json instead of accuracy.json and replace all relevant properties like lang, pipeline, spacy_version in existing meta.json. If not present, also add name and version placeholders to make it packagable.
This commit is contained in:
parent
d2d35b63b7
commit
edf7e4881d
|
@ -18,6 +18,7 @@ from ..gold import GoldParse, merge_sents
|
|||
from ..gold import GoldCorpus, minibatch
|
||||
from ..util import prints
|
||||
from .. import util
|
||||
from .. import about
|
||||
from .. import displacy
|
||||
from ..compat import json_dumps
|
||||
|
||||
|
@ -35,10 +36,11 @@ from ..compat import json_dumps
|
|||
no_parser=("Don't train parser", "flag", "P", bool),
|
||||
no_entities=("Don't train NER", "flag", "N", bool),
|
||||
gold_preproc=("Use gold preprocessing", "flag", "G", bool),
|
||||
meta_path=("Optional path to meta.json. All relevant properties will be overwritten.", "option", "m", Path)
|
||||
)
|
||||
def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||
use_gpu=-1, vectors=None, no_tagger=False, no_parser=False, no_entities=False,
|
||||
gold_preproc=False):
|
||||
gold_preproc=False, meta_path=None):
|
||||
"""
|
||||
Train a model. Expects data in spaCy's JSON format.
|
||||
"""
|
||||
|
@ -47,13 +49,19 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
|||
output_path = util.ensure_path(output_dir)
|
||||
train_path = util.ensure_path(train_data)
|
||||
dev_path = util.ensure_path(dev_data)
|
||||
meta_path = util.ensure_path(meta_path)
|
||||
if not output_path.exists():
|
||||
output_path.mkdir()
|
||||
if not train_path.exists():
|
||||
prints(train_path, title="Training data not found", exits=1)
|
||||
if dev_path and not dev_path.exists():
|
||||
prints(dev_path, title="Development data not found", exits=1)
|
||||
|
||||
if meta_path is not None and not meta_path.exists():
|
||||
prints(meta_path, title="meta.json not found", exits=1)
|
||||
meta = util.read_json(meta_path) if meta_path else {}
|
||||
if not isinstance(meta, dict):
|
||||
prints("Expected dict but got: {}".format(type(meta)),
|
||||
title="Not a valid meta.json format", exits=1)
|
||||
|
||||
pipeline = ['token_vectors', 'tags', 'dependencies', 'entities']
|
||||
if no_tagger and 'tags' in pipeline: pipeline.remove('tags')
|
||||
|
@ -105,9 +113,16 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
|||
corpus.dev_docs(
|
||||
nlp,
|
||||
gold_preproc=gold_preproc))
|
||||
acc_loc =(output_path / ('model%d' % i) / 'accuracy.json')
|
||||
with acc_loc.open('w') as file_:
|
||||
file_.write(json_dumps(scorer.scores))
|
||||
meta_loc = output_path / ('model%d' % i) / 'meta.json'
|
||||
meta['accuracy'] = scorer.scores
|
||||
meta['lang'] = nlp.lang
|
||||
meta['pipeline'] = pipeline
|
||||
meta['spacy_version'] = '>=%s' % about.__version__
|
||||
meta.setdefault('name', 'model%d' % i)
|
||||
meta.setdefault('version', '0.0.0')
|
||||
|
||||
with meta_loc.open('w') as file_:
|
||||
file_.write(json_dumps(meta))
|
||||
util.set_env_log(True)
|
||||
print_progress(i, losses, scorer.scores)
|
||||
finally:
|
||||
|
|
Loading…
Reference in New Issue
Block a user