mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Offer option of padding-sensitive batching
This commit is contained in:
parent
3a7f275c02
commit
77af0a6bb4
|
@ -303,11 +303,19 @@ def create_train_batches(nlp, corpus, cfg):
|
|||
)
|
||||
|
||||
epoch = 0
|
||||
batch_strategy = cfg.get("batch_by", "sequences")
|
||||
while True:
|
||||
if len(train_examples) == 0:
|
||||
raise ValueError(Errors.E988)
|
||||
epoch += 1
|
||||
if cfg.get("batch_by_words", True):
|
||||
if batch_strategy == "padded":
|
||||
batches = util.minibatch_by_padded_size(
|
||||
train_examples,
|
||||
size=cfg["batch_size"],
|
||||
buffer=256,
|
||||
discard_oversize=cfg["discard_oversize"],
|
||||
)
|
||||
elif batch_strategy == "words":
|
||||
batches = util.minibatch_by_words(
|
||||
train_examples,
|
||||
size=cfg["batch_size"],
|
||||
|
@ -318,7 +326,7 @@ def create_train_batches(nlp, corpus, cfg):
|
|||
train_examples,
|
||||
size=cfg["batch_size"],
|
||||
)
|
||||
|
||||
|
||||
# make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop
|
||||
try:
|
||||
first = next(batches)
|
||||
|
|
Loading…
Reference in New Issue
Block a user