mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Fix calculation of L2-norm for Lexeme
This commit is contained in:
parent
7638f439e5
commit
a0a4ada42a
|
@ -4,6 +4,7 @@ from libc.stdio cimport fopen, fclose, fread, fwrite, FILE
|
|||
from libc.string cimport memset
|
||||
from libc.stdint cimport int32_t
|
||||
from libc.stdint cimport uint64_t
|
||||
from libc.math cimport sqrt
|
||||
|
||||
import bz2
|
||||
from os import path
|
||||
|
@ -386,6 +387,7 @@ cdef class Vocab:
|
|||
cdef LexemeC* lexeme
|
||||
cdef attr_t orth
|
||||
cdef int32_t vec_len = -1
|
||||
cdef double norm = 0.0
|
||||
for line_num, line in enumerate(file_):
|
||||
pieces = line.split()
|
||||
word_str = pieces.pop(0)
|
||||
|
@ -397,9 +399,12 @@ cdef class Vocab:
|
|||
orth = self.strings[word_str]
|
||||
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
|
||||
lexeme.vector = <float*>self.mem.alloc(vec_len, sizeof(float))
|
||||
|
||||
for i, val_str in enumerate(pieces):
|
||||
lexeme.vector[i] = float(val_str)
|
||||
norm = 0.0
|
||||
for i in range(vec_len):
|
||||
norm += lexeme.vector[i] * lexeme.vector[i]
|
||||
lex.l2_norm = sqrt(norm)
|
||||
self.vectors_length = vec_len
|
||||
return vec_len
|
||||
|
||||
|
@ -438,14 +443,15 @@ cdef class Vocab:
|
|||
line_num += 1
|
||||
cdef LexemeC* lex
|
||||
cdef size_t lex_addr
|
||||
cdef double norm = 0.0
|
||||
cdef int i
|
||||
for orth, lex_addr in self._by_orth.items():
|
||||
lex = <LexemeC*>lex_addr
|
||||
if lex.lower < vectors.size():
|
||||
lex.vector = vectors[lex.lower]
|
||||
for i in range(vec_len):
|
||||
lex.l2_norm += (lex.vector[i] * lex.vector[i])
|
||||
lex.l2_norm = math.sqrt(lex.l2_norm)
|
||||
norm += lex.vector[i] * lex.vector[i]
|
||||
lex.l2_norm = sqrt(norm)
|
||||
else:
|
||||
lex.vector = EMPTY_VEC
|
||||
self.vectors_length = vec_len
|
||||
|
|
Loading…
Reference in New Issue
Block a user