Fix similarity calculation if vectors are on GPU (#3440)

This commit is contained in:
Matthew Honnibal 2019-03-20 12:09:59 +01:00 committed by Ines Montani
parent 1612990e88
commit 72889a16d5
3 changed files with 11 additions and 10 deletions

View File

@ -416,8 +416,9 @@ cdef class Doc:
return self.user_hooks["vector"](self) return self.user_hooks["vector"](self)
if self._vector is not None: if self._vector is not None:
return self._vector return self._vector
elif not len(self): xp = get_array_module(self.vocab.vectors.data)
self._vector = numpy.zeros((self.vocab.vectors_length,), dtype="f") if not len(self):
self._vector = xp.zeros((self.vocab.vectors_length,), dtype="f")
return self._vector return self._vector
elif self.vocab.vectors.data.size > 0: elif self.vocab.vectors.data.size > 0:
self._vector = sum(t.vector for t in self) / len(self) self._vector = sum(t.vector for t in self) / len(self)
@ -426,7 +427,7 @@ cdef class Doc:
self._vector = self.tensor.mean(axis=0) self._vector = self.tensor.mean(axis=0)
return self._vector return self._vector
else: else:
return numpy.zeros((self.vocab.vectors_length,), dtype="float32") return xp.zeros((self.vocab.vectors_length,), dtype="float32")
def __set__(self, value): def __set__(self, value):
self._vector = value self._vector = value

View File

@ -420,13 +420,11 @@ cdef class Span:
""" """
if "vector_norm" in self.doc.user_span_hooks: if "vector_norm" in self.doc.user_span_hooks:
return self.doc.user_span_hooks["vector"](self) return self.doc.user_span_hooks["vector"](self)
cdef float value vector = self.vector
cdef double norm = 0 xp = get_array_module(vector)
if self._vector_norm is None: if self._vector_norm is None:
norm = 0 total = (vector*vector).sum()
for value in self.vector: self._vector_norm = xp.sqrt(total) if total != 0. else 0.
norm += value * value
self._vector_norm = sqrt(norm) if norm != 0 else 0
return self._vector_norm return self._vector_norm
@property @property

View File

@ -404,7 +404,9 @@ cdef class Token:
if "vector_norm" in self.doc.user_token_hooks: if "vector_norm" in self.doc.user_token_hooks:
return self.doc.user_token_hooks["vector_norm"](self) return self.doc.user_token_hooks["vector_norm"](self)
vector = self.vector vector = self.vector
return numpy.sqrt((vector ** 2).sum()) xp = get_array_module(vector)
total = (vector ** 2).sum()
return xp.sqrt(total) if total != 0. else 0.
@property @property
def n_lefts(self): def n_lefts(self):