From 6681058abd08a52964960fb41c7e201a49277bcb Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 3 Nov 2017 13:29:36 +0100 Subject: [PATCH] Fix tensor extending in tagger --- spacy/pipeline.pyx | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index 283e4b106..e55710dee 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -352,10 +352,12 @@ class Tagger(Pipe): def predict(self, docs): tokvecs = self.model.tok2vec(docs) scores = self.model.softmax(tokvecs) - guesses = scores.argmax(axis=1) - if not isinstance(guesses, numpy.ndarray): - guesses = guesses.get() - guesses = self.model.ops.unflatten(guesses, [len(d) for d in docs]) + guesses = [] + for doc_scores in scores: + doc_guesses = doc_scores.argmax(axis=1) + if not isinstance(doc_guesses, numpy.ndarray): + doc_guesses = doc_guesses.get() + guesses.append(doc_guesses) return guesses, tokvecs def set_annotations(self, docs, batch_tag_ids, tensors=None):