mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Fix calculation of L2-norm for Lexeme
This commit is contained in:
		
							parent
							
								
									7638f439e5
								
							
						
					
					
						commit
						a0a4ada42a
					
				|  | @ -4,6 +4,7 @@ from libc.stdio cimport fopen, fclose, fread, fwrite, FILE | ||||||
| from libc.string cimport memset | from libc.string cimport memset | ||||||
| from libc.stdint cimport int32_t | from libc.stdint cimport int32_t | ||||||
| from libc.stdint cimport uint64_t | from libc.stdint cimport uint64_t | ||||||
|  | from libc.math cimport sqrt | ||||||
| 
 | 
 | ||||||
| import bz2 | import bz2 | ||||||
| from os import path | from os import path | ||||||
|  | @ -386,6 +387,7 @@ cdef class Vocab: | ||||||
|         cdef LexemeC* lexeme |         cdef LexemeC* lexeme | ||||||
|         cdef attr_t orth |         cdef attr_t orth | ||||||
|         cdef int32_t vec_len = -1 |         cdef int32_t vec_len = -1 | ||||||
|  |         cdef double norm = 0.0 | ||||||
|         for line_num, line in enumerate(file_): |         for line_num, line in enumerate(file_): | ||||||
|             pieces = line.split() |             pieces = line.split() | ||||||
|             word_str = pieces.pop(0) |             word_str = pieces.pop(0) | ||||||
|  | @ -397,9 +399,12 @@ cdef class Vocab: | ||||||
|             orth = self.strings[word_str] |             orth = self.strings[word_str] | ||||||
|             lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth) |             lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth) | ||||||
|             lexeme.vector = <float*>self.mem.alloc(vec_len, sizeof(float)) |             lexeme.vector = <float*>self.mem.alloc(vec_len, sizeof(float)) | ||||||
| 
 |  | ||||||
|             for i, val_str in enumerate(pieces): |             for i, val_str in enumerate(pieces): | ||||||
|                 lexeme.vector[i] = float(val_str) |                 lexeme.vector[i] = float(val_str) | ||||||
|  |             norm = 0.0 | ||||||
|  |             for i in range(vec_len): | ||||||
|  |                 norm += lexeme.vector[i] * lexeme.vector[i] | ||||||
|  |             lex.l2_norm = sqrt(norm) | ||||||
|         self.vectors_length = vec_len |         self.vectors_length = vec_len | ||||||
|         return vec_len |         return vec_len | ||||||
| 
 | 
 | ||||||
|  | @ -438,14 +443,15 @@ cdef class Vocab: | ||||||
|             line_num += 1 |             line_num += 1 | ||||||
|         cdef LexemeC* lex |         cdef LexemeC* lex | ||||||
|         cdef size_t lex_addr |         cdef size_t lex_addr | ||||||
|  |         cdef double norm = 0.0 | ||||||
|         cdef int i |         cdef int i | ||||||
|         for orth, lex_addr in self._by_orth.items(): |         for orth, lex_addr in self._by_orth.items(): | ||||||
|             lex = <LexemeC*>lex_addr |             lex = <LexemeC*>lex_addr | ||||||
|             if lex.lower < vectors.size(): |             if lex.lower < vectors.size(): | ||||||
|                 lex.vector = vectors[lex.lower] |                 lex.vector = vectors[lex.lower] | ||||||
|                 for i in range(vec_len): |                 for i in range(vec_len): | ||||||
|                     lex.l2_norm += (lex.vector[i] * lex.vector[i]) |                     norm += lex.vector[i] * lex.vector[i] | ||||||
|                 lex.l2_norm = math.sqrt(lex.l2_norm) |                 lex.l2_norm = sqrt(norm) | ||||||
|             else: |             else: | ||||||
|                 lex.vector = EMPTY_VEC |                 lex.vector = EMPTY_VEC | ||||||
|         self.vectors_length = vec_len |         self.vectors_length = vec_len | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user