mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Fix Tok2Vec for empty batches (#10324)
* Add test for tok2vec with vectors and empty docs * Add shortcut for empty batch in Tok2Vec.predict * Avoid types
This commit is contained in:
parent
6de84c8757
commit
f4c74764b8
|
@ -118,6 +118,10 @@ class Tok2Vec(TrainablePipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/tok2vec#predict
|
DOCS: https://spacy.io/api/tok2vec#predict
|
||||||
"""
|
"""
|
||||||
|
if not any(len(doc) for doc in docs):
|
||||||
|
# Handle cases where there are no tokens in any docs.
|
||||||
|
width = self.model.get_dim("nO")
|
||||||
|
return [self.model.ops.alloc((0, width)) for doc in docs]
|
||||||
tokvecs = self.model.predict(docs)
|
tokvecs = self.model.predict(docs)
|
||||||
batch_id = Tok2VecListener.get_batch_id(docs)
|
batch_id = Tok2VecListener.get_batch_id(docs)
|
||||||
for listener in self.listeners:
|
for listener in self.listeners:
|
||||||
|
|
|
@ -11,7 +11,7 @@ from spacy.lang.en import English
|
||||||
from thinc.api import Config, get_current_ops
|
from thinc.api import Config, get_current_ops
|
||||||
from numpy.testing import assert_array_equal
|
from numpy.testing import assert_array_equal
|
||||||
|
|
||||||
from ..util import get_batch, make_tempdir
|
from ..util import get_batch, make_tempdir, add_vecs_to_vocab
|
||||||
|
|
||||||
|
|
||||||
def test_empty_doc():
|
def test_empty_doc():
|
||||||
|
@ -140,9 +140,25 @@ TRAIN_DATA = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_tok2vec_listener():
|
@pytest.mark.parametrize("with_vectors", (False, True))
|
||||||
|
def test_tok2vec_listener(with_vectors):
|
||||||
orig_config = Config().from_str(cfg_string)
|
orig_config = Config().from_str(cfg_string)
|
||||||
|
orig_config["components"]["tok2vec"]["model"]["embed"][
|
||||||
|
"include_static_vectors"
|
||||||
|
] = with_vectors
|
||||||
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||||
|
|
||||||
|
if with_vectors:
|
||||||
|
ops = get_current_ops()
|
||||||
|
vectors = [
|
||||||
|
("apple", ops.asarray([1, 2, 3])),
|
||||||
|
("orange", ops.asarray([-1, -2, -3])),
|
||||||
|
("and", ops.asarray([-1, -1, -1])),
|
||||||
|
("juice", ops.asarray([5, 5, 10])),
|
||||||
|
("pie", ops.asarray([7, 6.3, 8.9])),
|
||||||
|
]
|
||||||
|
add_vecs_to_vocab(nlp.vocab, vectors)
|
||||||
|
|
||||||
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
||||||
tagger = nlp.get_pipe("tagger")
|
tagger = nlp.get_pipe("tagger")
|
||||||
tok2vec = nlp.get_pipe("tok2vec")
|
tok2vec = nlp.get_pipe("tok2vec")
|
||||||
|
@ -169,6 +185,9 @@ def test_tok2vec_listener():
|
||||||
ops = get_current_ops()
|
ops = get_current_ops()
|
||||||
assert_array_equal(ops.to_numpy(doc.tensor), ops.to_numpy(doc_tensor))
|
assert_array_equal(ops.to_numpy(doc.tensor), ops.to_numpy(doc_tensor))
|
||||||
|
|
||||||
|
# test with empty doc
|
||||||
|
doc = nlp("")
|
||||||
|
|
||||||
# TODO: should this warn or error?
|
# TODO: should this warn or error?
|
||||||
nlp.select_pipes(disable="tok2vec")
|
nlp.select_pipes(disable="tok2vec")
|
||||||
assert nlp.pipe_names == ["tagger"]
|
assert nlp.pipe_names == ["tagger"]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user