Move minibatch function to util

This commit is contained in:
Matthew Honnibal 2017-11-06 23:45:36 +01:00
parent d7016d4050
commit 1cab703bba
2 changed files with 47 additions and 46 deletions

View File

@ -11,6 +11,7 @@ import itertools
from .syntax import nonproj from .syntax import nonproj
from .tokens import Doc from .tokens import Doc
from . import util from . import util
from .util import minibatch
def tags_to_entities(tags): def tags_to_entities(tags):
@ -144,23 +145,6 @@ def _min_edit_path(cand_words, gold_words):
return prev_costs[n_gold], previous_row[-1] return prev_costs[n_gold], previous_row[-1]
def minibatch(items, size=8):
"""Iterate over batches of items. `size` may be an iterator,
so that batch-size can vary on each step.
"""
if isinstance(size, int):
size_ = itertools.repeat(8)
else:
size_ = size
items = iter(items)
while True:
batch_size = next(size_)
batch = list(cytoolz.take(int(batch_size), items))
if len(batch) == 0:
break
yield list(batch)
class GoldCorpus(object): class GoldCorpus(object):
"""An annotated corpus, using the JSON file format. Manages """An annotated corpus, using the JSON file format. Manages
annotations for tagging, dependency parsing and NER.""" annotations for tagging, dependency parsing and NER."""

View File

@ -251,35 +251,6 @@ def get_async(stream, numpy_array):
return array return array
def itershuffle(iterable, bufsize=1000):
"""Shuffle an iterator. This works by holding `bufsize` items back
and yielding them sometime later. Obviously, this is not unbiased
but should be good enough for batching. Larger bufsize means less bias.
From https://gist.github.com/andres-erbsen/1307752
iterable (iterable): Iterator to shuffle.
bufsize (int): Items to hold back.
YIELDS (iterable): The shuffled iterator.
"""
iterable = iter(iterable)
buf = []
try:
while True:
for i in range(random.randint(1, bufsize-len(buf))):
buf.append(iterable.next())
random.shuffle(buf)
for i in range(random.randint(1, bufsize)):
if buf:
yield buf.pop()
else:
break
except StopIteration:
random.shuffle(buf)
while buf:
yield buf.pop()
raise StopIteration
def env_opt(name, default=None): def env_opt(name, default=None):
if type(default) is float: if type(default) is float:
type_convert = float type_convert = float
@ -416,6 +387,23 @@ def normalize_slice(length, start, stop, step=None):
return start, stop return start, stop
def minibatch(items, size=8):
"""Iterate over batches of items. `size` may be an iterator,
so that batch-size can vary on each step.
"""
if isinstance(size, int):
size_ = itertools.repeat(8)
else:
size_ = size
items = iter(items)
while True:
batch_size = next(size_)
batch = list(cytoolz.take(int(batch_size), items))
if len(batch) == 0:
break
yield list(batch)
def compounding(start, stop, compound): def compounding(start, stop, compound):
"""Yield an infinite series of compounding values. Each time the """Yield an infinite series of compounding values. Each time the
generator is called, a value is produced by multiplying the previous generator is called, a value is produced by multiplying the previous
@ -445,6 +433,35 @@ def decaying(start, stop, decay):
nr_upd += 1 nr_upd += 1
def itershuffle(iterable, bufsize=1000):
"""Shuffle an iterator. This works by holding `bufsize` items back
and yielding them sometime later. Obviously, this is not unbiased
but should be good enough for batching. Larger bufsize means less bias.
From https://gist.github.com/andres-erbsen/1307752
iterable (iterable): Iterator to shuffle.
bufsize (int): Items to hold back.
YIELDS (iterable): The shuffled iterator.
"""
iterable = iter(iterable)
buf = []
try:
while True:
for i in range(random.randint(1, bufsize-len(buf))):
buf.append(iterable.next())
random.shuffle(buf)
for i in range(random.randint(1, bufsize)):
if buf:
yield buf.pop()
else:
break
except StopIteration:
random.shuffle(buf)
while buf:
yield buf.pop()
raise StopIteration
def read_json(location): def read_json(location):
"""Open and load JSON from file. """Open and load JSON from file.