From fdfdbcd9f40c73eefe106f9ebf26767809d69a83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 12 Feb 2024 14:39:38 +0100 Subject: [PATCH] Make `Language.pipe` workers exit cleanly (#13321) Also warn when any worker exited with a non-zero exit code and modify test to ensure that workers exit cleanly by default. --- spacy/errors.py | 1 + spacy/language.py | 5 +++++ spacy/tests/test_language.py | 11 ++++++++--- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index b6108dd0f..cf9a7b708 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -220,6 +220,7 @@ class Warnings(metaclass=ErrorsWithCodes): "key attribute for vectors, configure it through Vectors(attr=) or " "'spacy init vectors --attr'") W126 = ("These keys are unsupported: {unsupported}") + W127 = ("Not all `Language.pipe` worker processes completed successfully") class Errors(metaclass=ErrorsWithCodes): diff --git a/spacy/language.py b/spacy/language.py index 568d2d4fa..18d20c939 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1730,6 +1730,9 @@ class Language: for proc in procs: proc.join() + if not all(proc.exitcode == 0 for proc in procs): + warnings.warn(Warnings.W127) + def _link_components(self) -> None: """Register 'listeners' within pipeline components, to allow them to effectively share weights. @@ -2350,6 +2353,7 @@ def _apply_pipes( if isinstance(texts_with_ctx, _WorkDoneSentinel): sender.close() receiver.close() + return docs = ( ensure_doc(doc_like, context) for doc_like, context in texts_with_ctx @@ -2375,6 +2379,7 @@ def _apply_pipes( # stop processing. sender.close() receiver.close() + return class _Sender: diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py index 51eec3239..d229739e1 100644 --- a/spacy/tests/test_language.py +++ b/spacy/tests/test_language.py @@ -1,5 +1,6 @@ import itertools import logging +import warnings from unittest import mock import pytest @@ -738,9 +739,13 @@ def test_pass_doc_to_pipeline(nlp, n_process): assert doc.text == texts[0] assert len(doc.cats) > 0 if isinstance(get_current_ops(), NumpyOps) or n_process < 2: - docs = nlp.pipe(docs, n_process=n_process) - assert [doc.text for doc in docs] == texts - assert all(len(doc.cats) for doc in docs) + # Catch warnings to ensure that all worker processes exited + # succesfully. + with warnings.catch_warnings(): + warnings.simplefilter("error") + docs = nlp.pipe(docs, n_process=n_process) + assert [doc.text for doc in docs] == texts + assert all(len(doc.cats) for doc in docs) def test_invalid_arg_to_pipeline(nlp):