mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Update debug data for textcat (#8066)
* Check for unsupported cats values * Only show labels if train/dev mismatched * Don't show label counts (only counting positive labels seems odd) * Use warnings for mismatched train/dev labels
This commit is contained in:
parent
fe3a4aa846
commit
8a2602051c
|
@ -174,7 +174,8 @@ def debug_data(
|
|||
n_missing_vectors = sum(gold_train_data["words_missing_vectors"].values())
|
||||
msg.warn(
|
||||
"{} words in training data without vectors ({:.0f}%)".format(
|
||||
n_missing_vectors, 100 * (n_missing_vectors / gold_train_data["n_words"])
|
||||
n_missing_vectors,
|
||||
100 * (n_missing_vectors / gold_train_data["n_words"]),
|
||||
),
|
||||
)
|
||||
msg.text(
|
||||
|
@ -282,42 +283,7 @@ def debug_data(
|
|||
labels = _get_labels_from_model(nlp, "textcat")
|
||||
msg.info(f"Text Classification: {len(labels)} label(s)")
|
||||
msg.text(f"Labels: {_format_labels(labels)}", show=verbose)
|
||||
labels_with_counts = _format_labels(
|
||||
gold_train_data["cats"].most_common(), counts=True
|
||||
)
|
||||
msg.text(f"Labels in train data: {labels_with_counts}", show=verbose)
|
||||
missing_labels = labels - set(gold_train_data["cats"].keys())
|
||||
if missing_labels:
|
||||
msg.warn(
|
||||
"Some model labels are not present in the train data. The "
|
||||
"model performance may be degraded for these labels after "
|
||||
f"training: {_format_labels(missing_labels)}."
|
||||
)
|
||||
if gold_train_data["n_cats_multilabel"] > 0:
|
||||
# Note: you should never get here because you run into E895 on
|
||||
# initialization first.
|
||||
msg.warn(
|
||||
"The train data contains instances without "
|
||||
"mutually-exclusive classes. Use the component "
|
||||
"'textcat_multilabel' instead of 'textcat'."
|
||||
)
|
||||
if gold_dev_data["n_cats_multilabel"] > 0:
|
||||
msg.fail(
|
||||
"Train/dev mismatch: the dev data contains instances "
|
||||
"without mutually-exclusive classes while the train data "
|
||||
"contains only instances with mutually-exclusive classes."
|
||||
)
|
||||
|
||||
if "textcat_multilabel" in factory_names:
|
||||
msg.divider("Text Classification (Multilabel)")
|
||||
labels = _get_labels_from_model(nlp, "textcat_multilabel")
|
||||
msg.info(f"Text Classification: {len(labels)} label(s)")
|
||||
msg.text(f"Labels: {_format_labels(labels)}", show=verbose)
|
||||
labels_with_counts = _format_labels(
|
||||
gold_train_data["cats"].most_common(), counts=True
|
||||
)
|
||||
msg.text(f"Labels in train data: {labels_with_counts}", show=verbose)
|
||||
missing_labels = labels - set(gold_train_data["cats"].keys())
|
||||
missing_labels = labels - set(gold_train_data["cats"])
|
||||
if missing_labels:
|
||||
msg.warn(
|
||||
"Some model labels are not present in the train data. The "
|
||||
|
@ -325,17 +291,76 @@ def debug_data(
|
|||
f"training: {_format_labels(missing_labels)}."
|
||||
)
|
||||
if set(gold_train_data["cats"]) != set(gold_dev_data["cats"]):
|
||||
msg.fail(
|
||||
f"The train and dev labels are not the same. "
|
||||
msg.warn(
|
||||
"Potential train/dev mismatch: the train and dev labels are "
|
||||
"not the same. "
|
||||
f"Train labels: {_format_labels(gold_train_data['cats'])}. "
|
||||
f"Dev labels: {_format_labels(gold_dev_data['cats'])}."
|
||||
)
|
||||
if len(labels) < 2:
|
||||
msg.fail(
|
||||
"The model does not have enough labels. 'textcat' requires at "
|
||||
"least two labels due to mutually-exclusive classes, e.g. "
|
||||
"LABEL/NOT_LABEL or POSITIVE/NEGATIVE for a binary "
|
||||
"classification task."
|
||||
)
|
||||
if (
|
||||
gold_train_data["n_cats_bad_values"] > 0
|
||||
or gold_dev_data["n_cats_bad_values"] > 0
|
||||
):
|
||||
msg.fail(
|
||||
"Unsupported values for cats: the supported values are "
|
||||
"1.0/True and 0.0/False."
|
||||
)
|
||||
if gold_train_data["n_cats_multilabel"] > 0:
|
||||
# Note: you should never get here because you run into E895 on
|
||||
# initialization first.
|
||||
msg.fail(
|
||||
"The train data contains instances without mutually-exclusive "
|
||||
"classes. Use the component 'textcat_multilabel' instead of "
|
||||
"'textcat'."
|
||||
)
|
||||
if gold_dev_data["n_cats_multilabel"] > 0:
|
||||
msg.fail(
|
||||
"The dev data contains instances without mutually-exclusive "
|
||||
"classes. Use the component 'textcat_multilabel' instead of "
|
||||
"'textcat'."
|
||||
)
|
||||
|
||||
if "textcat_multilabel" in factory_names:
|
||||
msg.divider("Text Classification (Multilabel)")
|
||||
labels = _get_labels_from_model(nlp, "textcat_multilabel")
|
||||
msg.info(f"Text Classification: {len(labels)} label(s)")
|
||||
msg.text(f"Labels: {_format_labels(labels)}", show=verbose)
|
||||
missing_labels = labels - set(gold_train_data["cats"])
|
||||
if missing_labels:
|
||||
msg.warn(
|
||||
"Some model labels are not present in the train data. The "
|
||||
"model performance may be degraded for these labels after "
|
||||
f"training: {_format_labels(missing_labels)}."
|
||||
)
|
||||
if set(gold_train_data["cats"]) != set(gold_dev_data["cats"]):
|
||||
msg.warn(
|
||||
"Potential train/dev mismatch: the train and dev labels are "
|
||||
"not the same. "
|
||||
f"Train labels: {_format_labels(gold_train_data['cats'])}. "
|
||||
f"Dev labels: {_format_labels(gold_dev_data['cats'])}."
|
||||
)
|
||||
if (
|
||||
gold_train_data["n_cats_bad_values"] > 0
|
||||
or gold_dev_data["n_cats_bad_values"] > 0
|
||||
):
|
||||
msg.fail(
|
||||
"Unsupported values for cats: the supported values are "
|
||||
"1.0/True and 0.0/False."
|
||||
)
|
||||
if gold_train_data["n_cats_multilabel"] > 0:
|
||||
if gold_dev_data["n_cats_multilabel"] == 0:
|
||||
msg.warn(
|
||||
"Potential train/dev mismatch: the train data contains "
|
||||
"instances without mutually-exclusive classes while the "
|
||||
"dev data does not."
|
||||
"dev data contains only instances with mutually-exclusive "
|
||||
"classes."
|
||||
)
|
||||
else:
|
||||
msg.warn(
|
||||
|
@ -556,6 +581,7 @@ def _compile_gold(
|
|||
"n_nonproj": 0,
|
||||
"n_cycles": 0,
|
||||
"n_cats_multilabel": 0,
|
||||
"n_cats_bad_values": 0,
|
||||
"texts": set(),
|
||||
}
|
||||
for eg in examples:
|
||||
|
@ -599,7 +625,9 @@ def _compile_gold(
|
|||
data["ner"]["-"] += 1
|
||||
if "textcat" in factory_names or "textcat_multilabel" in factory_names:
|
||||
data["cats"].update(gold.cats)
|
||||
if list(gold.cats.values()).count(1.0) != 1:
|
||||
if any(val not in (0, 1) for val in gold.cats.values()):
|
||||
data["n_cats_bad_values"] += 1
|
||||
if list(gold.cats.values()).count(1) != 1:
|
||||
data["n_cats_multilabel"] += 1
|
||||
if "tagger" in factory_names:
|
||||
tags = eg.get_aligned("TAG", as_string=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user