Fix retrieval of OOV vectors

This commit is contained in:
Matthew Honnibal 2017-08-22 19:46:35 +02:00
parent df2745eb08
commit 83f8e98450

View File

@ -4,6 +4,7 @@ from __future__ import unicode_literals
import bz2
import ujson
import re
import numpy
from libc.string cimport memset, memcpy
from libc.stdint cimport int32_t
@ -244,7 +245,7 @@ cdef class Vocab:
@property
def vectors_length(self):
return self.vectors.data.shape[0]
return self.vectors.data.shape[1]
def clear_vectors(self, new_dim=None):
"""Drop the current vector table. Because all vectors must be the same
@ -268,7 +269,10 @@ cdef class Vocab:
"""
if isinstance(orth, basestring_):
orth = self.strings.add(orth)
return self.vectors[orth]
if orth in self.vectors.key2row:
return self.vectors[orth]
else:
return numpy.zeros((self.vectors_length,), dtype='f')
def set_vector(self, orth, vector):
"""Set a vector for a word in the vocabulary.