Back-off to tensor for similarity if no vectors

This commit is contained in:
Matthew Honnibal 2017-11-03 20:56:33 +01:00
parent 1e9634691a
commit 144a93c2a5
4 changed files with 23 additions and 5 deletions

View File

@ -75,3 +75,11 @@ def test_en_models_probs(example):
assert not prob0 == prob1 assert not prob0 == prob1
assert not prob0 == prob2 assert not prob0 == prob2
assert not prob1 == prob2 assert not prob1 == prob2
@pytest.mark.models('en')
def test_no_vectors_similarity(EN):
doc1 = EN(u'hallo')
doc2 = EN(u'hi')
assert doc1.similarity(doc2) > 0

View File

@ -307,7 +307,7 @@ cdef class Doc:
def __get__(self): def __get__(self):
if 'has_vector' in self.user_hooks: if 'has_vector' in self.user_hooks:
return self.user_hooks['has_vector'](self) return self.user_hooks['has_vector'](self)
elif any(token.has_vector for token in self): elif self.vocab.vectors.data.size:
return True return True
elif self.tensor.size: elif self.tensor.size:
return True return True
@ -330,13 +330,13 @@ cdef class Doc:
self._vector = numpy.zeros((self.vocab.vectors_length,), self._vector = numpy.zeros((self.vocab.vectors_length,),
dtype='f') dtype='f')
return self._vector return self._vector
elif self.has_vector: elif self.vocab.vectors.data.size > 0:
vector = numpy.zeros((self.vocab.vectors_length,), dtype='f') vector = numpy.zeros((self.vocab.vectors_length,), dtype='f')
for token in self.c[:self.length]: for token in self.c[:self.length]:
vector += self.vocab.get_vector(token.lex.orth) vector += self.vocab.get_vector(token.lex.orth)
self._vector = vector / len(self) self._vector = vector / len(self)
return self._vector return self._vector
elif self.tensor.size: elif self.tensor.size > 0:
self._vector = self.tensor.mean(axis=0) self._vector = self.tensor.mean(axis=0)
return self._vector return self._vector
else: else:

View File

@ -283,7 +283,12 @@ cdef class Span:
def __get__(self): def __get__(self):
if 'has_vector' in self.doc.user_span_hooks: if 'has_vector' in self.doc.user_span_hooks:
return self.doc.user_span_hooks['has_vector'](self) return self.doc.user_span_hooks['has_vector'](self)
elif self.vocab.vectors.data.size > 0:
return any(token.has_vector for token in self) return any(token.has_vector for token in self)
elif self.doc.tensor.size > 0:
return True
else:
return False
property vector: property vector:
"""A real-valued meaning representation. Defaults to an average of the """A real-valued meaning representation. Defaults to an average of the

View File

@ -292,6 +292,8 @@ cdef class Token:
def __get__(self): def __get__(self):
if 'has_vector' in self.doc.user_token_hooks: if 'has_vector' in self.doc.user_token_hooks:
return self.doc.user_token_hooks['has_vector'](self) return self.doc.user_token_hooks['has_vector'](self)
if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0:
return True
return self.vocab.has_vector(self.c.lex.orth) return self.vocab.has_vector(self.c.lex.orth)
property vector: property vector:
@ -303,6 +305,9 @@ cdef class Token:
def __get__(self): def __get__(self):
if 'vector' in self.doc.user_token_hooks: if 'vector' in self.doc.user_token_hooks:
return self.doc.user_token_hooks['vector'](self) return self.doc.user_token_hooks['vector'](self)
if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0:
return self.doc.tensor[self.i]
else:
return self.vocab.get_vector(self.c.lex.orth) return self.vocab.get_vector(self.c.lex.orth)
property vector_norm: property vector_norm: