diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py index f20673f25..e29ee71a2 100644 --- a/spacy/cli/debug_data.py +++ b/spacy/cli/debug_data.py @@ -7,6 +7,7 @@ import srsly from wasabi import Printer, MESSAGES, msg import typer import math +import numpy as np from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides from ._util import import_code, debug_cli, _format_number @@ -521,9 +522,13 @@ def debug_data( if "tagger" in factory_names: msg.divider("Part-of-speech Tagging") - label_list = [label for label in gold_train_data["tags"]] - model_labels = _get_labels_from_model(nlp, "tagger") + label_list, counts = zip(*gold_train_data["tags"].items()) msg.info(f"{len(label_list)} label(s) in train data") + p = np.array(counts) + p = p / p.sum() + entropy = np.round((-p*np.log2(p)).sum(), 2) + msg.info(f"{entropy} is the train data label entropy") + model_labels = _get_labels_from_model(nlp, "tagger") labels = set(label_list) missing_labels = model_labels - labels if missing_labels: