mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Add document length cap for training
This commit is contained in:
parent
6771780d3f
commit
c2bbf076a4
|
@ -85,6 +85,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
|||
batch_sizes = util.compounding(util.env_opt('batch_from', 1),
|
||||
util.env_opt('batch_to', 16),
|
||||
util.env_opt('batch_compound', 1.001))
|
||||
max_doc_len = util.env_opt('max_doc_len', 5000)
|
||||
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
|
||||
n_train_words = corpus.count_train()
|
||||
|
||||
|
@ -108,6 +109,9 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
|||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||
losses = {}
|
||||
for batch in minibatch(train_docs, size=batch_sizes):
|
||||
batch = [(d, g) for (d, g) in batch if len(d) < max_doc_len]
|
||||
if not batch:
|
||||
continue
|
||||
docs, golds = zip(*batch)
|
||||
nlp.update(docs, golds, sgd=optimizer,
|
||||
drop=next(dropout_rates), losses=losses)
|
||||
|
|
Loading…
Reference in New Issue
Block a user