From 2b35bb76addc664d722cff0d00a2cf597610c347 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 5 Nov 2017 15:34:40 +0100 Subject: [PATCH] Fix tensorizer on GPU --- spacy/pipeline.pyx | 6 +++++- spacy/syntax/nn_parser.pyx | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index 5a72dc946..f3defeeb9 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -415,7 +415,11 @@ class Tagger(Pipe): vocab.morphology.assign_tag_id(&doc.c[j], tag_id) idx += 1 if tensors is not None: - doc.extend_tensor(tensors[i]) + if isinstance(doc.tensor, numpy.ndarray) \ + and not isinstance(tensors[i], numpy.ndarray): + doc.extend_tensor(tensors[i].get()) + else: + doc.extend_tensor(tensors[i]) doc.is_tagged = True def update(self, docs, golds, drop=0., sgd=None, losses=None): diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 6bfd729eb..08b01a88f 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -751,7 +751,11 @@ cdef class Parser: for j in range(doc.length): doc.c[j] = state.c._sent[j] if tensors is not None: - doc.extend_tensor(tensors[i]) + if isinstance(doc.tensor, numpy.ndarray) \ + and not isinstance(tensors[i], numpy.ndarray): + doc.extend_tensor(tensors[i].get()) + else: + doc.extend_tensor(tensors[i]) self.moves.finalize_doc(doc) for hook in self.postprocesses: