Add minibatch function in spacy.gold

This commit is contained in:
Matthew Honnibal 2017-05-25 17:15:09 -05:00
parent 702fe74a4d
commit 3a6e59cc53

View File

@ -6,6 +6,7 @@ import io
import re import re
import ujson import ujson
import random import random
import cytoolz
from .syntax import nonproj from .syntax import nonproj
from .util import ensure_path from .util import ensure_path
@ -141,6 +142,19 @@ def _min_edit_path(cand_words, gold_words):
return prev_costs[n_gold], previous_row[-1] return prev_costs[n_gold], previous_row[-1]
def minibatch(items, size=8):
'''Iterate over batches of items. `size` may be an iterator,
so that batch-size can vary on each step.
'''
items = iter(items)
while True:
batch_size = next(size) #if hasattr(size, '__next__') else size
batch = list(cytoolz.take(int(batch_size), items))
if len(batch) == 0:
break
yield list(batch)
class GoldCorpus(object): class GoldCorpus(object):
"""An annotated corpus, using the JSON file format. Manages """An annotated corpus, using the JSON file format. Manages
annotations for tagging, dependency parsing and NER.""" annotations for tagging, dependency parsing and NER."""
@ -396,7 +410,10 @@ cdef class GoldParse:
else: else:
self.words[i] = words[gold_i] self.words[i] = words[gold_i]
self.tags[i] = tags[gold_i] self.tags[i] = tags[gold_i]
self.heads[i] = self.gold_to_cand[heads[gold_i]] if heads[gold_i] is None:
self.heads[i] = None
else:
self.heads[i] = self.gold_to_cand[heads[gold_i]]
self.labels[i] = deps[gold_i] self.labels[i] = deps[gold_i]
self.ner[i] = entities[gold_i] self.ner[i] = entities[gold_i]