Merge pull request #10346 from adrianeboyd/chore/v3.0-backport-10324

Fix Tok2Vec for empty batches (#10324)
This commit is contained in:
Adriane Boyd 2022-02-21 16:41:13 +01:00 committed by GitHub
commit f71de10405
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 52 additions and 19 deletions

View File

@ -22,7 +22,7 @@ jobs:
# defined in .flake8 and overwrites the selected codes.
- job: "Validate"
pool:
vmImage: "ubuntu-18.04"
vmImage: "ubuntu-latest"
steps:
- task: UsePythonVersion@0
inputs:
@ -38,41 +38,50 @@ jobs:
matrix:
# We're only running one platform per Python version to speed up builds
Python36Linux:
imageName: "ubuntu-18.04"
imageName: "ubuntu-latest"
python.version: "3.6"
# Python36Windows:
# imageName: "vs2017-win2016"
# imageName: "windows-latest"
# python.version: "3.6"
# Python36Mac:
# imageName: "macos-10.14"
# imageName: "macos-latest"
# python.version: "3.6"
# Python37Linux:
# imageName: "ubuntu-18.04"
# imageName: "ubuntu-latest"
# python.version: "3.7"
Python37Windows:
imageName: "vs2017-win2016"
imageName: "windows-latest"
python.version: "3.7"
# Python37Mac:
# imageName: "macos-10.14"
# imageName: "macos-latest"
# python.version: "3.7"
# Python38Linux:
# imageName: "ubuntu-18.04"
# imageName: "ubuntu-latest"
# python.version: "3.8"
# Python38Windows:
# imageName: "vs2017-win2016"
# imageName: "windows-latest"
# python.version: "3.8"
Python38Mac:
imageName: "macos-10.14"
imageName: "macos-latest"
python.version: "3.8"
Python39Linux:
imageName: "ubuntu-18.04"
python.version: "3.9"
Python39Windows:
imageName: "vs2017-win2016"
python.version: "3.9"
Python39Mac:
imageName: "macos-10.14"
imageName: "ubuntu-latest"
python.version: "3.9"
# Python39Windows:
# imageName: "windows-latest"
# python.version: "3.9"
# Python39Mac:
# imageName: "macos-latest"
# python.version: "3.9"
Python310Linux:
imageName: "ubuntu-latest"
python.version: "3.10"
Python310Windows:
imageName: "windows-latest"
python.version: "3.10"
Python310Mac:
imageName: "macos-latest"
python.version: "3.10"
maxParallel: 4
pool:
vmImage: $(imageName)

View File

@ -28,3 +28,4 @@ pytest-timeout>=1.3.0,<2.0.0
mock>=2.0.0,<3.0.0
flake8>=3.5.0,<3.6.0
hypothesis>=3.27.0,<7.0.0
mypy==0.910

View File

@ -118,6 +118,10 @@ class Tok2Vec(TrainablePipe):
DOCS: https://spacy.io/api/tok2vec#predict
"""
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
width = self.model.get_dim("nO")
return [self.model.ops.alloc((0, width)) for doc in docs]
tokvecs = self.model.predict(docs)
batch_id = Tok2VecListener.get_batch_id(docs)
for listener in self.listeners:

View File

@ -11,7 +11,7 @@ from spacy.lang.en import English
from thinc.api import Config, get_current_ops
from numpy.testing import assert_array_equal
from ..util import get_batch, make_tempdir
from ..util import get_batch, make_tempdir, add_vecs_to_vocab
def test_empty_doc():
@ -134,9 +134,25 @@ TRAIN_DATA = [
]
def test_tok2vec_listener():
@pytest.mark.parametrize("with_vectors", (False, True))
def test_tok2vec_listener(with_vectors):
orig_config = Config().from_str(cfg_string)
orig_config["components"]["tok2vec"]["model"]["embed"][
"include_static_vectors"
] = with_vectors
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
if with_vectors:
ops = get_current_ops()
vectors = [
("apple", ops.asarray([1, 2, 3])),
("orange", ops.asarray([-1, -2, -3])),
("and", ops.asarray([-1, -1, -1])),
("juice", ops.asarray([5, 5, 10])),
("pie", ops.asarray([7, 6.3, 8.9])),
]
add_vecs_to_vocab(nlp.vocab, vectors)
assert nlp.pipe_names == ["tok2vec", "tagger"]
tagger = nlp.get_pipe("tagger")
tok2vec = nlp.get_pipe("tok2vec")
@ -163,6 +179,9 @@ def test_tok2vec_listener():
ops = get_current_ops()
assert_array_equal(ops.to_numpy(doc.tensor), ops.to_numpy(doc_tensor))
# test with empty doc
doc = nlp("")
# TODO: should this warn or error?
nlp.select_pipes(disable="tok2vec")
assert nlp.pipe_names == ["tagger"]