mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Fix vector for 0-length span (#9244)
This commit is contained in:
parent
015d439eb6
commit
00bdb31150
|
@ -5,7 +5,9 @@ from spacy.attrs import ORTH, LENGTH
|
|||
from spacy.tokens import Doc, Span, Token
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.util import filter_spans
|
||||
from thinc.api import get_current_ops
|
||||
|
||||
from ..util import add_vecs_to_vocab
|
||||
from .test_underscore import clean_underscore # noqa: F401
|
||||
|
||||
|
||||
|
@ -412,3 +414,23 @@ def test_sent(en_tokenizer):
|
|||
assert not span.doc.has_annotation("SENT_START")
|
||||
with pytest.raises(ValueError):
|
||||
span.sent
|
||||
|
||||
|
||||
def test_span_with_vectors(doc):
|
||||
ops = get_current_ops()
|
||||
prev_vectors = doc.vocab.vectors
|
||||
vectors = [
|
||||
("apple", ops.asarray([1, 2, 3])),
|
||||
("orange", ops.asarray([-1, -2, -3])),
|
||||
("And", ops.asarray([-1, -1, -1])),
|
||||
("juice", ops.asarray([5, 5, 10])),
|
||||
("pie", ops.asarray([7, 6.3, 8.9])),
|
||||
]
|
||||
add_vecs_to_vocab(doc.vocab, vectors)
|
||||
# 0-length span
|
||||
assert_array_equal(ops.to_numpy(doc[0:0].vector), numpy.zeros((3, )))
|
||||
# longer span with no vector
|
||||
assert_array_equal(ops.to_numpy(doc[0:4].vector), numpy.zeros((3, )))
|
||||
# single-token span with vector
|
||||
assert_array_equal(ops.to_numpy(doc[10:11].vector), [-1, -1, -1])
|
||||
doc.vocab.vectors = prev_vectors
|
||||
|
|
|
@ -474,6 +474,10 @@ cdef class Span:
|
|||
if "vector" in self.doc.user_span_hooks:
|
||||
return self.doc.user_span_hooks["vector"](self)
|
||||
if self._vector is None:
|
||||
if not len(self):
|
||||
xp = get_array_module(self.vocab.vectors.data)
|
||||
self._vector = xp.zeros((self.vocab.vectors_length,), dtype="f")
|
||||
else:
|
||||
self._vector = sum(t.vector for t in self) / len(self)
|
||||
return self._vector
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user