Add spancat pipeline in spacy debug data (#10070)

* Setup debug data for spancat

* Add check for missing labels

* Add low-level data warning error

* Improve logic when compiling the gold train data

* Implement check for negative examples

* Remove breakpoint

* Remove ws_ents and missing entity checks

* Fix mypy errors

* Make variable name spans_key consistent

* Rename pipeline -> component for consistency

* Account for missing labels per spans_key

* Cleanup variable names for consistency

* Improve brevity of conditional statements

* Remove unused variables

* Include spans_key as an argument for _get_examples

* Add a conditional check for spans_key

* Update spancat debug data based on new API

- Instead of using _get_labels_from_model(), I'm now using
_get_labels_from_spancat() (cf. https://github.com/explosion/spaCy/pull10079)
- The way information is displayed was also changed (text -> table)

* Rename model_labels to ensure mypy works

* Update wording on warning messages

Use "span type" instead of "entity type" in wording the warning messages.
This is because Spans aren't necessarily entities.

* Update component type into a Literal

This is to make it clear that the component parameter should only accept
either 'spancat' or 'ner'.

* Update checks to include actual model span_keys

Instead of looking at everything in the data, we only check those
span_keys from the actual spancat component. Instead of doing the filter
inside the for-loop, I just made another dictionary,
data_labels_in_component to hold this value.

* Update spacy/cli/debug_data.py

* Show label counts only when verbose is True

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
Lj Miranda 2022-02-07 22:03:36 +08:00 committed by GitHub
parent 72fece712f
commit 42072f4468
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -193,6 +193,70 @@ def debug_data(
else:
msg.info("No word vectors present in the package")
if "spancat" in factory_names:
model_labels_spancat = _get_labels_from_spancat(nlp)
has_low_data_warning = False
has_no_neg_warning = False
msg.divider("Span Categorization")
msg.table(model_labels_spancat, header=["Spans Key", "Labels"], divider=True)
msg.text("Label counts in train data: ", show=verbose)
for spans_key, data_labels in gold_train_data["spancat"].items():
msg.text(
f"Key: {spans_key}, {_format_labels(data_labels.items(), counts=True)}",
show=verbose,
)
# Data checks: only take the spans keys in the actual spancat components
data_labels_in_component = {
spans_key: gold_train_data["spancat"][spans_key]
for spans_key in model_labels_spancat.keys()
}
for spans_key, data_labels in data_labels_in_component.items():
for label, count in data_labels.items():
# Check for missing labels
spans_key_in_model = spans_key in model_labels_spancat.keys()
if (spans_key_in_model) and (
label not in model_labels_spancat[spans_key]
):
msg.warn(
f"Label '{label}' is not present in the model labels of key '{spans_key}'. "
"Performance may degrade after training."
)
# Check for low number of examples per label
if count <= NEW_LABEL_THRESHOLD:
msg.warn(
f"Low number of examples for label '{label}' in key '{spans_key}' ({count})"
)
has_low_data_warning = True
# Check for negative examples
with msg.loading("Analyzing label distribution..."):
neg_docs = _get_examples_without_label(
train_dataset, label, "spancat", spans_key
)
if neg_docs == 0:
msg.warn(f"No examples for texts WITHOUT new label '{label}'")
has_no_neg_warning = True
if has_low_data_warning:
msg.text(
f"To train a new span type, your data should include at "
f"least {NEW_LABEL_THRESHOLD} instances of the new label",
show=verbose,
)
else:
msg.good("Good amount of examples for all labels")
if has_no_neg_warning:
msg.text(
"Training data should always include examples of spans "
"in context, as well as examples without a given span "
"type.",
show=verbose,
)
else:
msg.good("Examples without ocurrences available for all labels")
if "ner" in factory_names:
# Get all unique NER labels present in the data
labels = set(
@ -238,7 +302,7 @@ def debug_data(
has_low_data_warning = True
with msg.loading("Analyzing label distribution..."):
neg_docs = _get_examples_without_label(train_dataset, label)
neg_docs = _get_examples_without_label(train_dataset, label, "ner")
if neg_docs == 0:
msg.warn(f"No examples for texts WITHOUT new label '{label}'")
has_no_neg_warning = True
@ -573,6 +637,7 @@ def _compile_gold(
"deps": Counter(),
"words": Counter(),
"roots": Counter(),
"spancat": dict(),
"ws_ents": 0,
"boundary_cross_ents": 0,
"n_words": 0,
@ -617,6 +682,15 @@ def _compile_gold(
data["boundary_cross_ents"] += 1
elif label == "-":
data["ner"]["-"] += 1
if "spancat" in factory_names:
for span_key in list(eg.reference.spans.keys()):
if span_key not in data["spancat"]:
data["spancat"][span_key] = Counter()
for i, span in enumerate(eg.reference.spans[span_key]):
if span.label_ is None:
continue
else:
data["spancat"][span_key][span.label_] += 1
if "textcat" in factory_names or "textcat_multilabel" in factory_names:
data["cats"].update(gold.cats)
if any(val not in (0, 1) for val in gold.cats.values()):
@ -687,14 +761,28 @@ def _format_labels(
return ", ".join([f"'{l}'" for l in cast(Iterable[str], labels)])
def _get_examples_without_label(data: Sequence[Example], label: str) -> int:
def _get_examples_without_label(
data: Sequence[Example],
label: str,
component: Literal["ner", "spancat"] = "ner",
spans_key: Optional[str] = "sc",
) -> int:
count = 0
for eg in data:
labels = [
label.split("-")[1]
for label in eg.get_aligned_ner()
if label not in ("O", "-", None)
]
if component == "ner":
labels = [
label.split("-")[1]
for label in eg.get_aligned_ner()
if label not in ("O", "-", None)
]
if component == "spancat":
labels = (
[span.label_ for span in eg.reference.spans[spans_key]]
if spans_key in eg.reference.spans
else []
)
if label not in labels:
count += 1
return count