Fix tensor extending in tagger

This commit is contained in:
Matthew Honnibal 2017-11-03 13:29:36 +01:00
parent bd2cbdfa85
commit 6681058abd

View File

@ -352,10 +352,12 @@ class Tagger(Pipe):
def predict(self, docs):
tokvecs = self.model.tok2vec(docs)
scores = self.model.softmax(tokvecs)
guesses = scores.argmax(axis=1)
if not isinstance(guesses, numpy.ndarray):
guesses = guesses.get()
guesses = self.model.ops.unflatten(guesses, [len(d) for d in docs])
guesses = []
for doc_scores in scores:
doc_guesses = doc_scores.argmax(axis=1)
if not isinstance(doc_guesses, numpy.ndarray):
doc_guesses = doc_guesses.get()
guesses.append(doc_guesses)
return guesses, tokvecs
def set_annotations(self, docs, batch_tag_ids, tensors=None):