diff --git a/spacy/cli/init_pipeline.py b/spacy/cli/init_pipeline.py index 7c262d84d..2a920cdda 100644 --- a/spacy/cli/init_pipeline.py +++ b/spacy/cli/init_pipeline.py @@ -108,6 +108,10 @@ def init_labels_cli( config = util.load_config(config_path, overrides=overrides) with show_validation_error(hint_fill=False): nlp = init_nlp(config, use_gpu=use_gpu) + _init_labels(nlp, output_path) + + +def _init_labels(nlp, output_path): for name, component in nlp.pipeline: if getattr(component, "label_data", None) is not None: output_file = output_path / f"{name}.json" diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index f5d3d8da9..4e6c47b2f 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -329,7 +329,7 @@ class SpanCategorizer(TrainablePipe): get_examples: Callable[[], Iterable[Example]], *, nlp: Language = None, - labels: Optional[Dict] = None, + labels: Optional[List[str]] = None, ) -> None: """Initialize the pipe for training, using a representative set of data examples. diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 11324aa63..03bef3528 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -19,6 +19,7 @@ import srsly import os from .util import make_tempdir +from ..cli.init_pipeline import _init_labels def test_cli_info(): @@ -501,3 +502,33 @@ def test_validate_compatibility_table(): current_compat = compat.get(spacy_version, {}) assert len(current_compat) > 0 assert "en_core_web_sm" in current_compat + + +@pytest.mark.parametrize("component_name", ["ner", "textcat", "spancat", "tagger"]) +def test_init_labels(component_name): + nlp = Dutch() + component = nlp.add_pipe(component_name) + for label in ["T1", "T2", "T3", "T4"]: + component.add_label(label) + assert len(nlp.get_pipe(component_name).labels) == 4 + + with make_tempdir() as tmp_dir: + _init_labels(nlp, tmp_dir) + + config = init_config( + lang="nl", + pipeline=[component_name], + optimize="efficiency", + gpu=False, + ) + config["initialize"]["components"][component_name] = { + "labels": { + "@readers": "spacy.read_labels.v1", + "path": f"{tmp_dir}/{component_name}.json", + } + } + + nlp2 = load_model_from_config(config, auto_fill=True) + assert len(nlp2.get_pipe(component_name).labels) == 0 + nlp2.initialize() + assert len(nlp2.get_pipe(component_name).labels) == 4