Add itershuffle utility function. Maybe belongs in thinc

This commit is contained in:
Matthew Honnibal 2017-05-21 09:05:05 -05:00
parent 3b7c108246
commit 0731971bfc

View File

@ -9,6 +9,7 @@ import regex as re
from pathlib import Path from pathlib import Path
import sys import sys
import textwrap import textwrap
import random
from .symbols import ORTH from .symbols import ORTH
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_ from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
@ -172,6 +173,31 @@ def get_async(stream, numpy_array):
array.set(numpy_array, stream=stream) array.set(numpy_array, stream=stream)
return array return array
def itershuffle(iterable, bufsize=1000):
"""Shuffle an iterator. This works by holding `bufsize` items back
and yielding them sometime later. Obviously, this is not unbiased --
but should be good enough for batching. Larger bufsize means less bias.
From https://gist.github.com/andres-erbsen/1307752
"""
iterable = iter(iterable)
buf = []
try:
while True:
for i in range(random.randint(1, bufsize-len(buf))):
buf.append(iterable.next())
random.shuffle(buf)
for i in range(random.randint(1, bufsize)):
if buf:
yield buf.pop()
else:
break
except StopIteration:
random.shuffle(buf)
while buf:
yield buf.pop()
raise StopIteration
def env_opt(name, default=None): def env_opt(name, default=None):
if type(default) is float: if type(default) is float: