mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
Fix tensor extending in tagger
This commit is contained in:
parent
bd2cbdfa85
commit
6681058abd
|
@ -352,10 +352,12 @@ class Tagger(Pipe):
|
|||
def predict(self, docs):
|
||||
tokvecs = self.model.tok2vec(docs)
|
||||
scores = self.model.softmax(tokvecs)
|
||||
guesses = scores.argmax(axis=1)
|
||||
if not isinstance(guesses, numpy.ndarray):
|
||||
guesses = guesses.get()
|
||||
guesses = self.model.ops.unflatten(guesses, [len(d) for d in docs])
|
||||
guesses = []
|
||||
for doc_scores in scores:
|
||||
doc_guesses = doc_scores.argmax(axis=1)
|
||||
if not isinstance(doc_guesses, numpy.ndarray):
|
||||
doc_guesses = doc_guesses.get()
|
||||
guesses.append(doc_guesses)
|
||||
return guesses, tokvecs
|
||||
|
||||
def set_annotations(self, docs, batch_tag_ids, tensors=None):
|
||||
|
|
Loading…
Reference in New Issue
Block a user