From a0a4ada42ae5c744e33f5c8ae34b50830e5e7435 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 23 Oct 2016 14:44:45 +0200 Subject: [PATCH] Fix calculation of L2-norm for Lexeme --- spacy/vocab.pyx | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index f852e67da..9af55790a 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -4,6 +4,7 @@ from libc.stdio cimport fopen, fclose, fread, fwrite, FILE from libc.string cimport memset from libc.stdint cimport int32_t from libc.stdint cimport uint64_t +from libc.math cimport sqrt import bz2 from os import path @@ -386,6 +387,7 @@ cdef class Vocab: cdef LexemeC* lexeme cdef attr_t orth cdef int32_t vec_len = -1 + cdef double norm = 0.0 for line_num, line in enumerate(file_): pieces = line.split() word_str = pieces.pop(0) @@ -397,9 +399,12 @@ cdef class Vocab: orth = self.strings[word_str] lexeme = self.get_by_orth(self.mem, orth) lexeme.vector = self.mem.alloc(vec_len, sizeof(float)) - for i, val_str in enumerate(pieces): lexeme.vector[i] = float(val_str) + norm = 0.0 + for i in range(vec_len): + norm += lexeme.vector[i] * lexeme.vector[i] + lex.l2_norm = sqrt(norm) self.vectors_length = vec_len return vec_len @@ -438,14 +443,15 @@ cdef class Vocab: line_num += 1 cdef LexemeC* lex cdef size_t lex_addr + cdef double norm = 0.0 cdef int i for orth, lex_addr in self._by_orth.items(): lex = lex_addr if lex.lower < vectors.size(): lex.vector = vectors[lex.lower] for i in range(vec_len): - lex.l2_norm += (lex.vector[i] * lex.vector[i]) - lex.l2_norm = math.sqrt(lex.l2_norm) + norm += lex.vector[i] * lex.vector[i] + lex.l2_norm = sqrt(norm) else: lex.vector = EMPTY_VEC self.vectors_length = vec_len