mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
Offer option of padding-sensitive batching
This commit is contained in:
parent
3a7f275c02
commit
77af0a6bb4
|
@ -303,11 +303,19 @@ def create_train_batches(nlp, corpus, cfg):
|
||||||
)
|
)
|
||||||
|
|
||||||
epoch = 0
|
epoch = 0
|
||||||
|
batch_strategy = cfg.get("batch_by", "sequences")
|
||||||
while True:
|
while True:
|
||||||
if len(train_examples) == 0:
|
if len(train_examples) == 0:
|
||||||
raise ValueError(Errors.E988)
|
raise ValueError(Errors.E988)
|
||||||
epoch += 1
|
epoch += 1
|
||||||
if cfg.get("batch_by_words", True):
|
if batch_strategy == "padded":
|
||||||
|
batches = util.minibatch_by_padded_size(
|
||||||
|
train_examples,
|
||||||
|
size=cfg["batch_size"],
|
||||||
|
buffer=256,
|
||||||
|
discard_oversize=cfg["discard_oversize"],
|
||||||
|
)
|
||||||
|
elif batch_strategy == "words":
|
||||||
batches = util.minibatch_by_words(
|
batches = util.minibatch_by_words(
|
||||||
train_examples,
|
train_examples,
|
||||||
size=cfg["batch_size"],
|
size=cfg["batch_size"],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user