Fix vector loading

This commit is contained in:
Matthew Honnibal 2017-08-19 19:52:25 +02:00
parent 49a615e7d9
commit 93fb8b64e9
2 changed files with 29 additions and 13 deletions

View File

@ -4,16 +4,12 @@ from collections import OrderedDict
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
from cymem.cymem cimport Pool
cimport numpy as np
from libcpp.vector cimport vector
from .typedefs cimport attr_t
from .strings cimport StringStore
from . import util
from ._cfile cimport CFile
MAX_VEC_SIZE = 10000
from .compat import basestring_
cdef class Vectors:
@ -60,7 +56,21 @@ cdef class Vectors:
yield from self.data
def __len__(self):
return len(self.strings)
# TODO: Fix the quadratic behaviour here!
return max(self.key2row.values())
def __contains__(self, key):
if isinstance(key, basestring_):
key = self.strings[key]
return key in self.key2row
def add_key(self, string, vector=None):
key = self.strings.add(string)
next_i = len(self) + 1
self.keys[next_i] = key
self.key2row[key] = next_i
if vector is not None:
self.data[next_i] = vector
def items(self):
for i, string in enumerate(self.strings):
@ -75,9 +85,9 @@ cdef class Vectors:
def to_disk(self, path, **exclude):
serializers = OrderedDict((
('vectors', lambda p: numpy.save(p.open('wb'), self.data)),
('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)),
('strings.json', self.strings.to_disk),
('keys', lambda p: numpy.save(p.open('wb'), self.keys)),
('keys', lambda p: numpy.save(p.open('wb'), self.keys, allow_pickle=False)),
))
return util.to_disk(path, serializers, exclude)

View File

@ -19,7 +19,7 @@ from .tokens.token cimport Token
from .attrs cimport PROB, LANG
from .structs cimport SerializedLexemeC
from .compat import copy_reg, pickle
from .compat import copy_reg, pickle, basestring_
from .lemmatizer import Lemmatizer
from .attrs import intify_attrs
from .vectors import Vectors
@ -244,7 +244,7 @@ cdef class Vocab:
@property
def vectors_length(self):
raise NotImplementedError
return len(self.vectors)
def clear_vectors(self):
"""Drop the current vector table. Because all vectors must be the same
@ -264,7 +264,9 @@ cdef class Vocab:
RAISES: If no vectors data is loaded, ValueError is raised.
"""
raise NotImplementedError
if isinstance(orth, basestring_):
orth = self.strings.add(orth)
return self.vectors[orth]
def set_vector(self, orth, vector):
"""Set a vector for a word in the vocabulary.
@ -274,13 +276,17 @@ cdef class Vocab:
RETURNS:
None
"""
raise NotImplementedError
if not isinstance(orth, basestring_):
orth = self.strings[orth]
self.vectors.add_key(orth, vector=vector)
def has_vector(self, orth):
"""Check whether a word has a vector. Returns False if no
vectors have been loaded. Words can be looked up by string
or int ID."""
return False
if isinstance(orth, basestring_):
orth = self.strings.add(orth)
return orth in self.vectors
def to_disk(self, path, **exclude):
"""Save the current state to a directory.