mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Add meta.json option to cli.train and add relevant properties
Add accuracy scores to meta.json instead of accuracy.json and replace all relevant properties like lang, pipeline, spacy_version in existing meta.json. If not present, also add name and version placeholders to make it packagable.
This commit is contained in:
parent
d2d35b63b7
commit
edf7e4881d
|
@ -18,6 +18,7 @@ from ..gold import GoldParse, merge_sents
|
||||||
from ..gold import GoldCorpus, minibatch
|
from ..gold import GoldCorpus, minibatch
|
||||||
from ..util import prints
|
from ..util import prints
|
||||||
from .. import util
|
from .. import util
|
||||||
|
from .. import about
|
||||||
from .. import displacy
|
from .. import displacy
|
||||||
from ..compat import json_dumps
|
from ..compat import json_dumps
|
||||||
|
|
||||||
|
@ -35,10 +36,11 @@ from ..compat import json_dumps
|
||||||
no_parser=("Don't train parser", "flag", "P", bool),
|
no_parser=("Don't train parser", "flag", "P", bool),
|
||||||
no_entities=("Don't train NER", "flag", "N", bool),
|
no_entities=("Don't train NER", "flag", "N", bool),
|
||||||
gold_preproc=("Use gold preprocessing", "flag", "G", bool),
|
gold_preproc=("Use gold preprocessing", "flag", "G", bool),
|
||||||
|
meta_path=("Optional path to meta.json. All relevant properties will be overwritten.", "option", "m", Path)
|
||||||
)
|
)
|
||||||
def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
use_gpu=-1, vectors=None, no_tagger=False, no_parser=False, no_entities=False,
|
use_gpu=-1, vectors=None, no_tagger=False, no_parser=False, no_entities=False,
|
||||||
gold_preproc=False):
|
gold_preproc=False, meta_path=None):
|
||||||
"""
|
"""
|
||||||
Train a model. Expects data in spaCy's JSON format.
|
Train a model. Expects data in spaCy's JSON format.
|
||||||
"""
|
"""
|
||||||
|
@ -47,13 +49,19 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
output_path = util.ensure_path(output_dir)
|
output_path = util.ensure_path(output_dir)
|
||||||
train_path = util.ensure_path(train_data)
|
train_path = util.ensure_path(train_data)
|
||||||
dev_path = util.ensure_path(dev_data)
|
dev_path = util.ensure_path(dev_data)
|
||||||
|
meta_path = util.ensure_path(meta_path)
|
||||||
if not output_path.exists():
|
if not output_path.exists():
|
||||||
output_path.mkdir()
|
output_path.mkdir()
|
||||||
if not train_path.exists():
|
if not train_path.exists():
|
||||||
prints(train_path, title="Training data not found", exits=1)
|
prints(train_path, title="Training data not found", exits=1)
|
||||||
if dev_path and not dev_path.exists():
|
if dev_path and not dev_path.exists():
|
||||||
prints(dev_path, title="Development data not found", exits=1)
|
prints(dev_path, title="Development data not found", exits=1)
|
||||||
|
if meta_path is not None and not meta_path.exists():
|
||||||
|
prints(meta_path, title="meta.json not found", exits=1)
|
||||||
|
meta = util.read_json(meta_path) if meta_path else {}
|
||||||
|
if not isinstance(meta, dict):
|
||||||
|
prints("Expected dict but got: {}".format(type(meta)),
|
||||||
|
title="Not a valid meta.json format", exits=1)
|
||||||
|
|
||||||
pipeline = ['token_vectors', 'tags', 'dependencies', 'entities']
|
pipeline = ['token_vectors', 'tags', 'dependencies', 'entities']
|
||||||
if no_tagger and 'tags' in pipeline: pipeline.remove('tags')
|
if no_tagger and 'tags' in pipeline: pipeline.remove('tags')
|
||||||
|
@ -105,9 +113,16 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
corpus.dev_docs(
|
corpus.dev_docs(
|
||||||
nlp,
|
nlp,
|
||||||
gold_preproc=gold_preproc))
|
gold_preproc=gold_preproc))
|
||||||
acc_loc =(output_path / ('model%d' % i) / 'accuracy.json')
|
meta_loc = output_path / ('model%d' % i) / 'meta.json'
|
||||||
with acc_loc.open('w') as file_:
|
meta['accuracy'] = scorer.scores
|
||||||
file_.write(json_dumps(scorer.scores))
|
meta['lang'] = nlp.lang
|
||||||
|
meta['pipeline'] = pipeline
|
||||||
|
meta['spacy_version'] = '>=%s' % about.__version__
|
||||||
|
meta.setdefault('name', 'model%d' % i)
|
||||||
|
meta.setdefault('version', '0.0.0')
|
||||||
|
|
||||||
|
with meta_loc.open('w') as file_:
|
||||||
|
file_.write(json_dumps(meta))
|
||||||
util.set_env_log(True)
|
util.set_env_log(True)
|
||||||
print_progress(i, losses, scorer.scores)
|
print_progress(i, losses, scorer.scores)
|
||||||
finally:
|
finally:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user