Add gold_preproc flag to cli/train

This commit is contained in:
Matthew Honnibal 2017-08-20 11:07:00 -05:00
parent 42fa84075f
commit 84bb543e4d

View File

@ -32,10 +32,12 @@ from ..compat import json_dumps
resume=("Whether to resume training", "flag", "R", bool),
no_tagger=("Don't train tagger", "flag", "T", bool),
no_parser=("Don't train parser", "flag", "P", bool),
no_entities=("Don't train NER", "flag", "N", bool)
no_entities=("Don't train NER", "flag", "N", bool),
gold_preproc=("Use gold preprocessing", "flag", "G", bool),
)
def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
use_gpu=-1, resume=False, no_tagger=False, no_parser=False, no_entities=False):
use_gpu=-1, resume=False, no_tagger=False, no_parser=False, no_entities=False,
gold_preproc=False):
"""
Train a model. Expects data in spaCy's JSON format.
"""
@ -86,7 +88,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
i += 20
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
train_docs = corpus.train_docs(nlp, projectivize=True,
gold_preproc=False, max_length=0)
gold_preproc=gold_preproc, max_length=0)
losses = {}
for batch in minibatch(train_docs, size=batch_sizes):
docs, golds = zip(*batch)
@ -104,7 +106,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
scorer = nlp_loaded.evaluate(
corpus.dev_docs(
nlp_loaded,
gold_preproc=False))
gold_preproc=gold_preproc))
acc_loc =(output_path / ('model%d' % i) / 'accuracy.json')
with acc_loc.open('w') as file_:
file_.write(json_dumps(scorer.scores))