From e1664217f547bd4c5fb4fce70094832032dd172c Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Mon, 26 Jun 2023 10:25:20 +0200 Subject: [PATCH] Add spancat_singlelabel to debug data CLI (#12749) --- spacy/cli/debug_data.py | 6 +++--- spacy/tests/test_cli.py | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py index e3d0a102f..af3c24f3b 100644 --- a/spacy/cli/debug_data.py +++ b/spacy/cli/debug_data.py @@ -230,7 +230,7 @@ def debug_data( else: msg.info("No word vectors present in the package") - if "spancat" in factory_names: + if "spancat" in factory_names or "spancat_singlelabel" in factory_names: model_labels_spancat = _get_labels_from_spancat(nlp) has_low_data_warning = False has_no_neg_warning = False @@ -848,7 +848,7 @@ def _compile_gold( data["boundary_cross_ents"] += 1 elif label == "-": data["ner"]["-"] += 1 - if "spancat" in factory_names: + if "spancat" in factory_names or "spancat_singlelabel" in factory_names: for spans_key in list(eg.reference.spans.keys()): # Obtain the span frequency if spans_key not in data["spancat"]: @@ -1046,7 +1046,7 @@ def _get_labels_from_spancat(nlp: Language) -> Dict[str, Set[str]]: pipe_names = [ pipe_name for pipe_name in nlp.pipe_names - if nlp.get_pipe_meta(pipe_name).factory == "spancat" + if nlp.get_pipe_meta(pipe_name).factory in ("spancat", "spancat_singlelabel") ] labels: Dict[str, Set[str]] = {} for pipe_name in pipe_names: diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index b1b1b8844..9a2d7705f 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -860,7 +860,8 @@ def test_debug_data_compile_gold(): assert data["boundary_cross_ents"] == 1 -def test_debug_data_compile_gold_for_spans(): +@pytest.mark.parametrize("component_name", ["spancat", "spancat_singlelabel"]) +def test_debug_data_compile_gold_for_spans(component_name): nlp = English() spans_key = "sc" @@ -870,7 +871,7 @@ def test_debug_data_compile_gold_for_spans(): ref.spans[spans_key] = [Span(ref, 3, 6, "ORG"), Span(ref, 5, 6, "GPE")] eg = Example(pred, ref) - data = _compile_gold([eg], ["spancat"], nlp, True) + data = _compile_gold([eg], [component_name], nlp, True) assert data["spancat"][spans_key] == Counter({"ORG": 1, "GPE": 1}) assert data["spans_length"][spans_key] == {"ORG": [3], "GPE": [1]}