Add test for Language.pipe as_tuples with custom error handlers (#9608)

* make nlp.pipe() return None docs when no exceptions are (re-)raised during error handling

* Remove changes other than as_tuples test

* Only check warning count for one process

* Fix types

* Format

Co-authored-by: Xi Bai <xi.bai.ed@gmail.com>
This commit is contained in:
Adriane Boyd 2021-11-03 10:57:34 +01:00 committed by GitHub
parent 79cea03983
commit db0d8c56d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 27 deletions

View File

@ -1537,8 +1537,7 @@ class Language:
yield (doc, context) yield (doc, context)
return return
# At this point, we know that we're dealing with an iterable of plain texts texts = cast(Iterable[Union[str, Doc]], texts)
texts = cast(Iterable[str], texts)
# Set argument defaults # Set argument defaults
if n_process == -1: if n_process == -1:
@ -1592,7 +1591,7 @@ class Language:
def _multiprocessing_pipe( def _multiprocessing_pipe(
self, self,
texts: Iterable[str], texts: Iterable[Union[str, Doc]],
pipes: Iterable[Callable[..., Iterator[Doc]]], pipes: Iterable[Callable[..., Iterator[Doc]]],
n_process: int, n_process: int,
batch_size: int, batch_size: int,

View File

@ -255,6 +255,38 @@ def test_language_pipe_error_handler_custom(en_vocab, n_process):
assert [doc.text for doc in docs] == ["TEXT 111", "TEXT 333", "TEXT 666"] assert [doc.text for doc in docs] == ["TEXT 111", "TEXT 333", "TEXT 666"]
@pytest.mark.parametrize("n_process", [1, 2])
def test_language_pipe_error_handler_input_as_tuples(en_vocab, n_process):
"""Test the error handling of nlp.pipe with input as tuples"""
Language.component("my_evil_component", func=evil_component)
ops = get_current_ops()
if isinstance(ops, NumpyOps) or n_process < 2:
nlp = English()
nlp.add_pipe("my_evil_component")
texts = [
("TEXT 111", 111),
("TEXT 222", 222),
("TEXT 333", 333),
("TEXT 342", 342),
("TEXT 666", 666),
]
with pytest.raises(ValueError):
list(nlp.pipe(texts, as_tuples=True))
nlp.set_error_handler(warn_error)
logger = logging.getLogger("spacy")
with mock.patch.object(logger, "warning") as mock_warning:
tuples = list(nlp.pipe(texts, as_tuples=True, n_process=n_process))
# HACK/TODO? the warnings in child processes don't seem to be
# detected by the mock logger
if n_process == 1:
mock_warning.assert_called()
assert mock_warning.call_count == 2
assert len(tuples) + mock_warning.call_count == len(texts)
assert (tuples[0][0].text, tuples[0][1]) == ("TEXT 111", 111)
assert (tuples[1][0].text, tuples[1][1]) == ("TEXT 333", 333)
assert (tuples[2][0].text, tuples[2][1]) == ("TEXT 666", 666)
@pytest.mark.parametrize("n_process", [1, 2]) @pytest.mark.parametrize("n_process", [1, 2])
def test_language_pipe_error_handler_pipe(en_vocab, n_process): def test_language_pipe_error_handler_pipe(en_vocab, n_process):
"""Test the error handling of a component's pipe method""" """Test the error handling of a component's pipe method"""
@ -515,19 +547,19 @@ def test_spacy_blank():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"lang,target", "lang,target",
[ [
('en', 'en'), ("en", "en"),
('fra', 'fr'), ("fra", "fr"),
('fre', 'fr'), ("fre", "fr"),
('iw', 'he'), ("iw", "he"),
('mo', 'ro'), ("mo", "ro"),
('mul', 'xx'), ("mul", "xx"),
('no', 'nb'), ("no", "nb"),
('pt-BR', 'pt'), ("pt-BR", "pt"),
('xx', 'xx'), ("xx", "xx"),
('zh-Hans', 'zh'), ("zh-Hans", "zh"),
('zh-Hant', None), ("zh-Hant", None),
('zxx', None) ("zxx", None),
] ],
) )
def test_language_matching(lang, target): def test_language_matching(lang, target):
""" """
@ -540,17 +572,17 @@ def test_language_matching(lang, target):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"lang,target", "lang,target",
[ [
('en', 'en'), ("en", "en"),
('fra', 'fr'), ("fra", "fr"),
('fre', 'fr'), ("fre", "fr"),
('iw', 'he'), ("iw", "he"),
('mo', 'ro'), ("mo", "ro"),
('mul', 'xx'), ("mul", "xx"),
('no', 'nb'), ("no", "nb"),
('pt-BR', 'pt'), ("pt-BR", "pt"),
('xx', 'xx'), ("xx", "xx"),
('zh-Hans', 'zh'), ("zh-Hans", "zh"),
] ],
) )
def test_blank_languages(lang, target): def test_blank_languages(lang, target):
""" """