Fix minibatch function, for fixed batch size

This commit is contained in:
Matthew Honnibal 2017-09-14 13:37:41 +02:00
parent 456bb8a74c
commit ba23d63c35

View File

@ -7,6 +7,7 @@ import re
import ujson import ujson
import random import random
import cytoolz import cytoolz
import itertools
from .syntax import nonproj from .syntax import nonproj
from .util import ensure_path from .util import ensure_path
@ -146,9 +147,13 @@ def minibatch(items, size=8):
'''Iterate over batches of items. `size` may be an iterator, '''Iterate over batches of items. `size` may be an iterator,
so that batch-size can vary on each step. so that batch-size can vary on each step.
''' '''
if isinstance(size, int):
size_ = itertools.repeat(8)
else:
size_ = size
items = iter(items) items = iter(items)
while True: while True:
batch_size = next(size) #if hasattr(size, '__next__') else size batch_size = next(size_)
batch = list(cytoolz.take(int(batch_size), items)) batch = list(cytoolz.take(int(batch_size), items))
if len(batch) == 0: if len(batch) == 0:
break break