mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-24 16:24:16 +03:00
Add spancat_singlelabel to debug data CLI (#12749)
This commit is contained in:
parent
cb4fdc83e4
commit
e1664217f5
|
@ -230,7 +230,7 @@ def debug_data(
|
||||||
else:
|
else:
|
||||||
msg.info("No word vectors present in the package")
|
msg.info("No word vectors present in the package")
|
||||||
|
|
||||||
if "spancat" in factory_names:
|
if "spancat" in factory_names or "spancat_singlelabel" in factory_names:
|
||||||
model_labels_spancat = _get_labels_from_spancat(nlp)
|
model_labels_spancat = _get_labels_from_spancat(nlp)
|
||||||
has_low_data_warning = False
|
has_low_data_warning = False
|
||||||
has_no_neg_warning = False
|
has_no_neg_warning = False
|
||||||
|
@ -848,7 +848,7 @@ def _compile_gold(
|
||||||
data["boundary_cross_ents"] += 1
|
data["boundary_cross_ents"] += 1
|
||||||
elif label == "-":
|
elif label == "-":
|
||||||
data["ner"]["-"] += 1
|
data["ner"]["-"] += 1
|
||||||
if "spancat" in factory_names:
|
if "spancat" in factory_names or "spancat_singlelabel" in factory_names:
|
||||||
for spans_key in list(eg.reference.spans.keys()):
|
for spans_key in list(eg.reference.spans.keys()):
|
||||||
# Obtain the span frequency
|
# Obtain the span frequency
|
||||||
if spans_key not in data["spancat"]:
|
if spans_key not in data["spancat"]:
|
||||||
|
@ -1046,7 +1046,7 @@ def _get_labels_from_spancat(nlp: Language) -> Dict[str, Set[str]]:
|
||||||
pipe_names = [
|
pipe_names = [
|
||||||
pipe_name
|
pipe_name
|
||||||
for pipe_name in nlp.pipe_names
|
for pipe_name in nlp.pipe_names
|
||||||
if nlp.get_pipe_meta(pipe_name).factory == "spancat"
|
if nlp.get_pipe_meta(pipe_name).factory in ("spancat", "spancat_singlelabel")
|
||||||
]
|
]
|
||||||
labels: Dict[str, Set[str]] = {}
|
labels: Dict[str, Set[str]] = {}
|
||||||
for pipe_name in pipe_names:
|
for pipe_name in pipe_names:
|
||||||
|
|
|
@ -860,7 +860,8 @@ def test_debug_data_compile_gold():
|
||||||
assert data["boundary_cross_ents"] == 1
|
assert data["boundary_cross_ents"] == 1
|
||||||
|
|
||||||
|
|
||||||
def test_debug_data_compile_gold_for_spans():
|
@pytest.mark.parametrize("component_name", ["spancat", "spancat_singlelabel"])
|
||||||
|
def test_debug_data_compile_gold_for_spans(component_name):
|
||||||
nlp = English()
|
nlp = English()
|
||||||
spans_key = "sc"
|
spans_key = "sc"
|
||||||
|
|
||||||
|
@ -870,7 +871,7 @@ def test_debug_data_compile_gold_for_spans():
|
||||||
ref.spans[spans_key] = [Span(ref, 3, 6, "ORG"), Span(ref, 5, 6, "GPE")]
|
ref.spans[spans_key] = [Span(ref, 3, 6, "ORG"), Span(ref, 5, 6, "GPE")]
|
||||||
eg = Example(pred, ref)
|
eg = Example(pred, ref)
|
||||||
|
|
||||||
data = _compile_gold([eg], ["spancat"], nlp, True)
|
data = _compile_gold([eg], [component_name], nlp, True)
|
||||||
|
|
||||||
assert data["spancat"][spans_key] == Counter({"ORG": 1, "GPE": 1})
|
assert data["spancat"][spans_key] == Counter({"ORG": 1, "GPE": 1})
|
||||||
assert data["spans_length"][spans_key] == {"ORG": [3], "GPE": [1]}
|
assert data["spans_length"][spans_key] == {"ORG": [3], "GPE": [1]}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user