[2032] - Changed python set to cpp stl set (#2170)

Changed python set to cpp stl set #2032 

## Description

Changed python set to cpp stl set. CPP stl set works better due to the logarithmic run time of its methods. Finding minimum in the cpp set is done in constant time as opposed to the worst case linear runtime of python set. Operations such as find,count,insert,delete are also done in either constant and logarithmic time thus making cpp set a better option to manage vectors.
Reference : http://www.cplusplus.com/reference/set/set/

### Types of change
Enhancement for `Vectors` for faster initialising of word vectors(fasttext)
This commit is contained in:
Suraj Rajan 2018-03-31 16:58:25 +05:30 committed by Matthew Honnibal
parent 6f84e32253
commit 1cdbb7c97c
3 changed files with 45 additions and 17 deletions

View File

@ -91,16 +91,16 @@ mark both statements:
or entity, including my employer, has or will have rights with respect to my or entity, including my employer, has or will have rights with respect to my
contributions. contributions.
* [x] I am signing on behalf of my employer or a legal entity and I have the * [] I am signing on behalf of my employer or a legal entity and I have the
actual authority to contractually bind that entity. actual authority to contractually bind that entity.
## Contributor Details ## Contributor Details
| Field | Entry | | Field | Entry |
|------------------------------- | -------------------- | |------------------------------- | -------------------- |
| Name | | | Name | Suraj Rajan |
| Company name (if applicable) | | | Company name (if applicable) | |
| Title or role (if applicable) | | | Title or role (if applicable) | |
| Date | | | Date | 31/Mar/2018 |
| GitHub username | | | GitHub username | skrcode |
| Website (optional) | | | Website (optional) | |

View File

@ -28,12 +28,38 @@ def vectors():
def data(): def data():
return numpy.asarray([[0.0, 1.0, 2.0], [3.0, -2.0, 4.0]], dtype='f') return numpy.asarray([[0.0, 1.0, 2.0], [3.0, -2.0, 4.0]], dtype='f')
@pytest.fixture
def resize_data():
return numpy.asarray([[0.0, 1.0], [2.0, 3.0]], dtype='f')
@pytest.fixture() @pytest.fixture()
def vocab(en_vocab, vectors): def vocab(en_vocab, vectors):
add_vecs_to_vocab(en_vocab, vectors) add_vecs_to_vocab(en_vocab, vectors)
return en_vocab return en_vocab
def test_init_vectors_with_resize_shape(strings,resize_data):
v = Vectors(shape=(len(strings), 3))
v.resize(shape=resize_data.shape)
assert v.shape == resize_data.shape
assert v.shape != (len(strings), 3)
def test_init_vectors_with_resize_data(data,resize_data):
v = Vectors(data=data)
v.resize(shape=resize_data.shape)
assert v.shape == resize_data.shape
assert v.shape != data.shape
def test_get_vector_resize(strings, data,resize_data):
v = Vectors(data=data)
v.resize(shape=resize_data.shape)
strings = [hash_string(s) for s in strings]
for i, string in enumerate(strings):
v.add(string, row=i)
assert list(v[strings[0]]) == list(resize_data[0])
assert list(v[strings[0]]) != list(resize_data[1])
assert list(v[strings[1]]) != list(resize_data[0])
assert list(v[strings[1]]) == list(resize_data[1])
def test_init_vectors_with_data(strings, data): def test_init_vectors_with_data(strings, data):
v = Vectors(data=data) v = Vectors(data=data)

View File

@ -16,6 +16,8 @@ from .strings cimport StringStore, hash_string
from .compat import basestring_, path2str from .compat import basestring_, path2str
from . import util from . import util
from cython.operator cimport dereference as deref
from libcpp.set cimport set as cppset
def unpickle_vectors(bytes_data): def unpickle_vectors(bytes_data):
return Vectors().from_bytes(bytes_data) return Vectors().from_bytes(bytes_data)
@ -50,7 +52,7 @@ cdef class Vectors:
cdef public object name cdef public object name
cdef public object data cdef public object data
cdef public object key2row cdef public object key2row
cdef public object _unset cdef cppset[int] _unset
def __init__(self, *, shape=None, data=None, keys=None, name=None): def __init__(self, *, shape=None, data=None, keys=None, name=None):
"""Create a new vector store. """Create a new vector store.
@ -69,9 +71,9 @@ cdef class Vectors:
self.data = data self.data = data
self.key2row = OrderedDict() self.key2row = OrderedDict()
if self.data is not None: if self.data is not None:
self._unset = set(range(self.data.shape[0])) self._unset = cppset[int]({i for i in range(self.data.shape[0])})
else: else:
self._unset = set() self._unset = cppset[int]()
if keys is not None: if keys is not None:
for i, key in enumerate(keys): for i, key in enumerate(keys):
self.add(key, row=i) self.add(key, row=i)
@ -93,7 +95,7 @@ cdef class Vectors:
@property @property
def is_full(self): def is_full(self):
"""RETURNS (bool): `True` if no slots are available for new keys.""" """RETURNS (bool): `True` if no slots are available for new keys."""
return len(self._unset) == 0 return self._unset.size() == 0
@property @property
def n_keys(self): def n_keys(self):
@ -124,8 +126,8 @@ cdef class Vectors:
""" """
i = self.key2row[key] i = self.key2row[key]
self.data[i] = vector self.data[i] = vector
if i in self._unset: if self._unset.count(i):
self._unset.remove(i) self._unset.erase(self._unset.find(i))
def __iter__(self): def __iter__(self):
"""Iterate over the keys in the table. """Iterate over the keys in the table.
@ -164,7 +166,7 @@ cdef class Vectors:
xp = get_array_module(self.data) xp = get_array_module(self.data)
self.data = xp.resize(self.data, shape) self.data = xp.resize(self.data, shape)
filled = {row for row in self.key2row.values()} filled = {row for row in self.key2row.values()}
self._unset = {row for row in range(shape[0]) if row not in filled} self._unset = cppset[int]({row for row in range(shape[0]) if row not in filled})
removed_items = [] removed_items = []
for key, row in list(self.key2row.items()): for key, row in list(self.key2row.items()):
if row >= shape[0]: if row >= shape[0]:
@ -188,7 +190,7 @@ cdef class Vectors:
YIELDS (ndarray): A vector in the table. YIELDS (ndarray): A vector in the table.
""" """
for row, vector in enumerate(range(self.data.shape[0])): for row, vector in enumerate(range(self.data.shape[0])):
if row not in self._unset: if not self._unset.count(row):
yield vector yield vector
def items(self): def items(self):
@ -253,13 +255,13 @@ cdef class Vectors:
elif row is None: elif row is None:
if self.is_full: if self.is_full:
raise ValueError("Cannot add new key to vectors -- full") raise ValueError("Cannot add new key to vectors -- full")
row = min(self._unset) row = deref(self._unset.begin())
self.key2row[key] = row self.key2row[key] = row
if vector is not None: if vector is not None:
self.data[row] = vector self.data[row] = vector
if row in self._unset: if self._unset.count(row):
self._unset.remove(row) self._unset.erase(self._unset.find(row))
return row return row
def most_similar(self, queries, *, batch_size=1024): def most_similar(self, queries, *, batch_size=1024):
@ -365,8 +367,8 @@ cdef class Vectors:
with path.open('rb') as file_: with path.open('rb') as file_:
self.key2row = msgpack.load(file_) self.key2row = msgpack.load(file_)
for key, row in self.key2row.items(): for key, row in self.key2row.items():
if row in self._unset: if self._unset.count(row):
self._unset.remove(row) self._unset.erase(self._unset.find(row))
def load_keys(path): def load_keys(path):
if path.exists(): if path.exists():