mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-04 05:34:10 +03:00
Set accelerating batch size in CONLL train script
This commit is contained in:
parent
661873ee4c
commit
6a27a4f77c
|
@ -218,13 +218,18 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
|
||||||
n_train_words = sum(len(doc) for doc in docs)
|
n_train_words = sum(len(doc) for doc in docs)
|
||||||
print(n_train_words)
|
print(n_train_words)
|
||||||
print("Begin training")
|
print("Begin training")
|
||||||
|
# Batch size starts at 1 and grows, so that we make updates quickly
|
||||||
|
# at the beginning of training.
|
||||||
|
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1),
|
||||||
|
spacy.util.env_opt('batch_to', 8),
|
||||||
|
spacy.util.env_opt('batch_compound', 1.001))
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
with open(text_train_loc) as file_:
|
with open(text_train_loc) as file_:
|
||||||
docs = get_docs(nlp, file_.read())
|
docs = get_docs(nlp, file_.read())
|
||||||
docs = docs[:len(golds)]
|
docs = docs[:len(golds)]
|
||||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||||
losses = {}
|
losses = {}
|
||||||
for batch in minibatch(list(zip(docs, golds)), size=1):
|
for batch in minibatch(list(zip(docs, golds)), size=batch_sizes):
|
||||||
if not batch:
|
if not batch:
|
||||||
continue
|
continue
|
||||||
batch_docs, batch_gold = zip(*batch)
|
batch_docs, batch_gold = zip(*batch)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user