diff --git a/spacy/tests/pipeline/test_pipe_methods.py b/spacy/tests/pipeline/test_pipe_methods.py index ea09d990c..d6d04f158 100644 --- a/spacy/tests/pipeline/test_pipe_methods.py +++ b/spacy/tests/pipeline/test_pipe_methods.py @@ -1,6 +1,6 @@ import pytest from spacy.language import Language -from spacy.util import SimpleFrozenList +from spacy.util import SimpleFrozenList, get_arg_names @pytest.fixture @@ -346,3 +346,34 @@ def test_pipe_methods_frozen(): nlp.components.sort() with pytest.raises(NotImplementedError): nlp.component_names.clear() + + +@pytest.mark.parametrize( + "pipe", + [ + "tagger", + "parser", + "ner", + "textcat", + pytest.param("morphologizer", marks=pytest.mark.xfail), + ], +) +def test_pipe_label_data_exports_labels(pipe): + nlp = Language() + pipe = nlp.add_pipe(pipe) + # Make sure pipe has pipe labels + assert getattr(pipe, "label_data", None) is not None + # Make sure pipe can be initialized with labels + initialize = getattr(pipe, "initialize", None) + assert initialize is not None + assert "labels" in get_arg_names(initialize) + + +@pytest.mark.parametrize("pipe", ["senter", "entity_linker"]) +def test_pipe_label_data_no_labels(pipe): + nlp = Language() + pipe = nlp.add_pipe(pipe) + assert getattr(pipe, "label_data", None) is None + initialize = getattr(pipe, "initialize", None) + if initialize is not None: + assert "labels" not in get_arg_names(initialize)