diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index ac748292c..e569b46e7 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -4,16 +4,12 @@ from collections import OrderedDict import msgpack import msgpack_numpy msgpack_numpy.patch() -from cymem.cymem cimport Pool cimport numpy as np -from libcpp.vector cimport vector from .typedefs cimport attr_t from .strings cimport StringStore from . import util -from ._cfile cimport CFile - -MAX_VEC_SIZE = 10000 +from .compat import basestring_ cdef class Vectors: @@ -60,7 +56,21 @@ cdef class Vectors: yield from self.data def __len__(self): - return len(self.strings) + # TODO: Fix the quadratic behaviour here! + return max(self.key2row.values()) + + def __contains__(self, key): + if isinstance(key, basestring_): + key = self.strings[key] + return key in self.key2row + + def add_key(self, string, vector=None): + key = self.strings.add(string) + next_i = len(self) + 1 + self.keys[next_i] = key + self.key2row[key] = next_i + if vector is not None: + self.data[next_i] = vector def items(self): for i, string in enumerate(self.strings): @@ -75,9 +85,9 @@ cdef class Vectors: def to_disk(self, path, **exclude): serializers = OrderedDict(( - ('vectors', lambda p: numpy.save(p.open('wb'), self.data)), + ('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)), ('strings.json', self.strings.to_disk), - ('keys', lambda p: numpy.save(p.open('wb'), self.keys)), + ('keys', lambda p: numpy.save(p.open('wb'), self.keys, allow_pickle=False)), )) return util.to_disk(path, serializers, exclude) diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index 055f2ef24..1fc3f5e39 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -19,7 +19,7 @@ from .tokens.token cimport Token from .attrs cimport PROB, LANG from .structs cimport SerializedLexemeC -from .compat import copy_reg, pickle +from .compat import copy_reg, pickle, basestring_ from .lemmatizer import Lemmatizer from .attrs import intify_attrs from .vectors import Vectors @@ -244,7 +244,7 @@ cdef class Vocab: @property def vectors_length(self): - raise NotImplementedError + return len(self.vectors) def clear_vectors(self): """Drop the current vector table. Because all vectors must be the same @@ -264,7 +264,9 @@ cdef class Vocab: RAISES: If no vectors data is loaded, ValueError is raised. """ - raise NotImplementedError + if isinstance(orth, basestring_): + orth = self.strings.add(orth) + return self.vectors[orth] def set_vector(self, orth, vector): """Set a vector for a word in the vocabulary. @@ -274,13 +276,17 @@ cdef class Vocab: RETURNS: None """ - raise NotImplementedError + if not isinstance(orth, basestring_): + orth = self.strings[orth] + self.vectors.add_key(orth, vector=vector) def has_vector(self, orth): """Check whether a word has a vector. Returns False if no vectors have been loaded. Words can be looked up by string or int ID.""" - return False + if isinstance(orth, basestring_): + orth = self.strings.add(orth) + return orth in self.vectors def to_disk(self, path, **exclude): """Save the current state to a directory.