mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-15 10:42:34 +03:00
add entropy to debug data
This commit is contained in:
parent
38e5a75014
commit
d7406fffb0
|
@ -7,6 +7,7 @@ import srsly
|
|||
from wasabi import Printer, MESSAGES, msg
|
||||
import typer
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
||||
from ._util import import_code, debug_cli, _format_number
|
||||
|
@ -521,9 +522,13 @@ def debug_data(
|
|||
|
||||
if "tagger" in factory_names:
|
||||
msg.divider("Part-of-speech Tagging")
|
||||
label_list = [label for label in gold_train_data["tags"]]
|
||||
model_labels = _get_labels_from_model(nlp, "tagger")
|
||||
label_list, counts = zip(*gold_train_data["tags"].items())
|
||||
msg.info(f"{len(label_list)} label(s) in train data")
|
||||
p = np.array(counts)
|
||||
p = p / p.sum()
|
||||
entropy = np.round((-p*np.log2(p)).sum(), 2)
|
||||
msg.info(f"{entropy} is the train data label entropy")
|
||||
model_labels = _get_labels_from_model(nlp, "tagger")
|
||||
labels = set(label_list)
|
||||
missing_labels = model_labels - labels
|
||||
if missing_labels:
|
||||
|
|
Loading…
Reference in New Issue
Block a user