mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Use increasing batch sizes in ud-train
This commit is contained in:
parent
7b755414eb
commit
d7ce6527fb
|
@ -13,6 +13,7 @@ import spacy
|
||||||
import spacy.util
|
import spacy.util
|
||||||
from ..tokens import Token, Doc
|
from ..tokens import Token, Doc
|
||||||
from ..gold import GoldParse
|
from ..gold import GoldParse
|
||||||
|
from ..util import compounding
|
||||||
from ..syntax.nonproj import projectivize
|
from ..syntax.nonproj import projectivize
|
||||||
from ..matcher import Matcher
|
from ..matcher import Matcher
|
||||||
from collections import defaultdict, Counter
|
from collections import defaultdict, Counter
|
||||||
|
@ -36,7 +37,7 @@ lang.ja.Japanese.Defaults.use_janome = False
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
numpy.random.seed(0)
|
numpy.random.seed(0)
|
||||||
|
|
||||||
def minibatch_by_words(items, size=5000):
|
def minibatch_by_words(items, size):
|
||||||
random.shuffle(items)
|
random.shuffle(items)
|
||||||
if isinstance(size, int):
|
if isinstance(size, int):
|
||||||
size_ = itertools.repeat(size)
|
size_ = itertools.repeat(size)
|
||||||
|
@ -368,9 +369,10 @@ def main(ud_dir, parses_dir, config, corpus, limit=0):
|
||||||
|
|
||||||
optimizer = initialize_pipeline(nlp, docs, golds, config)
|
optimizer = initialize_pipeline(nlp, docs, golds, config)
|
||||||
|
|
||||||
|
batch_sizes = compounding(config.batch_size //10, config.batch_size, 1.001)
|
||||||
for i in range(config.nr_epoch):
|
for i in range(config.nr_epoch):
|
||||||
docs = [nlp.make_doc(doc.text) for doc in docs]
|
docs = [nlp.make_doc(doc.text) for doc in docs]
|
||||||
batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size)
|
batches = minibatch_by_words(list(zip(docs, golds)), size=batch_sizes)
|
||||||
losses = {}
|
losses = {}
|
||||||
n_train_words = sum(len(doc) for doc in docs)
|
n_train_words = sum(len(doc) for doc in docs)
|
||||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user