* Try giving Doc and Span objects vector and vector_norm attributes, and .similarity functions. Turns out to be bad idea.

This commit is contained in:
Matthew Honnibal 2015-09-17 11:50:11 +10:00
parent 191d593e03
commit 77856c4fcd
4 changed files with 48 additions and 8 deletions

View File

@ -23,6 +23,9 @@ cdef class Doc:
cdef readonly Pool mem cdef readonly Pool mem
cdef readonly Vocab vocab cdef readonly Vocab vocab
cdef public object _vector
cdef public object _vector_norm
cdef TokenC* data cdef TokenC* data
cdef public bint is_tagged cdef public bint is_tagged

View File

@ -6,6 +6,7 @@ import numpy
import numpy.linalg import numpy.linalg
import struct import struct
cimport numpy as np cimport numpy as np
import math
from ..lexeme cimport Lexeme from ..lexeme cimport Lexeme
from ..lexeme cimport EMPTY_LEXEME from ..lexeme cimport EMPTY_LEXEME
@ -77,6 +78,7 @@ cdef class Doc:
self.is_tagged = False self.is_tagged = False
self.is_parsed = False self.is_parsed = False
self._py_tokens = [] self._py_tokens = []
self._vector = None
def __getitem__(self, object i): def __getitem__(self, object i):
"""Get a token. """Get a token.
@ -133,12 +135,25 @@ cdef class Doc:
property vector: property vector:
def __get__(self): def __get__(self):
return sum(t.vector for t in self if not t.is_stop) / len(self) if self._vector is None:
self._vector = sum(t.vector for t in self) / len(self)
return self._vector
def __set__(self, value):
self._vector = value
property vector_norm: property vector_norm:
def __get__(self): def __get__(self):
return numpy.linalg.norm(self.vector) cdef float value
if self._vector_norm is None:
self._vector_norm = 1e-20
for value in self.vector:
self._vector_norm += value * value
self._vector_norm = math.sqrt(self._vector_norm)
return self._vector_norm
def __set__(self, value):
self._vector_norm = value
@property @property
def string(self): def string(self):
@ -304,15 +319,14 @@ cdef class Doc:
cdef size_t count cdef size_t count
if counts is None: if counts is None:
counts = PreshCounter(self.length) counts = PreshCounter()
output_dict = True output_dict = True
else: else:
output_dict = False output_dict = False
# Take this check out of the loop, for a bit of extra speed # Take this check out of the loop, for a bit of extra speed
if exclude is None: if exclude is None:
for i in range(self.length): for i in range(self.length):
attr = get_token_attr(&self.data[i], attr_id) counts.inc(get_token_attr(&self.data[i], attr_id), 1)
counts.inc(attr, 1)
else: else:
for i in range(self.length): for i in range(self.length):
if not exclude(self[i]): if not exclude(self[i]):

View File

@ -7,3 +7,6 @@ cdef class Span:
cdef public int start cdef public int start
cdef public int end cdef public int end
cdef readonly int label cdef readonly int label
cdef public _vector
cdef public _vector_norm

View File

@ -3,6 +3,7 @@ from collections import defaultdict
import numpy import numpy
import numpy.linalg import numpy.linalg
cimport numpy as np cimport numpy as np
import math
from ..structs cimport TokenC, LexemeC from ..structs cimport TokenC, LexemeC
from ..typedefs cimport flags_t, attr_t from ..typedefs cimport flags_t, attr_t
@ -21,6 +22,8 @@ cdef class Span:
self.start = start self.start = start
self.end = end self.end = end
self.label = label self.label = label
self._vector = None
self._vector_norm = None
def __richcmp__(self, Span other, int op): def __richcmp__(self, Span other, int op):
# Eq # Eq
@ -60,15 +63,32 @@ cdef class Span:
property vector: property vector:
def __get__(self): def __get__(self):
return sum(t.vector for t in self if not t.is_stop) / len(self) if self._vector is None:
self._vector = sum(t.vector for t in self) / len(self)
return self._vector
def __set__(self, value):
self._vector = value
property vector_norm: property vector_norm:
def __get__(self): def __get__(self):
return numpy.linalg.norm(self.vector) cdef float value
if self._vector_norm is None:
self._vector_norm = 1e-20
for value in self.vector:
self._vector_norm += value * value
self._vector_norm = math.sqrt(self._vector_norm)
return self._vector_norm
def __set__(self, value):
self._vector_norm = value
property text: property text:
def __get__(self): def __get__(self):
return u' '.join([t.text for t in self]) text = self.text_with_ws
if self[-1].whitespace_:
text = text[:-1]
return text
property text_with_ws: property text_with_ws:
def __get__(self): def __get__(self):