mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
Try using tensor for vector/similarity methdos
This commit is contained in:
parent
a131981f3b
commit
498ad85309
|
@ -30,6 +30,7 @@ from ..syntax.iterators import CHUNKERS
|
|||
from ..util import normalize_slice
|
||||
from ..compat import is_config
|
||||
from .. import about
|
||||
from .. import util
|
||||
|
||||
|
||||
DEF PADDING = 5
|
||||
|
@ -252,8 +253,12 @@ cdef class Doc:
|
|||
def __get__(self):
|
||||
if 'has_vector' in self.user_hooks:
|
||||
return self.user_hooks['has_vector'](self)
|
||||
|
||||
return any(token.has_vector for token in self)
|
||||
elif any(token.has_vector for token in self):
|
||||
return True
|
||||
elif self.tensor:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
property vector:
|
||||
"""A real-valued meaning representation. Defaults to an average of the
|
||||
|
@ -265,12 +270,16 @@ cdef class Doc:
|
|||
def __get__(self):
|
||||
if 'vector' in self.user_hooks:
|
||||
return self.user_hooks['vector'](self)
|
||||
if self._vector is None:
|
||||
if len(self):
|
||||
self._vector = sum(t.vector for t in self) / len(self)
|
||||
else:
|
||||
return numpy.zeros((self.vocab.vectors_length,), dtype='float32')
|
||||
return self._vector
|
||||
if self._vector is not None:
|
||||
return self._vector
|
||||
elif self.has_vector and len(self):
|
||||
self._vector = sum(t.vector for t in self) / len(self)
|
||||
return self._vector
|
||||
elif self.tensor:
|
||||
self._vector = self.tensor.mean(axis=0)
|
||||
return self._vector
|
||||
else:
|
||||
return numpy.zeros((self.vocab.vectors_length,), dtype='float32')
|
||||
|
||||
def __set__(self, value):
|
||||
self._vector = value
|
||||
|
@ -295,10 +304,6 @@ cdef class Doc:
|
|||
def __set__(self, value):
|
||||
self._vector_norm = value
|
||||
|
||||
@property
|
||||
def string(self):
|
||||
return self.text
|
||||
|
||||
property text:
|
||||
"""A unicode representation of the document text.
|
||||
|
||||
|
@ -598,15 +603,16 @@ cdef class Doc:
|
|||
self.is_tagged = bool(TAG in attrs or POS in attrs)
|
||||
return self
|
||||
|
||||
def to_disk(self, path):
|
||||
def to_disk(self, path, **exclude):
|
||||
"""Save the current state to a directory.
|
||||
|
||||
path (unicode or Path): A path to a directory, which will be created if
|
||||
it doesn't exist. Paths may be either strings or `Path`-like objects.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
with path.open('wb') as file_:
|
||||
file_.write(self.to_bytes(**exclude))
|
||||
|
||||
def from_disk(self, path):
|
||||
def from_disk(self, path, **exclude):
|
||||
"""Loads state from a directory. Modifies the object in place and
|
||||
returns it.
|
||||
|
||||
|
@ -614,25 +620,28 @@ cdef class Doc:
|
|||
strings or `Path`-like objects.
|
||||
RETURNS (Doc): The modified `Doc` object.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
with path.open('rb') as file_:
|
||||
bytes_data = file_.read()
|
||||
self.from_bytes(bytes_data, **exclude)
|
||||
|
||||
def to_bytes(self):
|
||||
def to_bytes(self, **exclude):
|
||||
"""Serialize, i.e. export the document contents to a binary string.
|
||||
|
||||
RETURNS (bytes): A losslessly serialized copy of the `Doc`, including
|
||||
all annotations.
|
||||
"""
|
||||
return dill.dumps(
|
||||
(self.text,
|
||||
self.to_array([LENGTH,SPACY,TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE]),
|
||||
self.sentiment,
|
||||
self.tensor,
|
||||
self.noun_chunks_iterator,
|
||||
self.user_data,
|
||||
(self.user_hooks, self.user_token_hooks, self.user_span_hooks)),
|
||||
protocol=-1)
|
||||
array_head = [LENGTH,SPACY,TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE]
|
||||
serializers = {
|
||||
'text': lambda: self.text,
|
||||
'array_head': lambda: array_head,
|
||||
'array_body': lambda: self.to_array(array_head),
|
||||
'sentiment': lambda: self.sentiment,
|
||||
'tensor': lambda: self.tensor,
|
||||
'user_data': lambda: self.user_data
|
||||
}
|
||||
return util.to_bytes(serializers, exclude)
|
||||
|
||||
def from_bytes(self, data):
|
||||
def from_bytes(self, bytes_data, **exclude):
|
||||
"""Deserialize, i.e. import the document contents from a binary string.
|
||||
|
||||
data (bytes): The string to load from.
|
||||
|
@ -640,27 +649,36 @@ cdef class Doc:
|
|||
"""
|
||||
if self.length != 0:
|
||||
raise ValueError("Cannot load into non-empty Doc")
|
||||
deserializers = {
|
||||
'text': lambda b: None,
|
||||
'array_head': lambda b: None,
|
||||
'array_body': lambda b: None,
|
||||
'sentiment': lambda b: None,
|
||||
'tensor': lambda b: None,
|
||||
'user_data': lambda user_data: self.user_data.update(user_data)
|
||||
}
|
||||
|
||||
msg = util.from_bytes(bytes_data, deserializers, exclude)
|
||||
|
||||
cdef attr_t[:, :] attrs
|
||||
cdef int i, start, end, has_space
|
||||
fields = dill.loads(data)
|
||||
text, attrs = fields[:2]
|
||||
self.sentiment, self.tensor = fields[2:4]
|
||||
self.noun_chunks_iterator, self.user_data = fields[4:6]
|
||||
self.user_hooks, self.user_token_hooks, self.user_span_hooks = fields[6]
|
||||
self.sentiment = msg['sentiment']
|
||||
self.tensor = msg['tensor']
|
||||
|
||||
start = 0
|
||||
cdef const LexemeC* lex
|
||||
cdef unicode orth_
|
||||
text = msg['text']
|
||||
attrs = msg['array_body']
|
||||
for i in range(attrs.shape[0]):
|
||||
end = start + attrs[i, 0]
|
||||
has_space = attrs[i, 1]
|
||||
orth_ = text[start:end]
|
||||
lex = self.vocab.get(self.mem, orth_)
|
||||
self.push_back(lex, has_space)
|
||||
|
||||
start = end + has_space
|
||||
self.from_array([TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE],
|
||||
attrs[:, 2:])
|
||||
self.from_array(msg['array_head'][2:],
|
||||
attrs[:, 2:])
|
||||
return self
|
||||
|
||||
def merge(self, int start_idx, int end_idx, *args, **attributes):
|
||||
|
|
|
@ -111,7 +111,7 @@ cdef class Token:
|
|||
RETURNS (float): A scalar similarity score. Higher is more similar.
|
||||
"""
|
||||
if 'similarity' in self.doc.user_token_hooks:
|
||||
return self.doc.user_token_hooks['similarity'](self)
|
||||
return self.doc.user_token_hooks['similarity'](self)
|
||||
if self.vector_norm == 0 or other.vector_norm == 0:
|
||||
return 0.0
|
||||
return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm)
|
||||
|
@ -245,7 +245,10 @@ cdef class Token:
|
|||
def __get__(self):
|
||||
if 'vector' in self.doc.user_token_hooks:
|
||||
return self.doc.user_token_hooks['vector'](self)
|
||||
return self.vocab.get_vector(self.c.lex.orth)
|
||||
if self.has_vector:
|
||||
return self.vocab.get_vector(self.c.lex.orth)
|
||||
else:
|
||||
return self.doc.tensor[self.i]
|
||||
|
||||
property vector_norm:
|
||||
"""The L2 norm of the token's vector representation.
|
||||
|
|
Loading…
Reference in New Issue
Block a user