Clean up spacy.cli.train

This commit is contained in:
Matthew Honnibal 2017-05-25 16:16:30 -05:00
parent b9cea9cd93
commit 702fe74a4d

View File

@ -14,7 +14,7 @@ from timeit import default_timer as timer
from ..tokens.doc import Doc from ..tokens.doc import Doc
from ..scorer import Scorer from ..scorer import Scorer
from ..gold import GoldParse, merge_sents from ..gold import GoldParse, merge_sents
from ..gold import GoldCorpus from ..gold import GoldCorpus, minibatch
from ..util import prints from ..util import prints
from .. import util from .. import util
from .. import displacy from .. import displacy
@ -53,44 +53,38 @@ def train(_, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
if no_parser and 'dependencies' in pipeline: pipeline.remove('dependencies') if no_parser and 'dependencies' in pipeline: pipeline.remove('dependencies')
if no_entities and 'entities' in pipeline: pipeline.remove('entities') if no_entities and 'entities' in pipeline: pipeline.remove('entities')
# Take dropout and batch size as generators of values -- dropout
# starts high and decays sharply, to force the optimizer to explore.
# Batch size starts at 1 and grows, so that we make updates quickly
# at the beginning of training.
dropout_rates = util.decaying(util.env_opt('dropout_from', 0.0),
util.env_opt('dropout_to', 0.0),
util.env_opt('dropout_decay', 0.0))
batch_sizes = util.compounding(util.env_opt('batch_from', 1),
util.env_opt('batch_to', 64),
util.env_opt('batch_compound', 1.001))
nlp = lang_class(pipeline=pipeline) nlp = lang_class(pipeline=pipeline)
corpus = GoldCorpus(train_path, dev_path, limit=n_sents) corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
n_train_docs = corpus.count_train()
dropout = util.env_opt('dropout', 0.0)
dropout_decay = util.env_opt('dropout_decay', 0.0)
orig_dropout = dropout
optimizer = nlp.begin_training(lambda: corpus.train_tuples, use_gpu=use_gpu) optimizer = nlp.begin_training(lambda: corpus.train_tuples, use_gpu=use_gpu)
n_train_docs = corpus.count_train()
batch_size = float(util.env_opt('min_batch_size', 4))
max_batch_size = util.env_opt('max_batch_size', 64)
batch_accel = util.env_opt('batch_accel', 1.001)
print("Itn.\tDep. Loss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %") print("Itn.\tDep. Loss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %")
for i in range(n_iter): for i in range(n_iter):
with tqdm.tqdm(total=n_train_docs) as pbar: with tqdm.tqdm(total=corpus.count_train()) as pbar:
train_docs = corpus.train_docs(nlp, shuffle=i, projectivize=True, train_docs = corpus.train_docs(nlp, projectivize=True,
gold_preproc=False) gold_preproc=False, shuffle=i)
losses = {} losses = {}
idx = 0 for batch in minibatch(train_docs, size=batch_sizes):
while idx < n_train_docs:
batch = list(cytoolz.take(int(batch_size), train_docs))
if not batch:
break
docs, golds = zip(*batch) docs, golds = zip(*batch)
nlp.update(docs, golds, drop=dropout, sgd=optimizer, losses=losses) nlp.update(docs, golds, sgd=optimizer,
drop=next(dropout_rates), losses=losses)
pbar.update(len(docs)) pbar.update(len(docs))
idx += len(docs)
batch_size *= batch_accel
batch_size = min(batch_size, max_batch_size)
dropout = linear_decay(orig_dropout, dropout_decay, i*n_train_docs+idx)
with nlp.use_params(optimizer.averages): with nlp.use_params(optimizer.averages):
start = timer()
scorer = nlp.evaluate(corpus.dev_docs(nlp, gold_preproc=False)) scorer = nlp.evaluate(corpus.dev_docs(nlp, gold_preproc=False))
end = timer() print_progress(i, losses, scorer.scores)
n_words = scorer.tokens.tp + scorer.tokens.fn
assert n_words != 0
wps = n_words / (end-start)
print_progress(i, losses, scorer.scores, wps=wps)
with (output_path / 'model.bin').open('wb') as file_: with (output_path / 'model.bin').open('wb') as file_:
with nlp.use_params(optimizer.averages): with nlp.use_params(optimizer.averages):
dill.dump(nlp, file_, -1) dill.dump(nlp, file_, -1)
@ -118,7 +112,6 @@ def print_progress(itn, losses, dev_scores, wps=0.0):
tpl = '\t'.join(( tpl = '\t'.join((
'{:d}', '{:d}',
'{dep_loss:.3f}', '{dep_loss:.3f}',
'{tag_loss:.3f}',
'{uas:.3f}', '{uas:.3f}',
'{ents_p:.3f}', '{ents_p:.3f}',
'{ents_r:.3f}', '{ents_r:.3f}',