Fix vector loading

This commit is contained in:
Matthew Honnibal 2017-08-19 19:52:25 +02:00
parent 49a615e7d9
commit 93fb8b64e9
2 changed files with 29 additions and 13 deletions

View File

@ -4,16 +4,12 @@ from collections import OrderedDict
import msgpack import msgpack
import msgpack_numpy import msgpack_numpy
msgpack_numpy.patch() msgpack_numpy.patch()
from cymem.cymem cimport Pool
cimport numpy as np cimport numpy as np
from libcpp.vector cimport vector
from .typedefs cimport attr_t from .typedefs cimport attr_t
from .strings cimport StringStore from .strings cimport StringStore
from . import util from . import util
from ._cfile cimport CFile from .compat import basestring_
MAX_VEC_SIZE = 10000
cdef class Vectors: cdef class Vectors:
@ -60,7 +56,21 @@ cdef class Vectors:
yield from self.data yield from self.data
def __len__(self): def __len__(self):
return len(self.strings) # TODO: Fix the quadratic behaviour here!
return max(self.key2row.values())
def __contains__(self, key):
if isinstance(key, basestring_):
key = self.strings[key]
return key in self.key2row
def add_key(self, string, vector=None):
key = self.strings.add(string)
next_i = len(self) + 1
self.keys[next_i] = key
self.key2row[key] = next_i
if vector is not None:
self.data[next_i] = vector
def items(self): def items(self):
for i, string in enumerate(self.strings): for i, string in enumerate(self.strings):
@ -75,9 +85,9 @@ cdef class Vectors:
def to_disk(self, path, **exclude): def to_disk(self, path, **exclude):
serializers = OrderedDict(( serializers = OrderedDict((
('vectors', lambda p: numpy.save(p.open('wb'), self.data)), ('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)),
('strings.json', self.strings.to_disk), ('strings.json', self.strings.to_disk),
('keys', lambda p: numpy.save(p.open('wb'), self.keys)), ('keys', lambda p: numpy.save(p.open('wb'), self.keys, allow_pickle=False)),
)) ))
return util.to_disk(path, serializers, exclude) return util.to_disk(path, serializers, exclude)

View File

@ -19,7 +19,7 @@ from .tokens.token cimport Token
from .attrs cimport PROB, LANG from .attrs cimport PROB, LANG
from .structs cimport SerializedLexemeC from .structs cimport SerializedLexemeC
from .compat import copy_reg, pickle from .compat import copy_reg, pickle, basestring_
from .lemmatizer import Lemmatizer from .lemmatizer import Lemmatizer
from .attrs import intify_attrs from .attrs import intify_attrs
from .vectors import Vectors from .vectors import Vectors
@ -244,7 +244,7 @@ cdef class Vocab:
@property @property
def vectors_length(self): def vectors_length(self):
raise NotImplementedError return len(self.vectors)
def clear_vectors(self): def clear_vectors(self):
"""Drop the current vector table. Because all vectors must be the same """Drop the current vector table. Because all vectors must be the same
@ -264,7 +264,9 @@ cdef class Vocab:
RAISES: If no vectors data is loaded, ValueError is raised. RAISES: If no vectors data is loaded, ValueError is raised.
""" """
raise NotImplementedError if isinstance(orth, basestring_):
orth = self.strings.add(orth)
return self.vectors[orth]
def set_vector(self, orth, vector): def set_vector(self, orth, vector):
"""Set a vector for a word in the vocabulary. """Set a vector for a word in the vocabulary.
@ -274,13 +276,17 @@ cdef class Vocab:
RETURNS: RETURNS:
None None
""" """
raise NotImplementedError if not isinstance(orth, basestring_):
orth = self.strings[orth]
self.vectors.add_key(orth, vector=vector)
def has_vector(self, orth): def has_vector(self, orth):
"""Check whether a word has a vector. Returns False if no """Check whether a word has a vector. Returns False if no
vectors have been loaded. Words can be looked up by string vectors have been loaded. Words can be looked up by string
or int ID.""" or int ID."""
return False if isinstance(orth, basestring_):
orth = self.strings.add(orth)
return orth in self.vectors
def to_disk(self, path, **exclude): def to_disk(self, path, **exclude):
"""Save the current state to a directory. """Save the current state to a directory.