Fix calculation of vector norm, re Issue #522. Need to consolidate the calculations into a helper function.

This commit is contained in:
Matthew Honnibal 2016-10-23 14:49:31 +02:00
parent a0a4ada42a
commit 2c3a67b693
2 changed files with 10 additions and 8 deletions

View File

@ -1,12 +1,12 @@
cimport cython cimport cython
from libc.string cimport memcpy, memset from libc.string cimport memcpy, memset
from libc.stdint cimport uint32_t from libc.stdint cimport uint32_t
from libc.math cimport sqrt
import numpy import numpy
import numpy.linalg import numpy.linalg
import struct import struct
cimport numpy as np cimport numpy as np
import math
import six import six
import warnings import warnings
@ -251,11 +251,12 @@ cdef class Doc:
if 'vector_norm' in self.user_hooks: if 'vector_norm' in self.user_hooks:
return self.user_hooks['vector_norm'](self) return self.user_hooks['vector_norm'](self)
cdef float value cdef float value
cdef double norm = 0
if self._vector_norm is None: if self._vector_norm is None:
self._vector_norm = 1e-20 norm = 0.0
for value in self.vector: for value in self.vector:
self._vector_norm += value * value norm += value * value
self._vector_norm = math.sqrt(self._vector_norm) self._vector_norm = sqrt(norm) if norm != 0 else 0
return self._vector_norm return self._vector_norm
def __set__(self, value): def __set__(self, value):

View File

@ -3,7 +3,7 @@ from collections import defaultdict
import numpy import numpy
import numpy.linalg import numpy.linalg
cimport numpy as np cimport numpy as np
import math from libc.math cimport sqrt
import six import six
from ..structs cimport TokenC, LexemeC from ..structs cimport TokenC, LexemeC
@ -136,11 +136,12 @@ cdef class Span:
if 'vector_norm' in self.doc.user_span_hooks: if 'vector_norm' in self.doc.user_span_hooks:
return self.doc.user_span_hooks['vector'](self) return self.doc.user_span_hooks['vector'](self)
cdef float value cdef float value
cdef double norm = 0
if self._vector_norm is None: if self._vector_norm is None:
self._vector_norm = 1e-20 norm = 0
for value in self.vector: for value in self.vector:
self._vector_norm += value * value norm += value * value
self._vector_norm = math.sqrt(self._vector_norm) self._vector_norm = sqrt(norm) if norm != 0 else 0
return self._vector_norm return self._vector_norm
property text: property text: