mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 08:12:24 +03:00
Fix retrieval of OOV vectors
This commit is contained in:
parent
df2745eb08
commit
83f8e98450
|
@ -4,6 +4,7 @@ from __future__ import unicode_literals
|
||||||
import bz2
|
import bz2
|
||||||
import ujson
|
import ujson
|
||||||
import re
|
import re
|
||||||
|
import numpy
|
||||||
|
|
||||||
from libc.string cimport memset, memcpy
|
from libc.string cimport memset, memcpy
|
||||||
from libc.stdint cimport int32_t
|
from libc.stdint cimport int32_t
|
||||||
|
@ -244,7 +245,7 @@ cdef class Vocab:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vectors_length(self):
|
def vectors_length(self):
|
||||||
return self.vectors.data.shape[0]
|
return self.vectors.data.shape[1]
|
||||||
|
|
||||||
def clear_vectors(self, new_dim=None):
|
def clear_vectors(self, new_dim=None):
|
||||||
"""Drop the current vector table. Because all vectors must be the same
|
"""Drop the current vector table. Because all vectors must be the same
|
||||||
|
@ -268,7 +269,10 @@ cdef class Vocab:
|
||||||
"""
|
"""
|
||||||
if isinstance(orth, basestring_):
|
if isinstance(orth, basestring_):
|
||||||
orth = self.strings.add(orth)
|
orth = self.strings.add(orth)
|
||||||
return self.vectors[orth]
|
if orth in self.vectors.key2row:
|
||||||
|
return self.vectors[orth]
|
||||||
|
else:
|
||||||
|
return numpy.zeros((self.vectors_length,), dtype='f')
|
||||||
|
|
||||||
def set_vector(self, orth, vector):
|
def set_vector(self, orth, vector):
|
||||||
"""Set a vector for a word in the vocabulary.
|
"""Set a vector for a word in the vocabulary.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user