From d7406fffb0368ec0aaca06e534b510c8e754277c Mon Sep 17 00:00:00 2001 From: vinit Date: Fri, 17 Feb 2023 01:35:33 +0530 Subject: [PATCH] add entropy to debug data --- spacy/cli/debug_data.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py index f20673f25..e29ee71a2 100644 --- a/spacy/cli/debug_data.py +++ b/spacy/cli/debug_data.py @@ -7,6 +7,7 @@ import srsly from wasabi import Printer, MESSAGES, msg import typer import math +import numpy as np from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides from ._util import import_code, debug_cli, _format_number @@ -521,9 +522,13 @@ def debug_data( if "tagger" in factory_names: msg.divider("Part-of-speech Tagging") - label_list = [label for label in gold_train_data["tags"]] - model_labels = _get_labels_from_model(nlp, "tagger") + label_list, counts = zip(*gold_train_data["tags"].items()) msg.info(f"{len(label_list)} label(s) in train data") + p = np.array(counts) + p = p / p.sum() + entropy = np.round((-p*np.log2(p)).sum(), 2) + msg.info(f"{entropy} is the train data label entropy") + model_labels = _get_labels_from_model(nlp, "tagger") labels = set(label_list) missing_labels = model_labels - labels if missing_labels: