mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
WIP on refactoring and fixing vectors
This commit is contained in:
parent
256c7dac5a
commit
368fdb389a
16
spacy/_ml.py
16
spacy/_ml.py
|
@ -29,6 +29,12 @@ from . import util
|
|||
VECTORS_KEY = 'spacy_pretrained_vectors'
|
||||
|
||||
|
||||
def cosine(vec1, vec2):
|
||||
norm1 = (vec1**2).sum() ** 0.5
|
||||
norm2 = (vec2**2).sum() ** 0.5
|
||||
return vec1.dot(vec2) / (norm1 * norm2)
|
||||
|
||||
|
||||
@layerize
|
||||
def _flatten_add_lengths(seqs, pad=0, drop=0.):
|
||||
ops = Model.ops
|
||||
|
@ -198,11 +204,11 @@ class PrecomputableAffine(Model):
|
|||
def link_vectors_to_models(vocab):
|
||||
vectors = vocab.vectors
|
||||
ops = Model.ops
|
||||
for word in vocab:
|
||||
if word.orth in vectors.key2row:
|
||||
word.rank = vectors.key2row[word.orth]
|
||||
else:
|
||||
word.rank = 0
|
||||
#for word in vocab:
|
||||
# if word.orth in vectors.key2row:
|
||||
# word.rank = vectors.key2row[word.orth]
|
||||
# else:
|
||||
# word.rank = 0
|
||||
data = ops.asarray(vectors.data)
|
||||
# Set an entry here, so that vectors are accessed by StaticVectors
|
||||
# (unideal, I know)
|
||||
|
|
|
@ -94,6 +94,11 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
|||
nlp.meta.update(meta)
|
||||
if vectors:
|
||||
util.load_model(vectors, vocab=nlp.vocab)
|
||||
if vectors_limit is not None:
|
||||
remap = nlp.vocab.prune_vectors(vectors_limit)
|
||||
print('remap', len(remap))
|
||||
for key, (value, sim) in remap.items():
|
||||
print(repr(key), repr(value), sim)
|
||||
for name in pipeline:
|
||||
nlp.add_pipe(nlp.create_pipe(name), name=name)
|
||||
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
|
||||
|
|
|
@ -3,13 +3,41 @@ from __future__ import unicode_literals
|
|||
|
||||
import numpy
|
||||
import pytest
|
||||
from ...vocab import Vocab
|
||||
from ..._ml import cosine
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.parametrize('text', ["Hello"])
|
||||
def test_vocab_add_vector(en_vocab, text):
|
||||
en_vocab.resize_vectors(10)
|
||||
lex = en_vocab[text]
|
||||
lex.vector = numpy.ndarray((10,), dtype='float32')
|
||||
lex = en_vocab[text]
|
||||
assert lex.vector.shape == (10,)
|
||||
def test_vocab_add_vector():
|
||||
vocab = Vocab()
|
||||
data = numpy.ndarray((5,3), dtype='f')
|
||||
data[0] = 1.
|
||||
data[1] = 2.
|
||||
vocab.set_vector(u'cat', data[0])
|
||||
vocab.set_vector(u'dog', data[1])
|
||||
cat = vocab[u'cat']
|
||||
assert list(cat.vector) == [1., 1., 1.]
|
||||
dog = vocab[u'dog']
|
||||
assert list(dog.vector) == [2., 2., 2.]
|
||||
for lex in vocab:
|
||||
print(lex.orth_)
|
||||
|
||||
|
||||
def test_vocab_prune_vectors():
|
||||
vocab = Vocab()
|
||||
_ = vocab[u'cat']
|
||||
_ = vocab[u'dog']
|
||||
_ = vocab[u'kitten']
|
||||
print(list(vocab.strings))
|
||||
data = numpy.ndarray((5,3), dtype='f')
|
||||
data[0] = 1.
|
||||
data[1] = 2.
|
||||
data[2] = 1.1
|
||||
vocab.set_vector(u'cat', data[0])
|
||||
vocab.set_vector(u'dog', data[1])
|
||||
vocab.set_vector(u'kitten', data[2])
|
||||
for lex in vocab:
|
||||
print(lex.orth_)
|
||||
|
||||
remap = vocab.prune_vectors(2)
|
||||
assert remap == {u'kitten': (u'cat', cosine(data[0], data[2]))}
|
||||
#print(remap)
|
||||
|
|
|
@ -27,8 +27,7 @@ cdef class Vectors:
|
|||
cdef public object data
|
||||
cdef readonly StringStore strings
|
||||
cdef public object key2row
|
||||
cdef public object keys
|
||||
cdef public int i
|
||||
cdef public int _i_vec
|
||||
|
||||
def __init__(self, strings, width=0, data=None):
|
||||
"""Create a new vector store. To keep the vector table empty, pass
|
||||
|
@ -51,13 +50,13 @@ cdef class Vectors:
|
|||
self.data = numpy.asarray(data, dtype='f')
|
||||
else:
|
||||
self.data = numpy.zeros((len(self.strings), width), dtype='f')
|
||||
self.i = 0
|
||||
self._i_vec = 0
|
||||
self.key2row = {}
|
||||
self.keys = numpy.zeros((self.data.shape[0],), dtype='uint64')
|
||||
if data is not None:
|
||||
for i, string in enumerate(self.strings):
|
||||
if i >= self.data.shape[0]:
|
||||
break
|
||||
self.add(self.strings[string], self.data[i])
|
||||
self.add(self.strings[string], vector=self.data[i])
|
||||
|
||||
def __reduce__(self):
|
||||
return (Vectors, (self.strings, self.data))
|
||||
|
@ -122,16 +121,15 @@ cdef class Vectors:
|
|||
"""
|
||||
if isinstance(key, basestring_):
|
||||
key = self.strings.add(key)
|
||||
if key not in self.key2row:
|
||||
i = self.i
|
||||
if i >= self.keys.shape[0]:
|
||||
self.keys.resize((self.keys.shape[0]*2,))
|
||||
self.data.resize((self.data.shape[0]*2, self.data.shape[1]))
|
||||
self.key2row[key] = self.i
|
||||
self.keys[self.i] = key
|
||||
self.i += 1
|
||||
else:
|
||||
i = self.key2row[key]
|
||||
if row is None and key in self.key2row:
|
||||
row = self.key2row[key]
|
||||
elif row is None:
|
||||
row = self._i_vec
|
||||
self._i_vec += 1
|
||||
if row >= self.data.shape[0]:
|
||||
self.data.resize((row*2, self.data.shape[1]))
|
||||
|
||||
self.key2row[key] = row
|
||||
if vector is not None:
|
||||
self.data[i] = vector
|
||||
return i
|
||||
|
@ -141,9 +139,9 @@ cdef class Vectors:
|
|||
|
||||
YIELDS (tuple): A key/vector pair.
|
||||
"""
|
||||
for i, key in enumerate(self.keys):
|
||||
for key, row in self.key2row.items():
|
||||
string = self.strings[key]
|
||||
yield string, self.data[i]
|
||||
yield string, self.data[row]
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
|
@ -202,7 +200,7 @@ cdef class Vectors:
|
|||
save_array = lambda arr, file_: xp.save(file_, arr)
|
||||
serializers = OrderedDict((
|
||||
('vectors', lambda p: save_array(self.data, p.open('wb'))),
|
||||
('keys', lambda p: xp.save(p.open('wb'), self.keys))
|
||||
('key2row', lambda p: msgpack.dump(self.key2row, p.open('wb')))
|
||||
))
|
||||
return util.to_disk(path, serializers, exclude)
|
||||
|
||||
|
@ -215,10 +213,7 @@ cdef class Vectors:
|
|||
"""
|
||||
def load_keys(path):
|
||||
if path.exists():
|
||||
self.keys = numpy.load(path2str(path))
|
||||
for i, key in enumerate(self.keys):
|
||||
self.keys[i] = key
|
||||
self.key2row[key] = i
|
||||
self.key2row = msgpack.load(path.open('rb'))
|
||||
|
||||
def load_vectors(path):
|
||||
xp = Model.ops.xp
|
||||
|
@ -226,7 +221,7 @@ cdef class Vectors:
|
|||
self.data = xp.load(path)
|
||||
|
||||
serializers = OrderedDict((
|
||||
('keys', load_keys),
|
||||
('key2row', load_keys),
|
||||
('vectors', load_vectors),
|
||||
))
|
||||
util.from_disk(path, serializers, exclude)
|
||||
|
@ -244,7 +239,7 @@ cdef class Vectors:
|
|||
else:
|
||||
return msgpack.dumps(self.data)
|
||||
serializers = OrderedDict((
|
||||
('keys', lambda: msgpack.dumps(self.keys)),
|
||||
('key2row', lambda: msgpack.dumps(self.key2row)),
|
||||
('vectors', serialize_weights)
|
||||
))
|
||||
return util.to_bytes(serializers, exclude)
|
||||
|
@ -262,14 +257,8 @@ cdef class Vectors:
|
|||
else:
|
||||
self.data = msgpack.loads(b)
|
||||
|
||||
def load_keys(keys):
|
||||
self.keys.resize((len(keys),))
|
||||
for i, key in enumerate(keys):
|
||||
self.keys[i] = key
|
||||
self.key2row[key] = i
|
||||
|
||||
deserializers = OrderedDict((
|
||||
('keys', lambda b: load_keys(msgpack.loads(b))),
|
||||
('key2row', lambda b: self.key2row.update(msgpack.loads(b))),
|
||||
('vectors', deserialize_weights)
|
||||
))
|
||||
util.from_bytes(data, deserializers, exclude)
|
||||
|
|
|
@ -190,10 +190,11 @@ cdef class Vocab:
|
|||
|
||||
YIELDS (Lexeme): An entry in the vocabulary.
|
||||
"""
|
||||
cdef attr_t orth
|
||||
cdef attr_t key
|
||||
cdef size_t addr
|
||||
for orth, addr in self._by_orth.items():
|
||||
yield Lexeme(self, orth)
|
||||
for key, addr in self._by_orth.items():
|
||||
lex = Lexeme(self, key)
|
||||
yield lex
|
||||
|
||||
def __getitem__(self, id_or_string):
|
||||
"""Retrieve a lexeme, given an int ID or a unicode string. If a
|
||||
|
@ -211,7 +212,7 @@ cdef class Vocab:
|
|||
>>> assert nlp.vocab[apple] == nlp.vocab[u'apple']
|
||||
"""
|
||||
cdef attr_t orth
|
||||
if type(id_or_string) == unicode:
|
||||
if isinstance(id_or_string, unicode):
|
||||
orth = self.strings.add(id_or_string)
|
||||
else:
|
||||
orth = id_or_string
|
||||
|
@ -242,9 +243,69 @@ cdef class Vocab:
|
|||
"""Drop the current vector table. Because all vectors must be the same
|
||||
width, you have to call this to change the size of the vectors.
|
||||
"""
|
||||
if new_dim is None:
|
||||
new_dim = self.vectors.data.shape[1]
|
||||
self.vectors = Vectors(self.strings, width=new_dim)
|
||||
if width is None:
|
||||
width = self.vectors.data.shape[1]
|
||||
self.vectors = Vectors(self.strings, width=width)
|
||||
|
||||
def prune_vectors(self, nr_row, batch_size=8):
|
||||
"""Reduce the current vector table to `nr_row` unique entries. Words
|
||||
mapped to the discarded vectors will be remapped to the closest vector
|
||||
among those remaining.
|
||||
|
||||
For example, suppose the original table had vectors for the words:
|
||||
['sat', 'cat', 'feline', 'reclined']. If we prune the vector table to,
|
||||
two rows, we would discard the vectors for 'feline' and 'reclined'.
|
||||
These words would then be remapped to the closest remaining vector
|
||||
-- so "feline" would have the same vector as "cat", and "reclined"
|
||||
would have the same vector as "sat".
|
||||
|
||||
The similarities are judged by cosine. The original vectors may
|
||||
be large, so the cosines are calculated in minibatches, to reduce
|
||||
memory usage.
|
||||
|
||||
nr_row (int): The number of rows to keep in the vector table.
|
||||
batch_size (int): Batch of vectors for calculating the similarities.
|
||||
Larger batch sizes might be faster, while temporarily requiring
|
||||
more memory.
|
||||
RETURNS (dict): A dictionary keyed by removed words mapped to
|
||||
`(string, score)` tuples, where `string` is the entry the removed
|
||||
word was mapped to, and `score` the similarity score between the
|
||||
two words.
|
||||
"""
|
||||
xp = get_array_module(self.vectors.data)
|
||||
# Work in batches, to avoid memory problems.
|
||||
keep = self.vectors.data[:nr_row]
|
||||
keep_keys = [key for key, row in self.vectors.key2row.items() if row < nr_row]
|
||||
toss = self.vectors.data[nr_row:]
|
||||
# Normalize the vectors, so cosine similarity is just dot product.
|
||||
# Note we can't modify the ones we're keeping in-place...
|
||||
keep = keep / (xp.linalg.norm(keep, axis=1, keepdims=True)+1e-8)
|
||||
keep = xp.ascontiguousarray(keep.T)
|
||||
neighbours = xp.zeros((toss.shape[0],), dtype='i')
|
||||
scores = xp.zeros((toss.shape[0],), dtype='f')
|
||||
for i in range(0, toss.shape[0]//2, batch_size):
|
||||
batch = toss[i : i+batch_size]
|
||||
batch /= xp.linalg.norm(batch, axis=1, keepdims=True)+1e-8
|
||||
sims = xp.dot(batch, keep)
|
||||
matches = sims.argmax(axis=1)
|
||||
neighbours[i:i+batch_size] = matches
|
||||
scores[i:i+batch_size] = sims.max(axis=1)
|
||||
i2k = {i: key for key, i in self.vectors.key2row.items()}
|
||||
remap = {}
|
||||
for lex in list(self):
|
||||
# If we're losing the vector for this word, map it to the nearest
|
||||
# vector we're keeping.
|
||||
if lex.rank >= nr_row:
|
||||
lex.rank = neighbours[lex.rank-nr_row]
|
||||
self.vectors.add(lex.orth, row=lex.rank)
|
||||
remap[lex.orth_] = (i2k[lex.rank], scores[lex.rank])
|
||||
for key, row in self.vectors.key2row.items():
|
||||
if row >= nr_row:
|
||||
self.vectors.key2row[key] = neighbours[row-nr_row]
|
||||
# Make copy, to encourage the original table to be garbage collected.
|
||||
self.vectors.data = xp.ascontiguousarray(self.vectors.data[:nr_row])
|
||||
link_vectors_to_models(self)
|
||||
return remap
|
||||
|
||||
def get_vector(self, orth):
|
||||
"""Retrieve a vector for a word in the vocabulary. Words can be looked
|
||||
|
@ -266,9 +327,11 @@ cdef class Vocab:
|
|||
"""Set a vector for a word in the vocabulary. Words can be referenced
|
||||
by string or int ID.
|
||||
"""
|
||||
if not isinstance(orth, basestring_):
|
||||
orth = self.strings[orth]
|
||||
if self.vectors.data.size == 0:
|
||||
self.clear_vectors(vector.shape[0])
|
||||
lex = self[orth]
|
||||
self.vectors.add(orth, vector=vector)
|
||||
lex.rank = self.vectors.key2row[lex.orth]
|
||||
|
||||
def has_vector(self, orth):
|
||||
"""Check whether a word has a vector. Returns False if no vectors have
|
||||
|
|
Loading…
Reference in New Issue
Block a user