Move minibatch function to util

This commit is contained in:
Matthew Honnibal 2017-11-06 23:45:36 +01:00
parent d7016d4050
commit 1cab703bba
2 changed files with 47 additions and 46 deletions

View File

@ -11,6 +11,7 @@ import itertools
from .syntax import nonproj
from .tokens import Doc
from . import util
from .util import minibatch
def tags_to_entities(tags):
@ -144,23 +145,6 @@ def _min_edit_path(cand_words, gold_words):
return prev_costs[n_gold], previous_row[-1]
def minibatch(items, size=8):
"""Iterate over batches of items. `size` may be an iterator,
so that batch-size can vary on each step.
"""
if isinstance(size, int):
size_ = itertools.repeat(8)
else:
size_ = size
items = iter(items)
while True:
batch_size = next(size_)
batch = list(cytoolz.take(int(batch_size), items))
if len(batch) == 0:
break
yield list(batch)
class GoldCorpus(object):
"""An annotated corpus, using the JSON file format. Manages
annotations for tagging, dependency parsing and NER."""

View File

@ -251,35 +251,6 @@ def get_async(stream, numpy_array):
return array
def itershuffle(iterable, bufsize=1000):
"""Shuffle an iterator. This works by holding `bufsize` items back
and yielding them sometime later. Obviously, this is not unbiased
but should be good enough for batching. Larger bufsize means less bias.
From https://gist.github.com/andres-erbsen/1307752
iterable (iterable): Iterator to shuffle.
bufsize (int): Items to hold back.
YIELDS (iterable): The shuffled iterator.
"""
iterable = iter(iterable)
buf = []
try:
while True:
for i in range(random.randint(1, bufsize-len(buf))):
buf.append(iterable.next())
random.shuffle(buf)
for i in range(random.randint(1, bufsize)):
if buf:
yield buf.pop()
else:
break
except StopIteration:
random.shuffle(buf)
while buf:
yield buf.pop()
raise StopIteration
def env_opt(name, default=None):
if type(default) is float:
type_convert = float
@ -416,6 +387,23 @@ def normalize_slice(length, start, stop, step=None):
return start, stop
def minibatch(items, size=8):
"""Iterate over batches of items. `size` may be an iterator,
so that batch-size can vary on each step.
"""
if isinstance(size, int):
size_ = itertools.repeat(8)
else:
size_ = size
items = iter(items)
while True:
batch_size = next(size_)
batch = list(cytoolz.take(int(batch_size), items))
if len(batch) == 0:
break
yield list(batch)
def compounding(start, stop, compound):
"""Yield an infinite series of compounding values. Each time the
generator is called, a value is produced by multiplying the previous
@ -445,6 +433,35 @@ def decaying(start, stop, decay):
nr_upd += 1
def itershuffle(iterable, bufsize=1000):
"""Shuffle an iterator. This works by holding `bufsize` items back
and yielding them sometime later. Obviously, this is not unbiased
but should be good enough for batching. Larger bufsize means less bias.
From https://gist.github.com/andres-erbsen/1307752
iterable (iterable): Iterator to shuffle.
bufsize (int): Items to hold back.
YIELDS (iterable): The shuffled iterator.
"""
iterable = iter(iterable)
buf = []
try:
while True:
for i in range(random.randint(1, bufsize-len(buf))):
buf.append(iterable.next())
random.shuffle(buf)
for i in range(random.randint(1, bufsize)):
if buf:
yield buf.pop()
else:
break
except StopIteration:
random.shuffle(buf)
while buf:
yield buf.pop()
raise StopIteration
def read_json(location):
"""Open and load JSON from file.