From ba23d63c35bf9187f093804f93af4fd345cfa1e3 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 14 Sep 2017 13:37:41 +0200 Subject: [PATCH] Fix minibatch function, for fixed batch size --- spacy/gold.pyx | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/spacy/gold.pyx b/spacy/gold.pyx index f00d04109..fc8d6622b 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -7,6 +7,7 @@ import re import ujson import random import cytoolz +import itertools from .syntax import nonproj from .util import ensure_path @@ -146,9 +147,13 @@ def minibatch(items, size=8): '''Iterate over batches of items. `size` may be an iterator, so that batch-size can vary on each step. ''' + if isinstance(size, int): + size_ = itertools.repeat(8) + else: + size_ = size items = iter(items) while True: - batch_size = next(size) #if hasattr(size, '__next__') else size + batch_size = next(size_) batch = list(cytoolz.take(int(batch_size), items)) if len(batch) == 0: break