Add minibatch function in spacy.gold

This commit is contained in:
Matthew Honnibal 2017-05-25 17:15:09 -05:00
parent 702fe74a4d
commit 3a6e59cc53

View File

@ -6,6 +6,7 @@ import io
import re
import ujson
import random
import cytoolz
from .syntax import nonproj
from .util import ensure_path
@ -141,6 +142,19 @@ def _min_edit_path(cand_words, gold_words):
return prev_costs[n_gold], previous_row[-1]
def minibatch(items, size=8):
'''Iterate over batches of items. `size` may be an iterator,
so that batch-size can vary on each step.
'''
items = iter(items)
while True:
batch_size = next(size) #if hasattr(size, '__next__') else size
batch = list(cytoolz.take(int(batch_size), items))
if len(batch) == 0:
break
yield list(batch)
class GoldCorpus(object):
"""An annotated corpus, using the JSON file format. Manages
annotations for tagging, dependency parsing and NER."""
@ -396,7 +410,10 @@ cdef class GoldParse:
else:
self.words[i] = words[gold_i]
self.tags[i] = tags[gold_i]
self.heads[i] = self.gold_to_cand[heads[gold_i]]
if heads[gold_i] is None:
self.heads[i] = None
else:
self.heads[i] = self.gold_to_cand[heads[gold_i]]
self.labels[i] = deps[gold_i]
self.ner[i] = entities[gold_i]