From 9b1b0742fd1cd621500052b4ed1a214c429621ae Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 21 May 2017 17:52:01 -0500 Subject: [PATCH] Fix prediction for tok2vec --- spacy/pipeline.pyx | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index b0b440727..91217b80b 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -93,6 +93,7 @@ class TokenVectorEncoder(object): YIELDS (iterator): A sequence of `Doc` objects, in order of input. """ for docs in cytoolz.partition_all(batch_size, stream): + docs = list(docs) tokvecses = self.predict(docs) self.set_annotations(docs, tokvecses) yield from docs @@ -108,19 +109,14 @@ class TokenVectorEncoder(object): return tokvecs def set_annotations(self, docs, tokvecses): - for doc, tokvecs in zip(docs, tokvecses): - doc.tensor = tokvecs - - def set_annotations(self, docs, tokvecs): """Set the tensor attribute for a batch of documents. docs (iterable): A sequence of `Doc` objects. tokvecs (object): Vector representation for each token in the documents. """ - start = 0 - for doc in docs: - doc.tensor = tokvecs[start : start + len(doc)] - start += len(doc) + for doc, tokvecs in zip(docs, tokvecses): + assert tokvecs.shape[0] == len(doc) + doc.tensor = tokvecs def update(self, docs, golds, state=None, drop=0., sgd=None): """Update the model.