Update vectors with find method

This commit is contained in:
Matthew Honnibal 2017-11-01 00:34:55 +01:00
parent 2ad2f09d12
commit c16310d156
2 changed files with 40 additions and 14 deletions

View File

@ -10,7 +10,7 @@ cimport numpy as np
from thinc.neural.util import get_array_module
from thinc.neural._classes.model import Model
from .strings cimport StringStore
from .strings cimport StringStore, hash_string
from .compat import basestring_, path2str
from . import util
@ -183,17 +183,42 @@ cdef class Vectors:
for key, row in self.key2row.items():
yield key, self.data[row]
def get_keys(self, rows):
xp = get_array_module(self.data)
row2key = {row: key for key, row in self.key2row.items()}
keys = xp.asarray([row2key[row] for row in rows],
dtype='uint64')
return keys
def find(self, *, key=None, keys=None, row=None, rows=None):
'''Lookup one or more keys by row, or vice versa.
def get_rows(self, keys):
key (unicode / int): Find the row that the given key points to.
Returns int, -1 if missing.
keys (sequence): Find rows that the keys point to.
Returns ndarray.
row (int): Find the first key that point to the row.
Returns int.
rows (sequence): Find the first keys that points to the rows.
Returns ndarray.
'''
if sum(arg is None for arg in (key, keys, row, rows)) != 3:
raise ValueError("One (and only one) keyword arg must be set.")
xp = get_array_module(self.data)
k2r = self.key2row
return xp.asarray([k2r.get(key, -1) for key in keys], dtype='i')
if key is not None:
if isinstance(key, basestring_):
key = hash_string(key)
return self.key2row.get(key, -1)
elif keys is not None:
keys = [hash_string(key) if isinstance(key, basestring_) else key
for key in keys]
rows = [self.key2row.get(key, -1.) for key in keys]
return xp.asarray(rows, dtype='i')
else:
targets = set()
if row is not None:
targets.add(row)
else:
targets.update(rows)
results = []
for key, row in self.key2row.items():
if row in targets:
results.append(key)
targets.remove(row)
return xp.asarray(results, dtype='uint64')
def add(self, key, *, vector=None, row=None):
"""Add a key to the table. Keys can be mapped to an existing vector
@ -204,6 +229,8 @@ cdef class Vectors:
row (int / None): The row number of a vector to map the key to.
RETURNS (int): The row the vector was added to.
"""
if isinstance(key, basestring):
key = hash_string(key)
if row is None and key in self.key2row:
row = self.key2row[key]
elif row is None:
@ -248,7 +275,7 @@ cdef class Vectors:
sims = xp.dot(batch, vectors.T)
best_rows[i:i+batch_size] = sims.argmax(axis=1)
scores[i:i+batch_size] = sims.max(axis=1)
keys = self.get_keys(best_rows)
keys = self.find(rows=best_rows)
return (keys, best_rows, scores)
def from_glove(self, path):

View File

@ -286,14 +286,13 @@ cdef class Vocab:
priority.sort()
indices = xp.asarray([i for (prob, i, key) in priority], dtype='i')
keys = xp.asarray([key for (prob, i, key) in priority], dtype='uint64')
keep = xp.ascontiguousarray(self.vectors.data[indices[:nr_row]])
toss = xp.ascontiguousarray(self.vectors.data[indices[nr_row:]])
self.vectors = Vectors(data=keep, keys=keys)
syn_keys, syn_rows, scores = self.vectors.most_similar(toss,
return_rows=True, return_scores=True)
syn_keys, syn_rows, scores = self.vectors.most_similar(toss)
remap = {}
for i, key in enumerate(keys[nr_row:]):