mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Move minibatch function to util
This commit is contained in:
parent
d7016d4050
commit
1cab703bba
|
@ -11,6 +11,7 @@ import itertools
|
||||||
from .syntax import nonproj
|
from .syntax import nonproj
|
||||||
from .tokens import Doc
|
from .tokens import Doc
|
||||||
from . import util
|
from . import util
|
||||||
|
from .util import minibatch
|
||||||
|
|
||||||
|
|
||||||
def tags_to_entities(tags):
|
def tags_to_entities(tags):
|
||||||
|
@ -144,23 +145,6 @@ def _min_edit_path(cand_words, gold_words):
|
||||||
return prev_costs[n_gold], previous_row[-1]
|
return prev_costs[n_gold], previous_row[-1]
|
||||||
|
|
||||||
|
|
||||||
def minibatch(items, size=8):
|
|
||||||
"""Iterate over batches of items. `size` may be an iterator,
|
|
||||||
so that batch-size can vary on each step.
|
|
||||||
"""
|
|
||||||
if isinstance(size, int):
|
|
||||||
size_ = itertools.repeat(8)
|
|
||||||
else:
|
|
||||||
size_ = size
|
|
||||||
items = iter(items)
|
|
||||||
while True:
|
|
||||||
batch_size = next(size_)
|
|
||||||
batch = list(cytoolz.take(int(batch_size), items))
|
|
||||||
if len(batch) == 0:
|
|
||||||
break
|
|
||||||
yield list(batch)
|
|
||||||
|
|
||||||
|
|
||||||
class GoldCorpus(object):
|
class GoldCorpus(object):
|
||||||
"""An annotated corpus, using the JSON file format. Manages
|
"""An annotated corpus, using the JSON file format. Manages
|
||||||
annotations for tagging, dependency parsing and NER."""
|
annotations for tagging, dependency parsing and NER."""
|
||||||
|
|
|
@ -251,35 +251,6 @@ def get_async(stream, numpy_array):
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
|
||||||
def itershuffle(iterable, bufsize=1000):
|
|
||||||
"""Shuffle an iterator. This works by holding `bufsize` items back
|
|
||||||
and yielding them sometime later. Obviously, this is not unbiased –
|
|
||||||
but should be good enough for batching. Larger bufsize means less bias.
|
|
||||||
From https://gist.github.com/andres-erbsen/1307752
|
|
||||||
|
|
||||||
iterable (iterable): Iterator to shuffle.
|
|
||||||
bufsize (int): Items to hold back.
|
|
||||||
YIELDS (iterable): The shuffled iterator.
|
|
||||||
"""
|
|
||||||
iterable = iter(iterable)
|
|
||||||
buf = []
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
for i in range(random.randint(1, bufsize-len(buf))):
|
|
||||||
buf.append(iterable.next())
|
|
||||||
random.shuffle(buf)
|
|
||||||
for i in range(random.randint(1, bufsize)):
|
|
||||||
if buf:
|
|
||||||
yield buf.pop()
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
except StopIteration:
|
|
||||||
random.shuffle(buf)
|
|
||||||
while buf:
|
|
||||||
yield buf.pop()
|
|
||||||
raise StopIteration
|
|
||||||
|
|
||||||
|
|
||||||
def env_opt(name, default=None):
|
def env_opt(name, default=None):
|
||||||
if type(default) is float:
|
if type(default) is float:
|
||||||
type_convert = float
|
type_convert = float
|
||||||
|
@ -416,6 +387,23 @@ def normalize_slice(length, start, stop, step=None):
|
||||||
return start, stop
|
return start, stop
|
||||||
|
|
||||||
|
|
||||||
|
def minibatch(items, size=8):
|
||||||
|
"""Iterate over batches of items. `size` may be an iterator,
|
||||||
|
so that batch-size can vary on each step.
|
||||||
|
"""
|
||||||
|
if isinstance(size, int):
|
||||||
|
size_ = itertools.repeat(8)
|
||||||
|
else:
|
||||||
|
size_ = size
|
||||||
|
items = iter(items)
|
||||||
|
while True:
|
||||||
|
batch_size = next(size_)
|
||||||
|
batch = list(cytoolz.take(int(batch_size), items))
|
||||||
|
if len(batch) == 0:
|
||||||
|
break
|
||||||
|
yield list(batch)
|
||||||
|
|
||||||
|
|
||||||
def compounding(start, stop, compound):
|
def compounding(start, stop, compound):
|
||||||
"""Yield an infinite series of compounding values. Each time the
|
"""Yield an infinite series of compounding values. Each time the
|
||||||
generator is called, a value is produced by multiplying the previous
|
generator is called, a value is produced by multiplying the previous
|
||||||
|
@ -445,6 +433,35 @@ def decaying(start, stop, decay):
|
||||||
nr_upd += 1
|
nr_upd += 1
|
||||||
|
|
||||||
|
|
||||||
|
def itershuffle(iterable, bufsize=1000):
|
||||||
|
"""Shuffle an iterator. This works by holding `bufsize` items back
|
||||||
|
and yielding them sometime later. Obviously, this is not unbiased –
|
||||||
|
but should be good enough for batching. Larger bufsize means less bias.
|
||||||
|
From https://gist.github.com/andres-erbsen/1307752
|
||||||
|
|
||||||
|
iterable (iterable): Iterator to shuffle.
|
||||||
|
bufsize (int): Items to hold back.
|
||||||
|
YIELDS (iterable): The shuffled iterator.
|
||||||
|
"""
|
||||||
|
iterable = iter(iterable)
|
||||||
|
buf = []
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
for i in range(random.randint(1, bufsize-len(buf))):
|
||||||
|
buf.append(iterable.next())
|
||||||
|
random.shuffle(buf)
|
||||||
|
for i in range(random.randint(1, bufsize)):
|
||||||
|
if buf:
|
||||||
|
yield buf.pop()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
except StopIteration:
|
||||||
|
random.shuffle(buf)
|
||||||
|
while buf:
|
||||||
|
yield buf.pop()
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
|
||||||
def read_json(location):
|
def read_json(location):
|
||||||
"""Open and load JSON from file.
|
"""Open and load JSON from file.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user