From 72889a16d558848191e51f4bfb200e70d3bc413a Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 20 Mar 2019 12:09:59 +0100 Subject: [PATCH] Fix similarity calculation if vectors are on GPU (#3440) --- spacy/tokens/doc.pyx | 7 ++++--- spacy/tokens/span.pyx | 10 ++++------ spacy/tokens/token.pyx | 4 +++- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index dd610bd6d..d4d7e5fa4 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -416,8 +416,9 @@ cdef class Doc: return self.user_hooks["vector"](self) if self._vector is not None: return self._vector - elif not len(self): - self._vector = numpy.zeros((self.vocab.vectors_length,), dtype="f") + xp = get_array_module(self.vocab.vectors.data) + if not len(self): + self._vector = xp.zeros((self.vocab.vectors_length,), dtype="f") return self._vector elif self.vocab.vectors.data.size > 0: self._vector = sum(t.vector for t in self) / len(self) @@ -426,7 +427,7 @@ cdef class Doc: self._vector = self.tensor.mean(axis=0) return self._vector else: - return numpy.zeros((self.vocab.vectors_length,), dtype="float32") + return xp.zeros((self.vocab.vectors_length,), dtype="float32") def __set__(self, value): self._vector = value diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx index 36eaeb568..e62caed40 100644 --- a/spacy/tokens/span.pyx +++ b/spacy/tokens/span.pyx @@ -420,13 +420,11 @@ cdef class Span: """ if "vector_norm" in self.doc.user_span_hooks: return self.doc.user_span_hooks["vector"](self) - cdef float value - cdef double norm = 0 + vector = self.vector + xp = get_array_module(vector) if self._vector_norm is None: - norm = 0 - for value in self.vector: - norm += value * value - self._vector_norm = sqrt(norm) if norm != 0 else 0 + total = (vector*vector).sum() + self._vector_norm = xp.sqrt(total) if total != 0. else 0. return self._vector_norm @property diff --git a/spacy/tokens/token.pyx b/spacy/tokens/token.pyx index 409b68290..66728d35c 100644 --- a/spacy/tokens/token.pyx +++ b/spacy/tokens/token.pyx @@ -404,7 +404,9 @@ cdef class Token: if "vector_norm" in self.doc.user_token_hooks: return self.doc.user_token_hooks["vector_norm"](self) vector = self.vector - return numpy.sqrt((vector ** 2).sum()) + xp = get_array_module(vector) + total = (vector ** 2).sum() + return xp.sqrt(total) if total != 0. else 0. @property def n_lefts(self):