mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Fix similarity calculation if vectors are on GPU (#3440)
This commit is contained in:
parent
1612990e88
commit
72889a16d5
|
@ -416,8 +416,9 @@ cdef class Doc:
|
||||||
return self.user_hooks["vector"](self)
|
return self.user_hooks["vector"](self)
|
||||||
if self._vector is not None:
|
if self._vector is not None:
|
||||||
return self._vector
|
return self._vector
|
||||||
elif not len(self):
|
xp = get_array_module(self.vocab.vectors.data)
|
||||||
self._vector = numpy.zeros((self.vocab.vectors_length,), dtype="f")
|
if not len(self):
|
||||||
|
self._vector = xp.zeros((self.vocab.vectors_length,), dtype="f")
|
||||||
return self._vector
|
return self._vector
|
||||||
elif self.vocab.vectors.data.size > 0:
|
elif self.vocab.vectors.data.size > 0:
|
||||||
self._vector = sum(t.vector for t in self) / len(self)
|
self._vector = sum(t.vector for t in self) / len(self)
|
||||||
|
@ -426,7 +427,7 @@ cdef class Doc:
|
||||||
self._vector = self.tensor.mean(axis=0)
|
self._vector = self.tensor.mean(axis=0)
|
||||||
return self._vector
|
return self._vector
|
||||||
else:
|
else:
|
||||||
return numpy.zeros((self.vocab.vectors_length,), dtype="float32")
|
return xp.zeros((self.vocab.vectors_length,), dtype="float32")
|
||||||
|
|
||||||
def __set__(self, value):
|
def __set__(self, value):
|
||||||
self._vector = value
|
self._vector = value
|
||||||
|
|
|
@ -420,13 +420,11 @@ cdef class Span:
|
||||||
"""
|
"""
|
||||||
if "vector_norm" in self.doc.user_span_hooks:
|
if "vector_norm" in self.doc.user_span_hooks:
|
||||||
return self.doc.user_span_hooks["vector"](self)
|
return self.doc.user_span_hooks["vector"](self)
|
||||||
cdef float value
|
vector = self.vector
|
||||||
cdef double norm = 0
|
xp = get_array_module(vector)
|
||||||
if self._vector_norm is None:
|
if self._vector_norm is None:
|
||||||
norm = 0
|
total = (vector*vector).sum()
|
||||||
for value in self.vector:
|
self._vector_norm = xp.sqrt(total) if total != 0. else 0.
|
||||||
norm += value * value
|
|
||||||
self._vector_norm = sqrt(norm) if norm != 0 else 0
|
|
||||||
return self._vector_norm
|
return self._vector_norm
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -404,7 +404,9 @@ cdef class Token:
|
||||||
if "vector_norm" in self.doc.user_token_hooks:
|
if "vector_norm" in self.doc.user_token_hooks:
|
||||||
return self.doc.user_token_hooks["vector_norm"](self)
|
return self.doc.user_token_hooks["vector_norm"](self)
|
||||||
vector = self.vector
|
vector = self.vector
|
||||||
return numpy.sqrt((vector ** 2).sum())
|
xp = get_array_module(vector)
|
||||||
|
total = (vector ** 2).sum()
|
||||||
|
return xp.sqrt(total) if total != 0. else 0.
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_lefts(self):
|
def n_lefts(self):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user