Fix vector for 0-length span (#9244)

This commit is contained in:
Adriane Boyd 2021-09-20 20:22:49 +02:00 committed by GitHub
parent 015d439eb6
commit 00bdb31150
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 1 deletions

View File

@ -5,7 +5,9 @@ from spacy.attrs import ORTH, LENGTH
from spacy.tokens import Doc, Span, Token from spacy.tokens import Doc, Span, Token
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.util import filter_spans from spacy.util import filter_spans
from thinc.api import get_current_ops
from ..util import add_vecs_to_vocab
from .test_underscore import clean_underscore # noqa: F401 from .test_underscore import clean_underscore # noqa: F401
@ -412,3 +414,23 @@ def test_sent(en_tokenizer):
assert not span.doc.has_annotation("SENT_START") assert not span.doc.has_annotation("SENT_START")
with pytest.raises(ValueError): with pytest.raises(ValueError):
span.sent span.sent
def test_span_with_vectors(doc):
ops = get_current_ops()
prev_vectors = doc.vocab.vectors
vectors = [
("apple", ops.asarray([1, 2, 3])),
("orange", ops.asarray([-1, -2, -3])),
("And", ops.asarray([-1, -1, -1])),
("juice", ops.asarray([5, 5, 10])),
("pie", ops.asarray([7, 6.3, 8.9])),
]
add_vecs_to_vocab(doc.vocab, vectors)
# 0-length span
assert_array_equal(ops.to_numpy(doc[0:0].vector), numpy.zeros((3, )))
# longer span with no vector
assert_array_equal(ops.to_numpy(doc[0:4].vector), numpy.zeros((3, )))
# single-token span with vector
assert_array_equal(ops.to_numpy(doc[10:11].vector), [-1, -1, -1])
doc.vocab.vectors = prev_vectors

View File

@ -474,7 +474,11 @@ cdef class Span:
if "vector" in self.doc.user_span_hooks: if "vector" in self.doc.user_span_hooks:
return self.doc.user_span_hooks["vector"](self) return self.doc.user_span_hooks["vector"](self)
if self._vector is None: if self._vector is None:
self._vector = sum(t.vector for t in self) / len(self) if not len(self):
xp = get_array_module(self.vocab.vectors.data)
self._vector = xp.zeros((self.vocab.vectors_length,), dtype="f")
else:
self._vector = sum(t.vector for t in self) / len(self)
return self._vector return self._vector
@property @property