mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 00:50:33 +03:00
Merge pull request #10345 from adrianeboyd/chore/v3.1-backport-10324
Fix Tok2Vec for empty batches (#10324)
This commit is contained in:
commit
c69a8756b6
|
@ -23,7 +23,7 @@ jobs:
|
|||
# defined in .flake8 and overwrites the selected codes.
|
||||
- job: "Validate"
|
||||
pool:
|
||||
vmImage: "ubuntu-18.04"
|
||||
vmImage: "ubuntu-latest"
|
||||
steps:
|
||||
- task: UsePythonVersion@0
|
||||
inputs:
|
||||
|
@ -39,49 +39,49 @@ jobs:
|
|||
matrix:
|
||||
# We're only running one platform per Python version to speed up builds
|
||||
Python36Linux:
|
||||
imageName: "ubuntu-18.04"
|
||||
imageName: "ubuntu-latest"
|
||||
python.version: "3.6"
|
||||
# Python36Windows:
|
||||
# imageName: "windows-2019"
|
||||
# imageName: "windows-latest"
|
||||
# python.version: "3.6"
|
||||
# Python36Mac:
|
||||
# imageName: "macos-10.14"
|
||||
# imageName: "macos-latest"
|
||||
# python.version: "3.6"
|
||||
# Python37Linux:
|
||||
# imageName: "ubuntu-18.04"
|
||||
# imageName: "ubuntu-latest"
|
||||
# python.version: "3.7"
|
||||
Python37Windows:
|
||||
imageName: "windows-2019"
|
||||
imageName: "windows-latest"
|
||||
python.version: "3.7"
|
||||
# Python37Mac:
|
||||
# imageName: "macos-10.14"
|
||||
# imageName: "macos-latest"
|
||||
# python.version: "3.7"
|
||||
# Python38Linux:
|
||||
# imageName: "ubuntu-18.04"
|
||||
# imageName: "ubuntu-latest"
|
||||
# python.version: "3.8"
|
||||
# Python38Windows:
|
||||
# imageName: "windows-2019"
|
||||
# imageName: "windows-latest"
|
||||
# python.version: "3.8"
|
||||
Python38Mac:
|
||||
imageName: "macos-10.14"
|
||||
imageName: "macos-latest"
|
||||
python.version: "3.8"
|
||||
Python39Linux:
|
||||
imageName: "ubuntu-18.04"
|
||||
imageName: "ubuntu-latest"
|
||||
python.version: "3.9"
|
||||
# Python39Windows:
|
||||
# imageName: "windows-2019"
|
||||
# imageName: "windows-latest"
|
||||
# python.version: "3.9"
|
||||
# Python39Mac:
|
||||
# imageName: "macos-10.14"
|
||||
# imageName: "macos-latest"
|
||||
# python.version: "3.9"
|
||||
Python310Linux:
|
||||
imageName: "ubuntu-20.04"
|
||||
imageName: "ubuntu-latest"
|
||||
python.version: "3.10"
|
||||
Python310Windows:
|
||||
imageName: "windows-2019"
|
||||
imageName: "windows-latest"
|
||||
python.version: "3.10"
|
||||
Python310Mac:
|
||||
imageName: "macos-10.15"
|
||||
imageName: "macos-latest"
|
||||
python.version: "3.10"
|
||||
maxParallel: 4
|
||||
pool:
|
||||
|
|
|
@ -29,7 +29,7 @@ pytest-timeout>=1.3.0,<2.0.0
|
|||
mock>=2.0.0,<3.0.0
|
||||
flake8>=3.8.0,<3.10.0
|
||||
hypothesis>=3.27.0,<7.0.0
|
||||
mypy>=0.910
|
||||
mypy==0.910
|
||||
types-dataclasses>=0.1.3; python_version < "3.7"
|
||||
types-mock>=0.1.1
|
||||
types-requests
|
||||
|
|
|
@ -19,7 +19,7 @@ class Lexeme:
|
|||
@property
|
||||
def vector_norm(self) -> float: ...
|
||||
vector: Floats1d
|
||||
rank: str
|
||||
rank: int
|
||||
sentiment: float
|
||||
@property
|
||||
def orth_(self) -> str: ...
|
||||
|
|
|
@ -118,6 +118,10 @@ class Tok2Vec(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/tok2vec#predict
|
||||
"""
|
||||
if not any(len(doc) for doc in docs):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
width = self.model.get_dim("nO")
|
||||
return [self.model.ops.alloc((0, width)) for doc in docs]
|
||||
tokvecs = self.model.predict(docs)
|
||||
batch_id = Tok2VecListener.get_batch_id(docs)
|
||||
for listener in self.listeners:
|
||||
|
|
|
@ -11,7 +11,7 @@ from spacy.lang.en import English
|
|||
from thinc.api import Config, get_current_ops
|
||||
from numpy.testing import assert_array_equal
|
||||
|
||||
from ..util import get_batch, make_tempdir
|
||||
from ..util import get_batch, make_tempdir, add_vecs_to_vocab
|
||||
|
||||
|
||||
def test_empty_doc():
|
||||
|
@ -140,9 +140,25 @@ TRAIN_DATA = [
|
|||
]
|
||||
|
||||
|
||||
def test_tok2vec_listener():
|
||||
@pytest.mark.parametrize("with_vectors", (False, True))
|
||||
def test_tok2vec_listener(with_vectors):
|
||||
orig_config = Config().from_str(cfg_string)
|
||||
orig_config["components"]["tok2vec"]["model"]["embed"][
|
||||
"include_static_vectors"
|
||||
] = with_vectors
|
||||
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||
|
||||
if with_vectors:
|
||||
ops = get_current_ops()
|
||||
vectors = [
|
||||
("apple", ops.asarray([1, 2, 3])),
|
||||
("orange", ops.asarray([-1, -2, -3])),
|
||||
("and", ops.asarray([-1, -1, -1])),
|
||||
("juice", ops.asarray([5, 5, 10])),
|
||||
("pie", ops.asarray([7, 6.3, 8.9])),
|
||||
]
|
||||
add_vecs_to_vocab(nlp.vocab, vectors)
|
||||
|
||||
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
||||
tagger = nlp.get_pipe("tagger")
|
||||
tok2vec = nlp.get_pipe("tok2vec")
|
||||
|
@ -169,6 +185,9 @@ def test_tok2vec_listener():
|
|||
ops = get_current_ops()
|
||||
assert_array_equal(ops.to_numpy(doc.tensor), ops.to_numpy(doc_tensor))
|
||||
|
||||
# test with empty doc
|
||||
doc = nlp("")
|
||||
|
||||
# TODO: should this warn or error?
|
||||
nlp.select_pipes(disable="tok2vec")
|
||||
assert nlp.pipe_names == ["tagger"]
|
||||
|
|
Loading…
Reference in New Issue
Block a user