Update debug data for textcat (#8066)

* Check for unsupported cats values
* Only show labels if train/dev mismatched
* Don't show label counts (only counting positive labels seems odd)
* Use warnings for mismatched train/dev labels
This commit is contained in:
Adriane Boyd 2021-05-17 13:27:04 +02:00 committed by GitHub
parent fe3a4aa846
commit 8a2602051c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -174,7 +174,8 @@ def debug_data(
n_missing_vectors = sum(gold_train_data["words_missing_vectors"].values())
msg.warn(
"{} words in training data without vectors ({:.0f}%)".format(
n_missing_vectors, 100 * (n_missing_vectors / gold_train_data["n_words"])
n_missing_vectors,
100 * (n_missing_vectors / gold_train_data["n_words"]),
),
)
msg.text(
@ -282,42 +283,7 @@ def debug_data(
labels = _get_labels_from_model(nlp, "textcat")
msg.info(f"Text Classification: {len(labels)} label(s)")
msg.text(f"Labels: {_format_labels(labels)}", show=verbose)
labels_with_counts = _format_labels(
gold_train_data["cats"].most_common(), counts=True
)
msg.text(f"Labels in train data: {labels_with_counts}", show=verbose)
missing_labels = labels - set(gold_train_data["cats"].keys())
if missing_labels:
msg.warn(
"Some model labels are not present in the train data. The "
"model performance may be degraded for these labels after "
f"training: {_format_labels(missing_labels)}."
)
if gold_train_data["n_cats_multilabel"] > 0:
# Note: you should never get here because you run into E895 on
# initialization first.
msg.warn(
"The train data contains instances without "
"mutually-exclusive classes. Use the component "
"'textcat_multilabel' instead of 'textcat'."
)
if gold_dev_data["n_cats_multilabel"] > 0:
msg.fail(
"Train/dev mismatch: the dev data contains instances "
"without mutually-exclusive classes while the train data "
"contains only instances with mutually-exclusive classes."
)
if "textcat_multilabel" in factory_names:
msg.divider("Text Classification (Multilabel)")
labels = _get_labels_from_model(nlp, "textcat_multilabel")
msg.info(f"Text Classification: {len(labels)} label(s)")
msg.text(f"Labels: {_format_labels(labels)}", show=verbose)
labels_with_counts = _format_labels(
gold_train_data["cats"].most_common(), counts=True
)
msg.text(f"Labels in train data: {labels_with_counts}", show=verbose)
missing_labels = labels - set(gold_train_data["cats"].keys())
missing_labels = labels - set(gold_train_data["cats"])
if missing_labels:
msg.warn(
"Some model labels are not present in the train data. The "
@ -325,17 +291,76 @@ def debug_data(
f"training: {_format_labels(missing_labels)}."
)
if set(gold_train_data["cats"]) != set(gold_dev_data["cats"]):
msg.fail(
f"The train and dev labels are not the same. "
msg.warn(
"Potential train/dev mismatch: the train and dev labels are "
"not the same. "
f"Train labels: {_format_labels(gold_train_data['cats'])}. "
f"Dev labels: {_format_labels(gold_dev_data['cats'])}."
)
if len(labels) < 2:
msg.fail(
"The model does not have enough labels. 'textcat' requires at "
"least two labels due to mutually-exclusive classes, e.g. "
"LABEL/NOT_LABEL or POSITIVE/NEGATIVE for a binary "
"classification task."
)
if (
gold_train_data["n_cats_bad_values"] > 0
or gold_dev_data["n_cats_bad_values"] > 0
):
msg.fail(
"Unsupported values for cats: the supported values are "
"1.0/True and 0.0/False."
)
if gold_train_data["n_cats_multilabel"] > 0:
# Note: you should never get here because you run into E895 on
# initialization first.
msg.fail(
"The train data contains instances without mutually-exclusive "
"classes. Use the component 'textcat_multilabel' instead of "
"'textcat'."
)
if gold_dev_data["n_cats_multilabel"] > 0:
msg.fail(
"The dev data contains instances without mutually-exclusive "
"classes. Use the component 'textcat_multilabel' instead of "
"'textcat'."
)
if "textcat_multilabel" in factory_names:
msg.divider("Text Classification (Multilabel)")
labels = _get_labels_from_model(nlp, "textcat_multilabel")
msg.info(f"Text Classification: {len(labels)} label(s)")
msg.text(f"Labels: {_format_labels(labels)}", show=verbose)
missing_labels = labels - set(gold_train_data["cats"])
if missing_labels:
msg.warn(
"Some model labels are not present in the train data. The "
"model performance may be degraded for these labels after "
f"training: {_format_labels(missing_labels)}."
)
if set(gold_train_data["cats"]) != set(gold_dev_data["cats"]):
msg.warn(
"Potential train/dev mismatch: the train and dev labels are "
"not the same. "
f"Train labels: {_format_labels(gold_train_data['cats'])}. "
f"Dev labels: {_format_labels(gold_dev_data['cats'])}."
)
if (
gold_train_data["n_cats_bad_values"] > 0
or gold_dev_data["n_cats_bad_values"] > 0
):
msg.fail(
"Unsupported values for cats: the supported values are "
"1.0/True and 0.0/False."
)
if gold_train_data["n_cats_multilabel"] > 0:
if gold_dev_data["n_cats_multilabel"] == 0:
msg.warn(
"Potential train/dev mismatch: the train data contains "
"instances without mutually-exclusive classes while the "
"dev data does not."
"dev data contains only instances with mutually-exclusive "
"classes."
)
else:
msg.warn(
@ -556,6 +581,7 @@ def _compile_gold(
"n_nonproj": 0,
"n_cycles": 0,
"n_cats_multilabel": 0,
"n_cats_bad_values": 0,
"texts": set(),
}
for eg in examples:
@ -599,7 +625,9 @@ def _compile_gold(
data["ner"]["-"] += 1
if "textcat" in factory_names or "textcat_multilabel" in factory_names:
data["cats"].update(gold.cats)
if list(gold.cats.values()).count(1.0) != 1:
if any(val not in (0, 1) for val in gold.cats.values()):
data["n_cats_bad_values"] += 1
if list(gold.cats.values()).count(1) != 1:
data["n_cats_multilabel"] += 1
if "tagger" in factory_names:
tags = eg.get_aligned("TAG", as_string=True)