diff --git a/.github/CONTRIBUTOR_AGREEMENT.md b/.github/CONTRIBUTOR_AGREEMENT.md index a8c741ce1..a2bd09d65 100644 --- a/.github/CONTRIBUTOR_AGREEMENT.md +++ b/.github/CONTRIBUTOR_AGREEMENT.md @@ -91,16 +91,16 @@ mark both statements: or entity, including my employer, has or will have rights with respect to my contributions. - * [x] I am signing on behalf of my employer or a legal entity and I have the + * [] I am signing on behalf of my employer or a legal entity and I have the actual authority to contractually bind that entity. ## Contributor Details | Field | Entry | |------------------------------- | -------------------- | -| Name | | +| Name | Suraj Rajan | | Company name (if applicable) | | | Title or role (if applicable) | | -| Date | | -| GitHub username | | +| Date | 31/Mar/2018 | +| GitHub username | skrcode | | Website (optional) | | diff --git a/spacy/tests/vectors/test_vectors.py b/spacy/tests/vectors/test_vectors.py index a9eabc78d..ce32eec00 100644 --- a/spacy/tests/vectors/test_vectors.py +++ b/spacy/tests/vectors/test_vectors.py @@ -28,12 +28,38 @@ def vectors(): def data(): return numpy.asarray([[0.0, 1.0, 2.0], [3.0, -2.0, 4.0]], dtype='f') +@pytest.fixture +def resize_data(): + return numpy.asarray([[0.0, 1.0], [2.0, 3.0]], dtype='f') @pytest.fixture() def vocab(en_vocab, vectors): add_vecs_to_vocab(en_vocab, vectors) return en_vocab +def test_init_vectors_with_resize_shape(strings,resize_data): + v = Vectors(shape=(len(strings), 3)) + v.resize(shape=resize_data.shape) + assert v.shape == resize_data.shape + assert v.shape != (len(strings), 3) + +def test_init_vectors_with_resize_data(data,resize_data): + v = Vectors(data=data) + v.resize(shape=resize_data.shape) + assert v.shape == resize_data.shape + assert v.shape != data.shape + +def test_get_vector_resize(strings, data,resize_data): + v = Vectors(data=data) + v.resize(shape=resize_data.shape) + strings = [hash_string(s) for s in strings] + for i, string in enumerate(strings): + v.add(string, row=i) + + assert list(v[strings[0]]) == list(resize_data[0]) + assert list(v[strings[0]]) != list(resize_data[1]) + assert list(v[strings[1]]) != list(resize_data[0]) + assert list(v[strings[1]]) == list(resize_data[1]) def test_init_vectors_with_data(strings, data): v = Vectors(data=data) diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index 33d817fb2..25dedea5f 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -16,6 +16,8 @@ from .strings cimport StringStore, hash_string from .compat import basestring_, path2str from . import util +from cython.operator cimport dereference as deref +from libcpp.set cimport set as cppset def unpickle_vectors(bytes_data): return Vectors().from_bytes(bytes_data) @@ -50,7 +52,7 @@ cdef class Vectors: cdef public object name cdef public object data cdef public object key2row - cdef public object _unset + cdef cppset[int] _unset def __init__(self, *, shape=None, data=None, keys=None, name=None): """Create a new vector store. @@ -69,9 +71,9 @@ cdef class Vectors: self.data = data self.key2row = OrderedDict() if self.data is not None: - self._unset = set(range(self.data.shape[0])) + self._unset = cppset[int]({i for i in range(self.data.shape[0])}) else: - self._unset = set() + self._unset = cppset[int]() if keys is not None: for i, key in enumerate(keys): self.add(key, row=i) @@ -93,7 +95,7 @@ cdef class Vectors: @property def is_full(self): """RETURNS (bool): `True` if no slots are available for new keys.""" - return len(self._unset) == 0 + return self._unset.size() == 0 @property def n_keys(self): @@ -124,8 +126,8 @@ cdef class Vectors: """ i = self.key2row[key] self.data[i] = vector - if i in self._unset: - self._unset.remove(i) + if self._unset.count(i): + self._unset.erase(self._unset.find(i)) def __iter__(self): """Iterate over the keys in the table. @@ -164,7 +166,7 @@ cdef class Vectors: xp = get_array_module(self.data) self.data = xp.resize(self.data, shape) filled = {row for row in self.key2row.values()} - self._unset = {row for row in range(shape[0]) if row not in filled} + self._unset = cppset[int]({row for row in range(shape[0]) if row not in filled}) removed_items = [] for key, row in list(self.key2row.items()): if row >= shape[0]: @@ -188,7 +190,7 @@ cdef class Vectors: YIELDS (ndarray): A vector in the table. """ for row, vector in enumerate(range(self.data.shape[0])): - if row not in self._unset: + if not self._unset.count(row): yield vector def items(self): @@ -253,13 +255,13 @@ cdef class Vectors: elif row is None: if self.is_full: raise ValueError("Cannot add new key to vectors -- full") - row = min(self._unset) + row = deref(self._unset.begin()) self.key2row[key] = row if vector is not None: self.data[row] = vector - if row in self._unset: - self._unset.remove(row) + if self._unset.count(row): + self._unset.erase(self._unset.find(row)) return row def most_similar(self, queries, *, batch_size=1024): @@ -365,8 +367,8 @@ cdef class Vectors: with path.open('rb') as file_: self.key2row = msgpack.load(file_) for key, row in self.key2row.items(): - if row in self._unset: - self._unset.remove(row) + if self._unset.count(row): + self._unset.erase(self._unset.find(row)) def load_keys(path): if path.exists():