Fix retrieval of OOV vectors

This commit is contained in:
Matthew Honnibal 2017-08-22 19:46:35 +02:00
parent df2745eb08
commit 83f8e98450

View File

@ -4,6 +4,7 @@ from __future__ import unicode_literals
import bz2 import bz2
import ujson import ujson
import re import re
import numpy
from libc.string cimport memset, memcpy from libc.string cimport memset, memcpy
from libc.stdint cimport int32_t from libc.stdint cimport int32_t
@ -244,7 +245,7 @@ cdef class Vocab:
@property @property
def vectors_length(self): def vectors_length(self):
return self.vectors.data.shape[0] return self.vectors.data.shape[1]
def clear_vectors(self, new_dim=None): def clear_vectors(self, new_dim=None):
"""Drop the current vector table. Because all vectors must be the same """Drop the current vector table. Because all vectors must be the same
@ -268,7 +269,10 @@ cdef class Vocab:
""" """
if isinstance(orth, basestring_): if isinstance(orth, basestring_):
orth = self.strings.add(orth) orth = self.strings.add(orth)
return self.vectors[orth] if orth in self.vectors.key2row:
return self.vectors[orth]
else:
return numpy.zeros((self.vectors_length,), dtype='f')
def set_vector(self, orth, vector): def set_vector(self, orth, vector):
"""Set a vector for a word in the vocabulary. """Set a vector for a word in the vocabulary.