Offer option of padding-sensitive batching

This commit is contained in:
Matthw Honnibal 2020-07-09 14:50:20 +02:00
parent 3a7f275c02
commit 77af0a6bb4

View File

@ -303,11 +303,19 @@ def create_train_batches(nlp, corpus, cfg):
)
epoch = 0
batch_strategy = cfg.get("batch_by", "sequences")
while True:
if len(train_examples) == 0:
raise ValueError(Errors.E988)
epoch += 1
if cfg.get("batch_by_words", True):
if batch_strategy == "padded":
batches = util.minibatch_by_padded_size(
train_examples,
size=cfg["batch_size"],
buffer=256,
discard_oversize=cfg["discard_oversize"],
)
elif batch_strategy == "words":
batches = util.minibatch_by_words(
train_examples,
size=cfg["batch_size"],
@ -318,7 +326,7 @@ def create_train_batches(nlp, corpus, cfg):
train_examples,
size=cfg["batch_size"],
)
# make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop
try:
first = next(batches)