mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Fix similarity calculation if vectors are on GPU (#3440)
This commit is contained in:
parent
1612990e88
commit
72889a16d5
|
@ -416,8 +416,9 @@ cdef class Doc:
|
|||
return self.user_hooks["vector"](self)
|
||||
if self._vector is not None:
|
||||
return self._vector
|
||||
elif not len(self):
|
||||
self._vector = numpy.zeros((self.vocab.vectors_length,), dtype="f")
|
||||
xp = get_array_module(self.vocab.vectors.data)
|
||||
if not len(self):
|
||||
self._vector = xp.zeros((self.vocab.vectors_length,), dtype="f")
|
||||
return self._vector
|
||||
elif self.vocab.vectors.data.size > 0:
|
||||
self._vector = sum(t.vector for t in self) / len(self)
|
||||
|
@ -426,7 +427,7 @@ cdef class Doc:
|
|||
self._vector = self.tensor.mean(axis=0)
|
||||
return self._vector
|
||||
else:
|
||||
return numpy.zeros((self.vocab.vectors_length,), dtype="float32")
|
||||
return xp.zeros((self.vocab.vectors_length,), dtype="float32")
|
||||
|
||||
def __set__(self, value):
|
||||
self._vector = value
|
||||
|
|
|
@ -420,13 +420,11 @@ cdef class Span:
|
|||
"""
|
||||
if "vector_norm" in self.doc.user_span_hooks:
|
||||
return self.doc.user_span_hooks["vector"](self)
|
||||
cdef float value
|
||||
cdef double norm = 0
|
||||
vector = self.vector
|
||||
xp = get_array_module(vector)
|
||||
if self._vector_norm is None:
|
||||
norm = 0
|
||||
for value in self.vector:
|
||||
norm += value * value
|
||||
self._vector_norm = sqrt(norm) if norm != 0 else 0
|
||||
total = (vector*vector).sum()
|
||||
self._vector_norm = xp.sqrt(total) if total != 0. else 0.
|
||||
return self._vector_norm
|
||||
|
||||
@property
|
||||
|
|
|
@ -404,7 +404,9 @@ cdef class Token:
|
|||
if "vector_norm" in self.doc.user_token_hooks:
|
||||
return self.doc.user_token_hooks["vector_norm"](self)
|
||||
vector = self.vector
|
||||
return numpy.sqrt((vector ** 2).sum())
|
||||
xp = get_array_module(vector)
|
||||
total = (vector ** 2).sum()
|
||||
return xp.sqrt(total) if total != 0. else 0.
|
||||
|
||||
@property
|
||||
def n_lefts(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user