mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Add test for Language.pipe as_tuples with custom error handlers (#9608)
* make nlp.pipe() return None docs when no exceptions are (re-)raised during error handling * Remove changes other than as_tuples test * Only check warning count for one process * Fix types * Format Co-authored-by: Xi Bai <xi.bai.ed@gmail.com>
This commit is contained in:
parent
79cea03983
commit
db0d8c56d0
|
@ -1537,8 +1537,7 @@ class Language:
|
|||
yield (doc, context)
|
||||
return
|
||||
|
||||
# At this point, we know that we're dealing with an iterable of plain texts
|
||||
texts = cast(Iterable[str], texts)
|
||||
texts = cast(Iterable[Union[str, Doc]], texts)
|
||||
|
||||
# Set argument defaults
|
||||
if n_process == -1:
|
||||
|
@ -1592,7 +1591,7 @@ class Language:
|
|||
|
||||
def _multiprocessing_pipe(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
texts: Iterable[Union[str, Doc]],
|
||||
pipes: Iterable[Callable[..., Iterator[Doc]]],
|
||||
n_process: int,
|
||||
batch_size: int,
|
||||
|
|
|
@ -255,6 +255,38 @@ def test_language_pipe_error_handler_custom(en_vocab, n_process):
|
|||
assert [doc.text for doc in docs] == ["TEXT 111", "TEXT 333", "TEXT 666"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_process", [1, 2])
|
||||
def test_language_pipe_error_handler_input_as_tuples(en_vocab, n_process):
|
||||
"""Test the error handling of nlp.pipe with input as tuples"""
|
||||
Language.component("my_evil_component", func=evil_component)
|
||||
ops = get_current_ops()
|
||||
if isinstance(ops, NumpyOps) or n_process < 2:
|
||||
nlp = English()
|
||||
nlp.add_pipe("my_evil_component")
|
||||
texts = [
|
||||
("TEXT 111", 111),
|
||||
("TEXT 222", 222),
|
||||
("TEXT 333", 333),
|
||||
("TEXT 342", 342),
|
||||
("TEXT 666", 666),
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
list(nlp.pipe(texts, as_tuples=True))
|
||||
nlp.set_error_handler(warn_error)
|
||||
logger = logging.getLogger("spacy")
|
||||
with mock.patch.object(logger, "warning") as mock_warning:
|
||||
tuples = list(nlp.pipe(texts, as_tuples=True, n_process=n_process))
|
||||
# HACK/TODO? the warnings in child processes don't seem to be
|
||||
# detected by the mock logger
|
||||
if n_process == 1:
|
||||
mock_warning.assert_called()
|
||||
assert mock_warning.call_count == 2
|
||||
assert len(tuples) + mock_warning.call_count == len(texts)
|
||||
assert (tuples[0][0].text, tuples[0][1]) == ("TEXT 111", 111)
|
||||
assert (tuples[1][0].text, tuples[1][1]) == ("TEXT 333", 333)
|
||||
assert (tuples[2][0].text, tuples[2][1]) == ("TEXT 666", 666)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_process", [1, 2])
|
||||
def test_language_pipe_error_handler_pipe(en_vocab, n_process):
|
||||
"""Test the error handling of a component's pipe method"""
|
||||
|
@ -515,19 +547,19 @@ def test_spacy_blank():
|
|||
@pytest.mark.parametrize(
|
||||
"lang,target",
|
||||
[
|
||||
('en', 'en'),
|
||||
('fra', 'fr'),
|
||||
('fre', 'fr'),
|
||||
('iw', 'he'),
|
||||
('mo', 'ro'),
|
||||
('mul', 'xx'),
|
||||
('no', 'nb'),
|
||||
('pt-BR', 'pt'),
|
||||
('xx', 'xx'),
|
||||
('zh-Hans', 'zh'),
|
||||
('zh-Hant', None),
|
||||
('zxx', None)
|
||||
]
|
||||
("en", "en"),
|
||||
("fra", "fr"),
|
||||
("fre", "fr"),
|
||||
("iw", "he"),
|
||||
("mo", "ro"),
|
||||
("mul", "xx"),
|
||||
("no", "nb"),
|
||||
("pt-BR", "pt"),
|
||||
("xx", "xx"),
|
||||
("zh-Hans", "zh"),
|
||||
("zh-Hant", None),
|
||||
("zxx", None),
|
||||
],
|
||||
)
|
||||
def test_language_matching(lang, target):
|
||||
"""
|
||||
|
@ -540,17 +572,17 @@ def test_language_matching(lang, target):
|
|||
@pytest.mark.parametrize(
|
||||
"lang,target",
|
||||
[
|
||||
('en', 'en'),
|
||||
('fra', 'fr'),
|
||||
('fre', 'fr'),
|
||||
('iw', 'he'),
|
||||
('mo', 'ro'),
|
||||
('mul', 'xx'),
|
||||
('no', 'nb'),
|
||||
('pt-BR', 'pt'),
|
||||
('xx', 'xx'),
|
||||
('zh-Hans', 'zh'),
|
||||
]
|
||||
("en", "en"),
|
||||
("fra", "fr"),
|
||||
("fre", "fr"),
|
||||
("iw", "he"),
|
||||
("mo", "ro"),
|
||||
("mul", "xx"),
|
||||
("no", "nb"),
|
||||
("pt-BR", "pt"),
|
||||
("xx", "xx"),
|
||||
("zh-Hans", "zh"),
|
||||
],
|
||||
)
|
||||
def test_blank_languages(lang, target):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user