mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-24 08:14:15 +03:00
Fix tokvecs flattening in pipeline
This commit is contained in:
parent
0731971bfc
commit
180e5afede
|
@ -105,16 +105,19 @@ class NeuralTagger(object):
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=128, n_threads=-1):
|
def pipe(self, stream, batch_size=128, n_threads=-1):
|
||||||
for docs in cytoolz.partition_all(batch_size, stream):
|
for docs in cytoolz.partition_all(batch_size, stream):
|
||||||
tokvecs = self.model.ops.flatten([d.tensor for d in docs])
|
tokvecs = [d.tensor for d in docs]
|
||||||
tag_ids = self.predict(tokvecs)
|
tag_ids = self.predict(tokvecs)
|
||||||
self.set_annotations(docs, tag_ids)
|
self.set_annotations(docs, tag_ids)
|
||||||
yield from docs
|
yield from docs
|
||||||
|
|
||||||
def predict(self, tokvecs):
|
def predict(self, tokvecs):
|
||||||
scores = self.model(tokvecs)
|
scores = self.model(tokvecs)
|
||||||
|
scores = self.model.ops.flatten(scores)
|
||||||
guesses = scores.argmax(axis=1)
|
guesses = scores.argmax(axis=1)
|
||||||
if not isinstance(guesses, numpy.ndarray):
|
if not isinstance(guesses, numpy.ndarray):
|
||||||
guesses = guesses.get()
|
guesses = guesses.get()
|
||||||
|
guesses = self.model.ops.unflatten(guesses,
|
||||||
|
[tv.shape[0] for tv in tokvecs])
|
||||||
return guesses
|
return guesses
|
||||||
|
|
||||||
def set_annotations(self, docs, batch_tag_ids):
|
def set_annotations(self, docs, batch_tag_ids):
|
||||||
|
@ -122,10 +125,9 @@ class NeuralTagger(object):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
cdef int idx = 0
|
cdef int idx = 0
|
||||||
cdef int i, j, tag_id
|
|
||||||
cdef Vocab vocab = self.vocab
|
cdef Vocab vocab = self.vocab
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc_tag_ids = batch_tag_ids[idx:idx+len(doc)]
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
for j, tag_id in enumerate(doc_tag_ids):
|
for j, tag_id in enumerate(doc_tag_ids):
|
||||||
vocab.morphology.assign_tag_id(&doc.c[j], tag_id)
|
vocab.morphology.assign_tag_id(&doc.c[j], tag_id)
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user