mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
WIP on refactoring and fixing vectors
This commit is contained in:
parent
256c7dac5a
commit
368fdb389a
16
spacy/_ml.py
16
spacy/_ml.py
|
@ -29,6 +29,12 @@ from . import util
|
||||||
VECTORS_KEY = 'spacy_pretrained_vectors'
|
VECTORS_KEY = 'spacy_pretrained_vectors'
|
||||||
|
|
||||||
|
|
||||||
|
def cosine(vec1, vec2):
|
||||||
|
norm1 = (vec1**2).sum() ** 0.5
|
||||||
|
norm2 = (vec2**2).sum() ** 0.5
|
||||||
|
return vec1.dot(vec2) / (norm1 * norm2)
|
||||||
|
|
||||||
|
|
||||||
@layerize
|
@layerize
|
||||||
def _flatten_add_lengths(seqs, pad=0, drop=0.):
|
def _flatten_add_lengths(seqs, pad=0, drop=0.):
|
||||||
ops = Model.ops
|
ops = Model.ops
|
||||||
|
@ -198,11 +204,11 @@ class PrecomputableAffine(Model):
|
||||||
def link_vectors_to_models(vocab):
|
def link_vectors_to_models(vocab):
|
||||||
vectors = vocab.vectors
|
vectors = vocab.vectors
|
||||||
ops = Model.ops
|
ops = Model.ops
|
||||||
for word in vocab:
|
#for word in vocab:
|
||||||
if word.orth in vectors.key2row:
|
# if word.orth in vectors.key2row:
|
||||||
word.rank = vectors.key2row[word.orth]
|
# word.rank = vectors.key2row[word.orth]
|
||||||
else:
|
# else:
|
||||||
word.rank = 0
|
# word.rank = 0
|
||||||
data = ops.asarray(vectors.data)
|
data = ops.asarray(vectors.data)
|
||||||
# Set an entry here, so that vectors are accessed by StaticVectors
|
# Set an entry here, so that vectors are accessed by StaticVectors
|
||||||
# (unideal, I know)
|
# (unideal, I know)
|
||||||
|
|
|
@ -94,6 +94,11 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
||||||
nlp.meta.update(meta)
|
nlp.meta.update(meta)
|
||||||
if vectors:
|
if vectors:
|
||||||
util.load_model(vectors, vocab=nlp.vocab)
|
util.load_model(vectors, vocab=nlp.vocab)
|
||||||
|
if vectors_limit is not None:
|
||||||
|
remap = nlp.vocab.prune_vectors(vectors_limit)
|
||||||
|
print('remap', len(remap))
|
||||||
|
for key, (value, sim) in remap.items():
|
||||||
|
print(repr(key), repr(value), sim)
|
||||||
for name in pipeline:
|
for name in pipeline:
|
||||||
nlp.add_pipe(nlp.create_pipe(name), name=name)
|
nlp.add_pipe(nlp.create_pipe(name), name=name)
|
||||||
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
|
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
|
||||||
|
|
|
@ -3,13 +3,41 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import pytest
|
import pytest
|
||||||
|
from ...vocab import Vocab
|
||||||
|
from ..._ml import cosine
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail
|
def test_vocab_add_vector():
|
||||||
@pytest.mark.parametrize('text', ["Hello"])
|
vocab = Vocab()
|
||||||
def test_vocab_add_vector(en_vocab, text):
|
data = numpy.ndarray((5,3), dtype='f')
|
||||||
en_vocab.resize_vectors(10)
|
data[0] = 1.
|
||||||
lex = en_vocab[text]
|
data[1] = 2.
|
||||||
lex.vector = numpy.ndarray((10,), dtype='float32')
|
vocab.set_vector(u'cat', data[0])
|
||||||
lex = en_vocab[text]
|
vocab.set_vector(u'dog', data[1])
|
||||||
assert lex.vector.shape == (10,)
|
cat = vocab[u'cat']
|
||||||
|
assert list(cat.vector) == [1., 1., 1.]
|
||||||
|
dog = vocab[u'dog']
|
||||||
|
assert list(dog.vector) == [2., 2., 2.]
|
||||||
|
for lex in vocab:
|
||||||
|
print(lex.orth_)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vocab_prune_vectors():
|
||||||
|
vocab = Vocab()
|
||||||
|
_ = vocab[u'cat']
|
||||||
|
_ = vocab[u'dog']
|
||||||
|
_ = vocab[u'kitten']
|
||||||
|
print(list(vocab.strings))
|
||||||
|
data = numpy.ndarray((5,3), dtype='f')
|
||||||
|
data[0] = 1.
|
||||||
|
data[1] = 2.
|
||||||
|
data[2] = 1.1
|
||||||
|
vocab.set_vector(u'cat', data[0])
|
||||||
|
vocab.set_vector(u'dog', data[1])
|
||||||
|
vocab.set_vector(u'kitten', data[2])
|
||||||
|
for lex in vocab:
|
||||||
|
print(lex.orth_)
|
||||||
|
|
||||||
|
remap = vocab.prune_vectors(2)
|
||||||
|
assert remap == {u'kitten': (u'cat', cosine(data[0], data[2]))}
|
||||||
|
#print(remap)
|
||||||
|
|
|
@ -27,8 +27,7 @@ cdef class Vectors:
|
||||||
cdef public object data
|
cdef public object data
|
||||||
cdef readonly StringStore strings
|
cdef readonly StringStore strings
|
||||||
cdef public object key2row
|
cdef public object key2row
|
||||||
cdef public object keys
|
cdef public int _i_vec
|
||||||
cdef public int i
|
|
||||||
|
|
||||||
def __init__(self, strings, width=0, data=None):
|
def __init__(self, strings, width=0, data=None):
|
||||||
"""Create a new vector store. To keep the vector table empty, pass
|
"""Create a new vector store. To keep the vector table empty, pass
|
||||||
|
@ -51,13 +50,13 @@ cdef class Vectors:
|
||||||
self.data = numpy.asarray(data, dtype='f')
|
self.data = numpy.asarray(data, dtype='f')
|
||||||
else:
|
else:
|
||||||
self.data = numpy.zeros((len(self.strings), width), dtype='f')
|
self.data = numpy.zeros((len(self.strings), width), dtype='f')
|
||||||
self.i = 0
|
self._i_vec = 0
|
||||||
self.key2row = {}
|
self.key2row = {}
|
||||||
self.keys = numpy.zeros((self.data.shape[0],), dtype='uint64')
|
if data is not None:
|
||||||
for i, string in enumerate(self.strings):
|
for i, string in enumerate(self.strings):
|
||||||
if i >= self.data.shape[0]:
|
if i >= self.data.shape[0]:
|
||||||
break
|
break
|
||||||
self.add(self.strings[string], self.data[i])
|
self.add(self.strings[string], vector=self.data[i])
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (Vectors, (self.strings, self.data))
|
return (Vectors, (self.strings, self.data))
|
||||||
|
@ -122,16 +121,15 @@ cdef class Vectors:
|
||||||
"""
|
"""
|
||||||
if isinstance(key, basestring_):
|
if isinstance(key, basestring_):
|
||||||
key = self.strings.add(key)
|
key = self.strings.add(key)
|
||||||
if key not in self.key2row:
|
if row is None and key in self.key2row:
|
||||||
i = self.i
|
row = self.key2row[key]
|
||||||
if i >= self.keys.shape[0]:
|
elif row is None:
|
||||||
self.keys.resize((self.keys.shape[0]*2,))
|
row = self._i_vec
|
||||||
self.data.resize((self.data.shape[0]*2, self.data.shape[1]))
|
self._i_vec += 1
|
||||||
self.key2row[key] = self.i
|
if row >= self.data.shape[0]:
|
||||||
self.keys[self.i] = key
|
self.data.resize((row*2, self.data.shape[1]))
|
||||||
self.i += 1
|
|
||||||
else:
|
self.key2row[key] = row
|
||||||
i = self.key2row[key]
|
|
||||||
if vector is not None:
|
if vector is not None:
|
||||||
self.data[i] = vector
|
self.data[i] = vector
|
||||||
return i
|
return i
|
||||||
|
@ -141,9 +139,9 @@ cdef class Vectors:
|
||||||
|
|
||||||
YIELDS (tuple): A key/vector pair.
|
YIELDS (tuple): A key/vector pair.
|
||||||
"""
|
"""
|
||||||
for i, key in enumerate(self.keys):
|
for key, row in self.key2row.items():
|
||||||
string = self.strings[key]
|
string = self.strings[key]
|
||||||
yield string, self.data[i]
|
yield string, self.data[row]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
|
@ -202,7 +200,7 @@ cdef class Vectors:
|
||||||
save_array = lambda arr, file_: xp.save(file_, arr)
|
save_array = lambda arr, file_: xp.save(file_, arr)
|
||||||
serializers = OrderedDict((
|
serializers = OrderedDict((
|
||||||
('vectors', lambda p: save_array(self.data, p.open('wb'))),
|
('vectors', lambda p: save_array(self.data, p.open('wb'))),
|
||||||
('keys', lambda p: xp.save(p.open('wb'), self.keys))
|
('key2row', lambda p: msgpack.dump(self.key2row, p.open('wb')))
|
||||||
))
|
))
|
||||||
return util.to_disk(path, serializers, exclude)
|
return util.to_disk(path, serializers, exclude)
|
||||||
|
|
||||||
|
@ -215,10 +213,7 @@ cdef class Vectors:
|
||||||
"""
|
"""
|
||||||
def load_keys(path):
|
def load_keys(path):
|
||||||
if path.exists():
|
if path.exists():
|
||||||
self.keys = numpy.load(path2str(path))
|
self.key2row = msgpack.load(path.open('rb'))
|
||||||
for i, key in enumerate(self.keys):
|
|
||||||
self.keys[i] = key
|
|
||||||
self.key2row[key] = i
|
|
||||||
|
|
||||||
def load_vectors(path):
|
def load_vectors(path):
|
||||||
xp = Model.ops.xp
|
xp = Model.ops.xp
|
||||||
|
@ -226,7 +221,7 @@ cdef class Vectors:
|
||||||
self.data = xp.load(path)
|
self.data = xp.load(path)
|
||||||
|
|
||||||
serializers = OrderedDict((
|
serializers = OrderedDict((
|
||||||
('keys', load_keys),
|
('key2row', load_keys),
|
||||||
('vectors', load_vectors),
|
('vectors', load_vectors),
|
||||||
))
|
))
|
||||||
util.from_disk(path, serializers, exclude)
|
util.from_disk(path, serializers, exclude)
|
||||||
|
@ -244,7 +239,7 @@ cdef class Vectors:
|
||||||
else:
|
else:
|
||||||
return msgpack.dumps(self.data)
|
return msgpack.dumps(self.data)
|
||||||
serializers = OrderedDict((
|
serializers = OrderedDict((
|
||||||
('keys', lambda: msgpack.dumps(self.keys)),
|
('key2row', lambda: msgpack.dumps(self.key2row)),
|
||||||
('vectors', serialize_weights)
|
('vectors', serialize_weights)
|
||||||
))
|
))
|
||||||
return util.to_bytes(serializers, exclude)
|
return util.to_bytes(serializers, exclude)
|
||||||
|
@ -262,14 +257,8 @@ cdef class Vectors:
|
||||||
else:
|
else:
|
||||||
self.data = msgpack.loads(b)
|
self.data = msgpack.loads(b)
|
||||||
|
|
||||||
def load_keys(keys):
|
|
||||||
self.keys.resize((len(keys),))
|
|
||||||
for i, key in enumerate(keys):
|
|
||||||
self.keys[i] = key
|
|
||||||
self.key2row[key] = i
|
|
||||||
|
|
||||||
deserializers = OrderedDict((
|
deserializers = OrderedDict((
|
||||||
('keys', lambda b: load_keys(msgpack.loads(b))),
|
('key2row', lambda b: self.key2row.update(msgpack.loads(b))),
|
||||||
('vectors', deserialize_weights)
|
('vectors', deserialize_weights)
|
||||||
))
|
))
|
||||||
util.from_bytes(data, deserializers, exclude)
|
util.from_bytes(data, deserializers, exclude)
|
||||||
|
|
|
@ -190,10 +190,11 @@ cdef class Vocab:
|
||||||
|
|
||||||
YIELDS (Lexeme): An entry in the vocabulary.
|
YIELDS (Lexeme): An entry in the vocabulary.
|
||||||
"""
|
"""
|
||||||
cdef attr_t orth
|
cdef attr_t key
|
||||||
cdef size_t addr
|
cdef size_t addr
|
||||||
for orth, addr in self._by_orth.items():
|
for key, addr in self._by_orth.items():
|
||||||
yield Lexeme(self, orth)
|
lex = Lexeme(self, key)
|
||||||
|
yield lex
|
||||||
|
|
||||||
def __getitem__(self, id_or_string):
|
def __getitem__(self, id_or_string):
|
||||||
"""Retrieve a lexeme, given an int ID or a unicode string. If a
|
"""Retrieve a lexeme, given an int ID or a unicode string. If a
|
||||||
|
@ -211,7 +212,7 @@ cdef class Vocab:
|
||||||
>>> assert nlp.vocab[apple] == nlp.vocab[u'apple']
|
>>> assert nlp.vocab[apple] == nlp.vocab[u'apple']
|
||||||
"""
|
"""
|
||||||
cdef attr_t orth
|
cdef attr_t orth
|
||||||
if type(id_or_string) == unicode:
|
if isinstance(id_or_string, unicode):
|
||||||
orth = self.strings.add(id_or_string)
|
orth = self.strings.add(id_or_string)
|
||||||
else:
|
else:
|
||||||
orth = id_or_string
|
orth = id_or_string
|
||||||
|
@ -242,9 +243,69 @@ cdef class Vocab:
|
||||||
"""Drop the current vector table. Because all vectors must be the same
|
"""Drop the current vector table. Because all vectors must be the same
|
||||||
width, you have to call this to change the size of the vectors.
|
width, you have to call this to change the size of the vectors.
|
||||||
"""
|
"""
|
||||||
if new_dim is None:
|
if width is None:
|
||||||
new_dim = self.vectors.data.shape[1]
|
width = self.vectors.data.shape[1]
|
||||||
self.vectors = Vectors(self.strings, width=new_dim)
|
self.vectors = Vectors(self.strings, width=width)
|
||||||
|
|
||||||
|
def prune_vectors(self, nr_row, batch_size=8):
|
||||||
|
"""Reduce the current vector table to `nr_row` unique entries. Words
|
||||||
|
mapped to the discarded vectors will be remapped to the closest vector
|
||||||
|
among those remaining.
|
||||||
|
|
||||||
|
For example, suppose the original table had vectors for the words:
|
||||||
|
['sat', 'cat', 'feline', 'reclined']. If we prune the vector table to,
|
||||||
|
two rows, we would discard the vectors for 'feline' and 'reclined'.
|
||||||
|
These words would then be remapped to the closest remaining vector
|
||||||
|
-- so "feline" would have the same vector as "cat", and "reclined"
|
||||||
|
would have the same vector as "sat".
|
||||||
|
|
||||||
|
The similarities are judged by cosine. The original vectors may
|
||||||
|
be large, so the cosines are calculated in minibatches, to reduce
|
||||||
|
memory usage.
|
||||||
|
|
||||||
|
nr_row (int): The number of rows to keep in the vector table.
|
||||||
|
batch_size (int): Batch of vectors for calculating the similarities.
|
||||||
|
Larger batch sizes might be faster, while temporarily requiring
|
||||||
|
more memory.
|
||||||
|
RETURNS (dict): A dictionary keyed by removed words mapped to
|
||||||
|
`(string, score)` tuples, where `string` is the entry the removed
|
||||||
|
word was mapped to, and `score` the similarity score between the
|
||||||
|
two words.
|
||||||
|
"""
|
||||||
|
xp = get_array_module(self.vectors.data)
|
||||||
|
# Work in batches, to avoid memory problems.
|
||||||
|
keep = self.vectors.data[:nr_row]
|
||||||
|
keep_keys = [key for key, row in self.vectors.key2row.items() if row < nr_row]
|
||||||
|
toss = self.vectors.data[nr_row:]
|
||||||
|
# Normalize the vectors, so cosine similarity is just dot product.
|
||||||
|
# Note we can't modify the ones we're keeping in-place...
|
||||||
|
keep = keep / (xp.linalg.norm(keep, axis=1, keepdims=True)+1e-8)
|
||||||
|
keep = xp.ascontiguousarray(keep.T)
|
||||||
|
neighbours = xp.zeros((toss.shape[0],), dtype='i')
|
||||||
|
scores = xp.zeros((toss.shape[0],), dtype='f')
|
||||||
|
for i in range(0, toss.shape[0]//2, batch_size):
|
||||||
|
batch = toss[i : i+batch_size]
|
||||||
|
batch /= xp.linalg.norm(batch, axis=1, keepdims=True)+1e-8
|
||||||
|
sims = xp.dot(batch, keep)
|
||||||
|
matches = sims.argmax(axis=1)
|
||||||
|
neighbours[i:i+batch_size] = matches
|
||||||
|
scores[i:i+batch_size] = sims.max(axis=1)
|
||||||
|
i2k = {i: key for key, i in self.vectors.key2row.items()}
|
||||||
|
remap = {}
|
||||||
|
for lex in list(self):
|
||||||
|
# If we're losing the vector for this word, map it to the nearest
|
||||||
|
# vector we're keeping.
|
||||||
|
if lex.rank >= nr_row:
|
||||||
|
lex.rank = neighbours[lex.rank-nr_row]
|
||||||
|
self.vectors.add(lex.orth, row=lex.rank)
|
||||||
|
remap[lex.orth_] = (i2k[lex.rank], scores[lex.rank])
|
||||||
|
for key, row in self.vectors.key2row.items():
|
||||||
|
if row >= nr_row:
|
||||||
|
self.vectors.key2row[key] = neighbours[row-nr_row]
|
||||||
|
# Make copy, to encourage the original table to be garbage collected.
|
||||||
|
self.vectors.data = xp.ascontiguousarray(self.vectors.data[:nr_row])
|
||||||
|
link_vectors_to_models(self)
|
||||||
|
return remap
|
||||||
|
|
||||||
def get_vector(self, orth):
|
def get_vector(self, orth):
|
||||||
"""Retrieve a vector for a word in the vocabulary. Words can be looked
|
"""Retrieve a vector for a word in the vocabulary. Words can be looked
|
||||||
|
@ -266,9 +327,11 @@ cdef class Vocab:
|
||||||
"""Set a vector for a word in the vocabulary. Words can be referenced
|
"""Set a vector for a word in the vocabulary. Words can be referenced
|
||||||
by string or int ID.
|
by string or int ID.
|
||||||
"""
|
"""
|
||||||
if not isinstance(orth, basestring_):
|
if self.vectors.data.size == 0:
|
||||||
orth = self.strings[orth]
|
self.clear_vectors(vector.shape[0])
|
||||||
|
lex = self[orth]
|
||||||
self.vectors.add(orth, vector=vector)
|
self.vectors.add(orth, vector=vector)
|
||||||
|
lex.rank = self.vectors.key2row[lex.orth]
|
||||||
|
|
||||||
def has_vector(self, orth):
|
def has_vector(self, orth):
|
||||||
"""Check whether a word has a vector. Returns False if no vectors have
|
"""Check whether a word has a vector. Returns False if no vectors have
|
||||||
|
|
Loading…
Reference in New Issue
Block a user