mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
Add document length cap for training
This commit is contained in:
parent
6771780d3f
commit
c2bbf076a4
|
@ -85,6 +85,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
||||||
batch_sizes = util.compounding(util.env_opt('batch_from', 1),
|
batch_sizes = util.compounding(util.env_opt('batch_from', 1),
|
||||||
util.env_opt('batch_to', 16),
|
util.env_opt('batch_to', 16),
|
||||||
util.env_opt('batch_compound', 1.001))
|
util.env_opt('batch_compound', 1.001))
|
||||||
|
max_doc_len = util.env_opt('max_doc_len', 5000)
|
||||||
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
|
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
|
||||||
n_train_words = corpus.count_train()
|
n_train_words = corpus.count_train()
|
||||||
|
|
||||||
|
@ -108,6 +109,9 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
||||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||||
losses = {}
|
losses = {}
|
||||||
for batch in minibatch(train_docs, size=batch_sizes):
|
for batch in minibatch(train_docs, size=batch_sizes):
|
||||||
|
batch = [(d, g) for (d, g) in batch if len(d) < max_doc_len]
|
||||||
|
if not batch:
|
||||||
|
continue
|
||||||
docs, golds = zip(*batch)
|
docs, golds = zip(*batch)
|
||||||
nlp.update(docs, golds, sgd=optimizer,
|
nlp.update(docs, golds, sgd=optimizer,
|
||||||
drop=next(dropout_rates), losses=losses)
|
drop=next(dropout_rates), losses=losses)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user