From bc2294d7f18977028b349058a9ac6d88313e5e2e Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 22 May 2017 04:47:14 -0500 Subject: [PATCH] Add support for fiddly hyper-parameters to train func --- spacy/cli/train.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 1b847301d..a25a7f252 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -7,6 +7,7 @@ import cytoolz from pathlib import Path import dill import tqdm +from thinc.neural.optimizers import linear_decay from ..tokens.doc import Doc from ..scorer import Scorer @@ -40,24 +41,35 @@ def train(lang_id, output_dir, train_data, dev_data, n_iter, n_sents, corpus = GoldCorpus(train_path, dev_path) dropout = util.env_opt('dropout', 0.0) + dropout_decay = util.env_opt('dropout_decay', 0.0) optimizer = nlp.begin_training(lambda: corpus.train_tuples, use_gpu=use_gpu) n_train_docs = corpus.count_train() + batch_size = float(util.env_opt('min_batch_size', 4)) + max_batch_size = util.env_opt('max_batch_size', 64) + batch_accel = util.env_opt('batch_accel', 1.001) print("Itn.\tDep. Loss\tUAS\tNER F.\tTag %\tToken %") for i in range(n_iter): with tqdm.tqdm(total=n_train_docs) as pbar: train_docs = corpus.train_docs(nlp, shuffle=i, projectivize=True) - for batch in cytoolz.partition_all(20, train_docs): + idx = 0 + while idx < n_train_docs: + batch = list(cytoolz.take(int(batch_size), train_docs)) + if not batch: + break docs, golds = zip(*batch) - docs = list(docs) - golds = list(golds) nlp.update(docs, golds, drop=dropout, sgd=optimizer) pbar.update(len(docs)) + idx += len(docs) + batch_size *= batch_accel + batch_size = min(int(batch_size), max_batch_size) + dropout = linear_decay(dropout, dropout_decay, i*n_train_docs+idx) with nlp.use_params(optimizer.averages): scorer = nlp.evaluate(corpus.dev_docs(nlp)) print_progress(i, {}, scorer.scores) with (output_path / 'model.bin').open('wb') as file_: - dill.dump(nlp, file_, -1) + with nlp.use_params(optimizer.averages): + dill.dump(nlp, file_, -1) def _render_parses(i, to_render):