From 8a2602051c0b878045ed10ecd84ce61f20130a87 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Mon, 17 May 2021 13:27:04 +0200 Subject: [PATCH] Update debug data for textcat (#8066) * Check for unsupported cats values * Only show labels if train/dev mismatched * Don't show label counts (only counting positive labels seems odd) * Use warnings for mismatched train/dev labels --- spacy/cli/debug_data.py | 110 +++++++++++++++++++++++++--------------- 1 file changed, 69 insertions(+), 41 deletions(-) diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py index 1ebf65957..b4119abdf 100644 --- a/spacy/cli/debug_data.py +++ b/spacy/cli/debug_data.py @@ -174,7 +174,8 @@ def debug_data( n_missing_vectors = sum(gold_train_data["words_missing_vectors"].values()) msg.warn( "{} words in training data without vectors ({:.0f}%)".format( - n_missing_vectors, 100 * (n_missing_vectors / gold_train_data["n_words"]) + n_missing_vectors, + 100 * (n_missing_vectors / gold_train_data["n_words"]), ), ) msg.text( @@ -282,42 +283,7 @@ def debug_data( labels = _get_labels_from_model(nlp, "textcat") msg.info(f"Text Classification: {len(labels)} label(s)") msg.text(f"Labels: {_format_labels(labels)}", show=verbose) - labels_with_counts = _format_labels( - gold_train_data["cats"].most_common(), counts=True - ) - msg.text(f"Labels in train data: {labels_with_counts}", show=verbose) - missing_labels = labels - set(gold_train_data["cats"].keys()) - if missing_labels: - msg.warn( - "Some model labels are not present in the train data. The " - "model performance may be degraded for these labels after " - f"training: {_format_labels(missing_labels)}." - ) - if gold_train_data["n_cats_multilabel"] > 0: - # Note: you should never get here because you run into E895 on - # initialization first. - msg.warn( - "The train data contains instances without " - "mutually-exclusive classes. Use the component " - "'textcat_multilabel' instead of 'textcat'." - ) - if gold_dev_data["n_cats_multilabel"] > 0: - msg.fail( - "Train/dev mismatch: the dev data contains instances " - "without mutually-exclusive classes while the train data " - "contains only instances with mutually-exclusive classes." - ) - - if "textcat_multilabel" in factory_names: - msg.divider("Text Classification (Multilabel)") - labels = _get_labels_from_model(nlp, "textcat_multilabel") - msg.info(f"Text Classification: {len(labels)} label(s)") - msg.text(f"Labels: {_format_labels(labels)}", show=verbose) - labels_with_counts = _format_labels( - gold_train_data["cats"].most_common(), counts=True - ) - msg.text(f"Labels in train data: {labels_with_counts}", show=verbose) - missing_labels = labels - set(gold_train_data["cats"].keys()) + missing_labels = labels - set(gold_train_data["cats"]) if missing_labels: msg.warn( "Some model labels are not present in the train data. The " @@ -325,17 +291,76 @@ def debug_data( f"training: {_format_labels(missing_labels)}." ) if set(gold_train_data["cats"]) != set(gold_dev_data["cats"]): - msg.fail( - f"The train and dev labels are not the same. " + msg.warn( + "Potential train/dev mismatch: the train and dev labels are " + "not the same. " f"Train labels: {_format_labels(gold_train_data['cats'])}. " f"Dev labels: {_format_labels(gold_dev_data['cats'])}." ) + if len(labels) < 2: + msg.fail( + "The model does not have enough labels. 'textcat' requires at " + "least two labels due to mutually-exclusive classes, e.g. " + "LABEL/NOT_LABEL or POSITIVE/NEGATIVE for a binary " + "classification task." + ) + if ( + gold_train_data["n_cats_bad_values"] > 0 + or gold_dev_data["n_cats_bad_values"] > 0 + ): + msg.fail( + "Unsupported values for cats: the supported values are " + "1.0/True and 0.0/False." + ) + if gold_train_data["n_cats_multilabel"] > 0: + # Note: you should never get here because you run into E895 on + # initialization first. + msg.fail( + "The train data contains instances without mutually-exclusive " + "classes. Use the component 'textcat_multilabel' instead of " + "'textcat'." + ) + if gold_dev_data["n_cats_multilabel"] > 0: + msg.fail( + "The dev data contains instances without mutually-exclusive " + "classes. Use the component 'textcat_multilabel' instead of " + "'textcat'." + ) + + if "textcat_multilabel" in factory_names: + msg.divider("Text Classification (Multilabel)") + labels = _get_labels_from_model(nlp, "textcat_multilabel") + msg.info(f"Text Classification: {len(labels)} label(s)") + msg.text(f"Labels: {_format_labels(labels)}", show=verbose) + missing_labels = labels - set(gold_train_data["cats"]) + if missing_labels: + msg.warn( + "Some model labels are not present in the train data. The " + "model performance may be degraded for these labels after " + f"training: {_format_labels(missing_labels)}." + ) + if set(gold_train_data["cats"]) != set(gold_dev_data["cats"]): + msg.warn( + "Potential train/dev mismatch: the train and dev labels are " + "not the same. " + f"Train labels: {_format_labels(gold_train_data['cats'])}. " + f"Dev labels: {_format_labels(gold_dev_data['cats'])}." + ) + if ( + gold_train_data["n_cats_bad_values"] > 0 + or gold_dev_data["n_cats_bad_values"] > 0 + ): + msg.fail( + "Unsupported values for cats: the supported values are " + "1.0/True and 0.0/False." + ) if gold_train_data["n_cats_multilabel"] > 0: if gold_dev_data["n_cats_multilabel"] == 0: msg.warn( "Potential train/dev mismatch: the train data contains " "instances without mutually-exclusive classes while the " - "dev data does not." + "dev data contains only instances with mutually-exclusive " + "classes." ) else: msg.warn( @@ -556,6 +581,7 @@ def _compile_gold( "n_nonproj": 0, "n_cycles": 0, "n_cats_multilabel": 0, + "n_cats_bad_values": 0, "texts": set(), } for eg in examples: @@ -599,7 +625,9 @@ def _compile_gold( data["ner"]["-"] += 1 if "textcat" in factory_names or "textcat_multilabel" in factory_names: data["cats"].update(gold.cats) - if list(gold.cats.values()).count(1.0) != 1: + if any(val not in (0, 1) for val in gold.cats.values()): + data["n_cats_bad_values"] += 1 + if list(gold.cats.values()).count(1) != 1: data["n_cats_multilabel"] += 1 if "tagger" in factory_names: tags = eg.get_aligned("TAG", as_string=True)