Add spancat_singlelabel to debug data CLI (#12749)

This commit is contained in:
Adriane Boyd 2023-06-26 10:25:20 +02:00 committed by GitHub
parent cb4fdc83e4
commit e1664217f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 5 deletions

View File

@ -230,7 +230,7 @@ def debug_data(
else: else:
msg.info("No word vectors present in the package") msg.info("No word vectors present in the package")
if "spancat" in factory_names: if "spancat" in factory_names or "spancat_singlelabel" in factory_names:
model_labels_spancat = _get_labels_from_spancat(nlp) model_labels_spancat = _get_labels_from_spancat(nlp)
has_low_data_warning = False has_low_data_warning = False
has_no_neg_warning = False has_no_neg_warning = False
@ -848,7 +848,7 @@ def _compile_gold(
data["boundary_cross_ents"] += 1 data["boundary_cross_ents"] += 1
elif label == "-": elif label == "-":
data["ner"]["-"] += 1 data["ner"]["-"] += 1
if "spancat" in factory_names: if "spancat" in factory_names or "spancat_singlelabel" in factory_names:
for spans_key in list(eg.reference.spans.keys()): for spans_key in list(eg.reference.spans.keys()):
# Obtain the span frequency # Obtain the span frequency
if spans_key not in data["spancat"]: if spans_key not in data["spancat"]:
@ -1046,7 +1046,7 @@ def _get_labels_from_spancat(nlp: Language) -> Dict[str, Set[str]]:
pipe_names = [ pipe_names = [
pipe_name pipe_name
for pipe_name in nlp.pipe_names for pipe_name in nlp.pipe_names
if nlp.get_pipe_meta(pipe_name).factory == "spancat" if nlp.get_pipe_meta(pipe_name).factory in ("spancat", "spancat_singlelabel")
] ]
labels: Dict[str, Set[str]] = {} labels: Dict[str, Set[str]] = {}
for pipe_name in pipe_names: for pipe_name in pipe_names:

View File

@ -860,7 +860,8 @@ def test_debug_data_compile_gold():
assert data["boundary_cross_ents"] == 1 assert data["boundary_cross_ents"] == 1
def test_debug_data_compile_gold_for_spans(): @pytest.mark.parametrize("component_name", ["spancat", "spancat_singlelabel"])
def test_debug_data_compile_gold_for_spans(component_name):
nlp = English() nlp = English()
spans_key = "sc" spans_key = "sc"
@ -870,7 +871,7 @@ def test_debug_data_compile_gold_for_spans():
ref.spans[spans_key] = [Span(ref, 3, 6, "ORG"), Span(ref, 5, 6, "GPE")] ref.spans[spans_key] = [Span(ref, 3, 6, "ORG"), Span(ref, 5, 6, "GPE")]
eg = Example(pred, ref) eg = Example(pred, ref)
data = _compile_gold([eg], ["spancat"], nlp, True) data = _compile_gold([eg], [component_name], nlp, True)
assert data["spancat"][spans_key] == Counter({"ORG": 1, "GPE": 1}) assert data["spancat"][spans_key] == Counter({"ORG": 1, "GPE": 1})
assert data["spans_length"][spans_key] == {"ORG": [3], "GPE": [1]} assert data["spans_length"][spans_key] == {"ORG": [3], "GPE": [1]}