mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Check that row is within bounds when adding vector (#5430)
Check that row is within bounds for the vector data array when adding a vector. Don't add vectors with rank OOV_RANK in `init-model` (change is due to shift from OOV as 0 to OOV as OOV_RANK).
This commit is contained in:
parent
07639dd6ac
commit
113e7981d0
|
@ -181,7 +181,7 @@ def add_vectors(nlp, vectors_loc, truncate_vectors, prune_vectors, name=None):
|
|||
if vectors_loc and vectors_loc.parts[-1].endswith(".npz"):
|
||||
nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open("rb")))
|
||||
for lex in nlp.vocab:
|
||||
if lex.rank:
|
||||
if lex.rank and lex.rank != OOV_RANK:
|
||||
nlp.vocab.vectors.add(lex.orth, row=lex.rank)
|
||||
else:
|
||||
if vectors_loc:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
|
||||
def add_codes(err_cls):
|
||||
"""Add error codes to string messages via class attribute names."""
|
||||
|
||||
|
@ -555,6 +556,7 @@ class Errors(object):
|
|||
E195 = ("Matcher can be called on {good} only, got {got}.")
|
||||
E196 = ("Refusing to write to token.is_sent_end. Sentence boundaries can "
|
||||
"only be fixed with token.is_sent_start.")
|
||||
E197 = ("Row out of bounds, unable to add row {row} for key {key}.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
|
|
@ -307,6 +307,9 @@ def test_vocab_add_vector():
|
|||
dog = vocab["dog"]
|
||||
assert list(dog.vector) == [2.0, 2.0, 2.0]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
vocab.vectors.add(vocab["hamster"].orth, row=1000000)
|
||||
|
||||
|
||||
def test_vocab_prune_vectors():
|
||||
vocab = Vocab(vectors_name="test_vocab_prune_vectors")
|
||||
|
|
|
@ -9,6 +9,7 @@ import functools
|
|||
import numpy
|
||||
from collections import OrderedDict
|
||||
import srsly
|
||||
import warnings
|
||||
from thinc.neural.util import get_array_module
|
||||
from thinc.neural._classes.model import Model
|
||||
|
||||
|
@ -303,7 +304,10 @@ cdef class Vectors:
|
|||
raise ValueError(Errors.E060.format(rows=self.data.shape[0],
|
||||
cols=self.data.shape[1]))
|
||||
row = deref(self._unset.begin())
|
||||
self.key2row[key] = row
|
||||
if row < self.data.shape[0]:
|
||||
self.key2row[key] = row
|
||||
else:
|
||||
raise ValueError(Errors.E197.format(row=row, key=key))
|
||||
if vector is not None:
|
||||
self.data[row] = vector
|
||||
if self._unset.count(row):
|
||||
|
|
|
@ -319,7 +319,7 @@ cdef class Vocab:
|
|||
keys = xp.asarray([key for (prob, i, key) in priority], dtype="uint64")
|
||||
keep = xp.ascontiguousarray(self.vectors.data[indices[:nr_row]])
|
||||
toss = xp.ascontiguousarray(self.vectors.data[indices[nr_row:]])
|
||||
self.vectors = Vectors(data=keep, keys=keys, name=self.vectors.name)
|
||||
self.vectors = Vectors(data=keep, keys=keys[:nr_row], name=self.vectors.name)
|
||||
syn_keys, syn_rows, scores = self.vectors.most_similar(toss, batch_size=batch_size)
|
||||
remap = {}
|
||||
for i, key in enumerate(keys[nr_row:]):
|
||||
|
|
Loading…
Reference in New Issue
Block a user