Fix minibatching

This commit is contained in:
Matthew Honnibal 2017-07-22 20:14:49 +02:00
parent ded0df5e2f
commit 9bae0ddc50

View File

@ -148,7 +148,7 @@ def minibatch(items, size=8):
''' '''
items = iter(items) items = iter(items)
while True: while True:
batch_size = next(size) if hasattr(size, '__next__') else size batch_size = next(size) #if hasattr(size, '__next__') else size
batch = list(cytoolz.take(int(batch_size), items)) batch = list(cytoolz.take(int(batch_size), items))
if len(batch) == 0: if len(batch) == 0:
break break