mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Back-off to tensor for similarity if no vectors
This commit is contained in:
parent
1e9634691a
commit
144a93c2a5
|
@ -75,3 +75,11 @@ def test_en_models_probs(example):
|
||||||
assert not prob0 == prob1
|
assert not prob0 == prob1
|
||||||
assert not prob0 == prob2
|
assert not prob0 == prob2
|
||||||
assert not prob1 == prob2
|
assert not prob1 == prob2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.models('en')
|
||||||
|
def test_no_vectors_similarity(EN):
|
||||||
|
doc1 = EN(u'hallo')
|
||||||
|
doc2 = EN(u'hi')
|
||||||
|
assert doc1.similarity(doc2) > 0
|
||||||
|
|
||||||
|
|
|
@ -307,7 +307,7 @@ cdef class Doc:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
if 'has_vector' in self.user_hooks:
|
if 'has_vector' in self.user_hooks:
|
||||||
return self.user_hooks['has_vector'](self)
|
return self.user_hooks['has_vector'](self)
|
||||||
elif any(token.has_vector for token in self):
|
elif self.vocab.vectors.data.size:
|
||||||
return True
|
return True
|
||||||
elif self.tensor.size:
|
elif self.tensor.size:
|
||||||
return True
|
return True
|
||||||
|
@ -330,13 +330,13 @@ cdef class Doc:
|
||||||
self._vector = numpy.zeros((self.vocab.vectors_length,),
|
self._vector = numpy.zeros((self.vocab.vectors_length,),
|
||||||
dtype='f')
|
dtype='f')
|
||||||
return self._vector
|
return self._vector
|
||||||
elif self.has_vector:
|
elif self.vocab.vectors.data.size > 0:
|
||||||
vector = numpy.zeros((self.vocab.vectors_length,), dtype='f')
|
vector = numpy.zeros((self.vocab.vectors_length,), dtype='f')
|
||||||
for token in self.c[:self.length]:
|
for token in self.c[:self.length]:
|
||||||
vector += self.vocab.get_vector(token.lex.orth)
|
vector += self.vocab.get_vector(token.lex.orth)
|
||||||
self._vector = vector / len(self)
|
self._vector = vector / len(self)
|
||||||
return self._vector
|
return self._vector
|
||||||
elif self.tensor.size:
|
elif self.tensor.size > 0:
|
||||||
self._vector = self.tensor.mean(axis=0)
|
self._vector = self.tensor.mean(axis=0)
|
||||||
return self._vector
|
return self._vector
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -283,7 +283,12 @@ cdef class Span:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
if 'has_vector' in self.doc.user_span_hooks:
|
if 'has_vector' in self.doc.user_span_hooks:
|
||||||
return self.doc.user_span_hooks['has_vector'](self)
|
return self.doc.user_span_hooks['has_vector'](self)
|
||||||
return any(token.has_vector for token in self)
|
elif self.vocab.vectors.data.size > 0:
|
||||||
|
return any(token.has_vector for token in self)
|
||||||
|
elif self.doc.tensor.size > 0:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
property vector:
|
property vector:
|
||||||
"""A real-valued meaning representation. Defaults to an average of the
|
"""A real-valued meaning representation. Defaults to an average of the
|
||||||
|
|
|
@ -292,6 +292,8 @@ cdef class Token:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
if 'has_vector' in self.doc.user_token_hooks:
|
if 'has_vector' in self.doc.user_token_hooks:
|
||||||
return self.doc.user_token_hooks['has_vector'](self)
|
return self.doc.user_token_hooks['has_vector'](self)
|
||||||
|
if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0:
|
||||||
|
return True
|
||||||
return self.vocab.has_vector(self.c.lex.orth)
|
return self.vocab.has_vector(self.c.lex.orth)
|
||||||
|
|
||||||
property vector:
|
property vector:
|
||||||
|
@ -303,7 +305,10 @@ cdef class Token:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
if 'vector' in self.doc.user_token_hooks:
|
if 'vector' in self.doc.user_token_hooks:
|
||||||
return self.doc.user_token_hooks['vector'](self)
|
return self.doc.user_token_hooks['vector'](self)
|
||||||
return self.vocab.get_vector(self.c.lex.orth)
|
if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0:
|
||||||
|
return self.doc.tensor[self.i]
|
||||||
|
else:
|
||||||
|
return self.vocab.get_vector(self.c.lex.orth)
|
||||||
|
|
||||||
property vector_norm:
|
property vector_norm:
|
||||||
"""The L2 norm of the token's vector representation.
|
"""The L2 norm of the token's vector representation.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user