mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-14 03:26:24 +03:00
Fix tensorizer on GPU
This commit is contained in:
parent
6e5181bbaa
commit
2b35bb76ad
|
@ -415,7 +415,11 @@ class Tagger(Pipe):
|
||||||
vocab.morphology.assign_tag_id(&doc.c[j], tag_id)
|
vocab.morphology.assign_tag_id(&doc.c[j], tag_id)
|
||||||
idx += 1
|
idx += 1
|
||||||
if tensors is not None:
|
if tensors is not None:
|
||||||
doc.extend_tensor(tensors[i])
|
if isinstance(doc.tensor, numpy.ndarray) \
|
||||||
|
and not isinstance(tensors[i], numpy.ndarray):
|
||||||
|
doc.extend_tensor(tensors[i].get())
|
||||||
|
else:
|
||||||
|
doc.extend_tensor(tensors[i])
|
||||||
doc.is_tagged = True
|
doc.is_tagged = True
|
||||||
|
|
||||||
def update(self, docs, golds, drop=0., sgd=None, losses=None):
|
def update(self, docs, golds, drop=0., sgd=None, losses=None):
|
||||||
|
|
|
@ -751,7 +751,11 @@ cdef class Parser:
|
||||||
for j in range(doc.length):
|
for j in range(doc.length):
|
||||||
doc.c[j] = state.c._sent[j]
|
doc.c[j] = state.c._sent[j]
|
||||||
if tensors is not None:
|
if tensors is not None:
|
||||||
doc.extend_tensor(tensors[i])
|
if isinstance(doc.tensor, numpy.ndarray) \
|
||||||
|
and not isinstance(tensors[i], numpy.ndarray):
|
||||||
|
doc.extend_tensor(tensors[i].get())
|
||||||
|
else:
|
||||||
|
doc.extend_tensor(tensors[i])
|
||||||
self.moves.finalize_doc(doc)
|
self.moves.finalize_doc(doc)
|
||||||
|
|
||||||
for hook in self.postprocesses:
|
for hook in self.postprocesses:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user