Add test for Language.pipe as_tuples with custom error handlers (#9608)

* make nlp.pipe() return None docs when no exceptions are (re-)raised during error handling

* Remove changes other than as_tuples test

* Only check warning count for one process

* Fix types

* Format

Co-authored-by: Xi Bai <xi.bai.ed@gmail.com>
This commit is contained in:
Adriane Boyd 2021-11-03 10:57:34 +01:00 committed by GitHub
parent 79cea03983
commit db0d8c56d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 27 deletions

View File

@ -1537,8 +1537,7 @@ class Language:
yield (doc, context)
return
# At this point, we know that we're dealing with an iterable of plain texts
texts = cast(Iterable[str], texts)
texts = cast(Iterable[Union[str, Doc]], texts)
# Set argument defaults
if n_process == -1:
@ -1592,7 +1591,7 @@ class Language:
def _multiprocessing_pipe(
self,
texts: Iterable[str],
texts: Iterable[Union[str, Doc]],
pipes: Iterable[Callable[..., Iterator[Doc]]],
n_process: int,
batch_size: int,

View File

@ -255,6 +255,38 @@ def test_language_pipe_error_handler_custom(en_vocab, n_process):
assert [doc.text for doc in docs] == ["TEXT 111", "TEXT 333", "TEXT 666"]
@pytest.mark.parametrize("n_process", [1, 2])
def test_language_pipe_error_handler_input_as_tuples(en_vocab, n_process):
"""Test the error handling of nlp.pipe with input as tuples"""
Language.component("my_evil_component", func=evil_component)
ops = get_current_ops()
if isinstance(ops, NumpyOps) or n_process < 2:
nlp = English()
nlp.add_pipe("my_evil_component")
texts = [
("TEXT 111", 111),
("TEXT 222", 222),
("TEXT 333", 333),
("TEXT 342", 342),
("TEXT 666", 666),
]
with pytest.raises(ValueError):
list(nlp.pipe(texts, as_tuples=True))
nlp.set_error_handler(warn_error)
logger = logging.getLogger("spacy")
with mock.patch.object(logger, "warning") as mock_warning:
tuples = list(nlp.pipe(texts, as_tuples=True, n_process=n_process))
# HACK/TODO? the warnings in child processes don't seem to be
# detected by the mock logger
if n_process == 1:
mock_warning.assert_called()
assert mock_warning.call_count == 2
assert len(tuples) + mock_warning.call_count == len(texts)
assert (tuples[0][0].text, tuples[0][1]) == ("TEXT 111", 111)
assert (tuples[1][0].text, tuples[1][1]) == ("TEXT 333", 333)
assert (tuples[2][0].text, tuples[2][1]) == ("TEXT 666", 666)
@pytest.mark.parametrize("n_process", [1, 2])
def test_language_pipe_error_handler_pipe(en_vocab, n_process):
"""Test the error handling of a component's pipe method"""
@ -515,19 +547,19 @@ def test_spacy_blank():
@pytest.mark.parametrize(
"lang,target",
[
('en', 'en'),
('fra', 'fr'),
('fre', 'fr'),
('iw', 'he'),
('mo', 'ro'),
('mul', 'xx'),
('no', 'nb'),
('pt-BR', 'pt'),
('xx', 'xx'),
('zh-Hans', 'zh'),
('zh-Hant', None),
('zxx', None)
]
("en", "en"),
("fra", "fr"),
("fre", "fr"),
("iw", "he"),
("mo", "ro"),
("mul", "xx"),
("no", "nb"),
("pt-BR", "pt"),
("xx", "xx"),
("zh-Hans", "zh"),
("zh-Hant", None),
("zxx", None),
],
)
def test_language_matching(lang, target):
"""
@ -540,17 +572,17 @@ def test_language_matching(lang, target):
@pytest.mark.parametrize(
"lang,target",
[
('en', 'en'),
('fra', 'fr'),
('fre', 'fr'),
('iw', 'he'),
('mo', 'ro'),
('mul', 'xx'),
('no', 'nb'),
('pt-BR', 'pt'),
('xx', 'xx'),
('zh-Hans', 'zh'),
]
("en", "en"),
("fra", "fr"),
("fre", "fr"),
("iw", "he"),
("mo", "ro"),
("mul", "xx"),
("no", "nb"),
("pt-BR", "pt"),
("xx", "xx"),
("zh-Hans", "zh"),
],
)
def test_blank_languages(lang, target):
"""