mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Update vectors with find method
This commit is contained in:
parent
2ad2f09d12
commit
c16310d156
|
@ -10,7 +10,7 @@ cimport numpy as np
|
|||
from thinc.neural.util import get_array_module
|
||||
from thinc.neural._classes.model import Model
|
||||
|
||||
from .strings cimport StringStore
|
||||
from .strings cimport StringStore, hash_string
|
||||
from .compat import basestring_, path2str
|
||||
from . import util
|
||||
|
||||
|
@ -183,17 +183,42 @@ cdef class Vectors:
|
|||
for key, row in self.key2row.items():
|
||||
yield key, self.data[row]
|
||||
|
||||
def get_keys(self, rows):
|
||||
xp = get_array_module(self.data)
|
||||
row2key = {row: key for key, row in self.key2row.items()}
|
||||
keys = xp.asarray([row2key[row] for row in rows],
|
||||
dtype='uint64')
|
||||
return keys
|
||||
def find(self, *, key=None, keys=None, row=None, rows=None):
|
||||
'''Lookup one or more keys by row, or vice versa.
|
||||
|
||||
def get_rows(self, keys):
|
||||
key (unicode / int): Find the row that the given key points to.
|
||||
Returns int, -1 if missing.
|
||||
keys (sequence): Find rows that the keys point to.
|
||||
Returns ndarray.
|
||||
row (int): Find the first key that point to the row.
|
||||
Returns int.
|
||||
rows (sequence): Find the first keys that points to the rows.
|
||||
Returns ndarray.
|
||||
'''
|
||||
if sum(arg is None for arg in (key, keys, row, rows)) != 3:
|
||||
raise ValueError("One (and only one) keyword arg must be set.")
|
||||
xp = get_array_module(self.data)
|
||||
k2r = self.key2row
|
||||
return xp.asarray([k2r.get(key, -1) for key in keys], dtype='i')
|
||||
if key is not None:
|
||||
if isinstance(key, basestring_):
|
||||
key = hash_string(key)
|
||||
return self.key2row.get(key, -1)
|
||||
elif keys is not None:
|
||||
keys = [hash_string(key) if isinstance(key, basestring_) else key
|
||||
for key in keys]
|
||||
rows = [self.key2row.get(key, -1.) for key in keys]
|
||||
return xp.asarray(rows, dtype='i')
|
||||
else:
|
||||
targets = set()
|
||||
if row is not None:
|
||||
targets.add(row)
|
||||
else:
|
||||
targets.update(rows)
|
||||
results = []
|
||||
for key, row in self.key2row.items():
|
||||
if row in targets:
|
||||
results.append(key)
|
||||
targets.remove(row)
|
||||
return xp.asarray(results, dtype='uint64')
|
||||
|
||||
def add(self, key, *, vector=None, row=None):
|
||||
"""Add a key to the table. Keys can be mapped to an existing vector
|
||||
|
@ -204,6 +229,8 @@ cdef class Vectors:
|
|||
row (int / None): The row number of a vector to map the key to.
|
||||
RETURNS (int): The row the vector was added to.
|
||||
"""
|
||||
if isinstance(key, basestring):
|
||||
key = hash_string(key)
|
||||
if row is None and key in self.key2row:
|
||||
row = self.key2row[key]
|
||||
elif row is None:
|
||||
|
@ -248,7 +275,7 @@ cdef class Vectors:
|
|||
sims = xp.dot(batch, vectors.T)
|
||||
best_rows[i:i+batch_size] = sims.argmax(axis=1)
|
||||
scores[i:i+batch_size] = sims.max(axis=1)
|
||||
keys = self.get_keys(best_rows)
|
||||
keys = self.find(rows=best_rows)
|
||||
return (keys, best_rows, scores)
|
||||
|
||||
def from_glove(self, path):
|
||||
|
|
|
@ -286,14 +286,13 @@ cdef class Vocab:
|
|||
priority.sort()
|
||||
indices = xp.asarray([i for (prob, i, key) in priority], dtype='i')
|
||||
keys = xp.asarray([key for (prob, i, key) in priority], dtype='uint64')
|
||||
|
||||
|
||||
keep = xp.ascontiguousarray(self.vectors.data[indices[:nr_row]])
|
||||
toss = xp.ascontiguousarray(self.vectors.data[indices[nr_row:]])
|
||||
|
||||
self.vectors = Vectors(data=keep, keys=keys)
|
||||
|
||||
syn_keys, syn_rows, scores = self.vectors.most_similar(toss,
|
||||
return_rows=True, return_scores=True)
|
||||
syn_keys, syn_rows, scores = self.vectors.most_similar(toss)
|
||||
|
||||
remap = {}
|
||||
for i, key in enumerate(keys[nr_row:]):
|
||||
|
|
Loading…
Reference in New Issue
Block a user