mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
Fix vector for 0-length span (#9244)
This commit is contained in:
parent
015d439eb6
commit
00bdb31150
|
@ -5,7 +5,9 @@ from spacy.attrs import ORTH, LENGTH
|
||||||
from spacy.tokens import Doc, Span, Token
|
from spacy.tokens import Doc, Span, Token
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from spacy.util import filter_spans
|
from spacy.util import filter_spans
|
||||||
|
from thinc.api import get_current_ops
|
||||||
|
|
||||||
|
from ..util import add_vecs_to_vocab
|
||||||
from .test_underscore import clean_underscore # noqa: F401
|
from .test_underscore import clean_underscore # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
@ -412,3 +414,23 @@ def test_sent(en_tokenizer):
|
||||||
assert not span.doc.has_annotation("SENT_START")
|
assert not span.doc.has_annotation("SENT_START")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
span.sent
|
span.sent
|
||||||
|
|
||||||
|
|
||||||
|
def test_span_with_vectors(doc):
|
||||||
|
ops = get_current_ops()
|
||||||
|
prev_vectors = doc.vocab.vectors
|
||||||
|
vectors = [
|
||||||
|
("apple", ops.asarray([1, 2, 3])),
|
||||||
|
("orange", ops.asarray([-1, -2, -3])),
|
||||||
|
("And", ops.asarray([-1, -1, -1])),
|
||||||
|
("juice", ops.asarray([5, 5, 10])),
|
||||||
|
("pie", ops.asarray([7, 6.3, 8.9])),
|
||||||
|
]
|
||||||
|
add_vecs_to_vocab(doc.vocab, vectors)
|
||||||
|
# 0-length span
|
||||||
|
assert_array_equal(ops.to_numpy(doc[0:0].vector), numpy.zeros((3, )))
|
||||||
|
# longer span with no vector
|
||||||
|
assert_array_equal(ops.to_numpy(doc[0:4].vector), numpy.zeros((3, )))
|
||||||
|
# single-token span with vector
|
||||||
|
assert_array_equal(ops.to_numpy(doc[10:11].vector), [-1, -1, -1])
|
||||||
|
doc.vocab.vectors = prev_vectors
|
||||||
|
|
|
@ -474,7 +474,11 @@ cdef class Span:
|
||||||
if "vector" in self.doc.user_span_hooks:
|
if "vector" in self.doc.user_span_hooks:
|
||||||
return self.doc.user_span_hooks["vector"](self)
|
return self.doc.user_span_hooks["vector"](self)
|
||||||
if self._vector is None:
|
if self._vector is None:
|
||||||
self._vector = sum(t.vector for t in self) / len(self)
|
if not len(self):
|
||||||
|
xp = get_array_module(self.vocab.vectors.data)
|
||||||
|
self._vector = xp.zeros((self.vocab.vectors_length,), dtype="f")
|
||||||
|
else:
|
||||||
|
self._vector = sum(t.vector for t in self) / len(self)
|
||||||
return self._vector
|
return self._vector
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
Loading…
Reference in New Issue
Block a user