From 3a6e59cc53fd49293336ced657050022aedb1df5 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 25 May 2017 17:15:09 -0500 Subject: [PATCH] Add minibatch function in spacy.gold --- spacy/gold.pyx | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/spacy/gold.pyx b/spacy/gold.pyx index 53bd25890..579010e6d 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -6,6 +6,7 @@ import io import re import ujson import random +import cytoolz from .syntax import nonproj from .util import ensure_path @@ -141,6 +142,19 @@ def _min_edit_path(cand_words, gold_words): return prev_costs[n_gold], previous_row[-1] +def minibatch(items, size=8): + '''Iterate over batches of items. `size` may be an iterator, + so that batch-size can vary on each step. + ''' + items = iter(items) + while True: + batch_size = next(size) #if hasattr(size, '__next__') else size + batch = list(cytoolz.take(int(batch_size), items)) + if len(batch) == 0: + break + yield list(batch) + + class GoldCorpus(object): """An annotated corpus, using the JSON file format. Manages annotations for tagging, dependency parsing and NER.""" @@ -396,7 +410,10 @@ cdef class GoldParse: else: self.words[i] = words[gold_i] self.tags[i] = tags[gold_i] - self.heads[i] = self.gold_to_cand[heads[gold_i]] + if heads[gold_i] is None: + self.heads[i] = None + else: + self.heads[i] = self.gold_to_cand[heads[gold_i]] self.labels[i] = deps[gold_i] self.ner[i] = entities[gold_i]