Check that row is within bounds when adding vector (#5430)

Check that row is within bounds for the vector data array when adding a
vector.

Don't add vectors with rank OOV_RANK in `init-model` (change is due to
shift from OOV as 0 to OOV as OOV_RANK).
This commit is contained in:
adrianeboyd 2020-05-13 22:08:28 +02:00 committed by GitHub
parent 07639dd6ac
commit 113e7981d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 12 additions and 3 deletions

View File

@ -181,7 +181,7 @@ def add_vectors(nlp, vectors_loc, truncate_vectors, prune_vectors, name=None):
if vectors_loc and vectors_loc.parts[-1].endswith(".npz"): if vectors_loc and vectors_loc.parts[-1].endswith(".npz"):
nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open("rb"))) nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open("rb")))
for lex in nlp.vocab: for lex in nlp.vocab:
if lex.rank: if lex.rank and lex.rank != OOV_RANK:
nlp.vocab.vectors.add(lex.orth, row=lex.rank) nlp.vocab.vectors.add(lex.orth, row=lex.rank)
else: else:
if vectors_loc: if vectors_loc:

View File

@ -1,6 +1,7 @@
# coding: utf8 # coding: utf8
from __future__ import unicode_literals from __future__ import unicode_literals
def add_codes(err_cls): def add_codes(err_cls):
"""Add error codes to string messages via class attribute names.""" """Add error codes to string messages via class attribute names."""
@ -555,6 +556,7 @@ class Errors(object):
E195 = ("Matcher can be called on {good} only, got {got}.") E195 = ("Matcher can be called on {good} only, got {got}.")
E196 = ("Refusing to write to token.is_sent_end. Sentence boundaries can " E196 = ("Refusing to write to token.is_sent_end. Sentence boundaries can "
"only be fixed with token.is_sent_start.") "only be fixed with token.is_sent_start.")
E197 = ("Row out of bounds, unable to add row {row} for key {key}.")
@add_codes @add_codes

View File

@ -307,6 +307,9 @@ def test_vocab_add_vector():
dog = vocab["dog"] dog = vocab["dog"]
assert list(dog.vector) == [2.0, 2.0, 2.0] assert list(dog.vector) == [2.0, 2.0, 2.0]
with pytest.raises(ValueError):
vocab.vectors.add(vocab["hamster"].orth, row=1000000)
def test_vocab_prune_vectors(): def test_vocab_prune_vectors():
vocab = Vocab(vectors_name="test_vocab_prune_vectors") vocab = Vocab(vectors_name="test_vocab_prune_vectors")

View File

@ -9,6 +9,7 @@ import functools
import numpy import numpy
from collections import OrderedDict from collections import OrderedDict
import srsly import srsly
import warnings
from thinc.neural.util import get_array_module from thinc.neural.util import get_array_module
from thinc.neural._classes.model import Model from thinc.neural._classes.model import Model
@ -303,7 +304,10 @@ cdef class Vectors:
raise ValueError(Errors.E060.format(rows=self.data.shape[0], raise ValueError(Errors.E060.format(rows=self.data.shape[0],
cols=self.data.shape[1])) cols=self.data.shape[1]))
row = deref(self._unset.begin()) row = deref(self._unset.begin())
self.key2row[key] = row if row < self.data.shape[0]:
self.key2row[key] = row
else:
raise ValueError(Errors.E197.format(row=row, key=key))
if vector is not None: if vector is not None:
self.data[row] = vector self.data[row] = vector
if self._unset.count(row): if self._unset.count(row):

View File

@ -319,7 +319,7 @@ cdef class Vocab:
keys = xp.asarray([key for (prob, i, key) in priority], dtype="uint64") keys = xp.asarray([key for (prob, i, key) in priority], dtype="uint64")
keep = xp.ascontiguousarray(self.vectors.data[indices[:nr_row]]) keep = xp.ascontiguousarray(self.vectors.data[indices[:nr_row]])
toss = xp.ascontiguousarray(self.vectors.data[indices[nr_row:]]) toss = xp.ascontiguousarray(self.vectors.data[indices[nr_row:]])
self.vectors = Vectors(data=keep, keys=keys, name=self.vectors.name) self.vectors = Vectors(data=keep, keys=keys[:nr_row], name=self.vectors.name)
syn_keys, syn_rows, scores = self.vectors.most_similar(toss, batch_size=batch_size) syn_keys, syn_rows, scores = self.vectors.most_similar(toss, batch_size=batch_size)
remap = {} remap = {}
for i, key in enumerate(keys[nr_row:]): for i, key in enumerate(keys[nr_row:]):