mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Back-off to tensor for similarity if no vectors
This commit is contained in:
parent
1e9634691a
commit
144a93c2a5
|
@ -75,3 +75,11 @@ def test_en_models_probs(example):
|
|||
assert not prob0 == prob1
|
||||
assert not prob0 == prob2
|
||||
assert not prob1 == prob2
|
||||
|
||||
|
||||
@pytest.mark.models('en')
|
||||
def test_no_vectors_similarity(EN):
|
||||
doc1 = EN(u'hallo')
|
||||
doc2 = EN(u'hi')
|
||||
assert doc1.similarity(doc2) > 0
|
||||
|
||||
|
|
|
@ -307,7 +307,7 @@ cdef class Doc:
|
|||
def __get__(self):
|
||||
if 'has_vector' in self.user_hooks:
|
||||
return self.user_hooks['has_vector'](self)
|
||||
elif any(token.has_vector for token in self):
|
||||
elif self.vocab.vectors.data.size:
|
||||
return True
|
||||
elif self.tensor.size:
|
||||
return True
|
||||
|
@ -330,13 +330,13 @@ cdef class Doc:
|
|||
self._vector = numpy.zeros((self.vocab.vectors_length,),
|
||||
dtype='f')
|
||||
return self._vector
|
||||
elif self.has_vector:
|
||||
elif self.vocab.vectors.data.size > 0:
|
||||
vector = numpy.zeros((self.vocab.vectors_length,), dtype='f')
|
||||
for token in self.c[:self.length]:
|
||||
vector += self.vocab.get_vector(token.lex.orth)
|
||||
self._vector = vector / len(self)
|
||||
return self._vector
|
||||
elif self.tensor.size:
|
||||
elif self.tensor.size > 0:
|
||||
self._vector = self.tensor.mean(axis=0)
|
||||
return self._vector
|
||||
else:
|
||||
|
|
|
@ -283,7 +283,12 @@ cdef class Span:
|
|||
def __get__(self):
|
||||
if 'has_vector' in self.doc.user_span_hooks:
|
||||
return self.doc.user_span_hooks['has_vector'](self)
|
||||
return any(token.has_vector for token in self)
|
||||
elif self.vocab.vectors.data.size > 0:
|
||||
return any(token.has_vector for token in self)
|
||||
elif self.doc.tensor.size > 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
property vector:
|
||||
"""A real-valued meaning representation. Defaults to an average of the
|
||||
|
|
|
@ -292,6 +292,8 @@ cdef class Token:
|
|||
def __get__(self):
|
||||
if 'has_vector' in self.doc.user_token_hooks:
|
||||
return self.doc.user_token_hooks['has_vector'](self)
|
||||
if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0:
|
||||
return True
|
||||
return self.vocab.has_vector(self.c.lex.orth)
|
||||
|
||||
property vector:
|
||||
|
@ -303,7 +305,10 @@ cdef class Token:
|
|||
def __get__(self):
|
||||
if 'vector' in self.doc.user_token_hooks:
|
||||
return self.doc.user_token_hooks['vector'](self)
|
||||
return self.vocab.get_vector(self.c.lex.orth)
|
||||
if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0:
|
||||
return self.doc.tensor[self.i]
|
||||
else:
|
||||
return self.vocab.get_vector(self.c.lex.orth)
|
||||
|
||||
property vector_norm:
|
||||
"""The L2 norm of the token's vector representation.
|
||||
|
|
Loading…
Reference in New Issue
Block a user