add entropy to debug data

This commit is contained in:
vinit 2023-02-17 01:35:33 +05:30
parent 38e5a75014
commit d7406fffb0

View File

@ -7,6 +7,7 @@ import srsly
from wasabi import Printer, MESSAGES, msg
import typer
import math
import numpy as np
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code, debug_cli, _format_number
@ -521,9 +522,13 @@ def debug_data(
if "tagger" in factory_names:
msg.divider("Part-of-speech Tagging")
label_list = [label for label in gold_train_data["tags"]]
model_labels = _get_labels_from_model(nlp, "tagger")
label_list, counts = zip(*gold_train_data["tags"].items())
msg.info(f"{len(label_list)} label(s) in train data")
p = np.array(counts)
p = p / p.sum()
entropy = np.round((-p*np.log2(p)).sum(), 2)
msg.info(f"{entropy} is the train data label entropy")
model_labels = _get_labels_from_model(nlp, "tagger")
labels = set(label_list)
missing_labels = model_labels - labels
if missing_labels: