mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-15 18:52:29 +03:00
add entropy to debug data
This commit is contained in:
parent
38e5a75014
commit
d7406fffb0
|
@ -7,6 +7,7 @@ import srsly
|
||||||
from wasabi import Printer, MESSAGES, msg
|
from wasabi import Printer, MESSAGES, msg
|
||||||
import typer
|
import typer
|
||||||
import math
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
||||||
from ._util import import_code, debug_cli, _format_number
|
from ._util import import_code, debug_cli, _format_number
|
||||||
|
@ -521,9 +522,13 @@ def debug_data(
|
||||||
|
|
||||||
if "tagger" in factory_names:
|
if "tagger" in factory_names:
|
||||||
msg.divider("Part-of-speech Tagging")
|
msg.divider("Part-of-speech Tagging")
|
||||||
label_list = [label for label in gold_train_data["tags"]]
|
label_list, counts = zip(*gold_train_data["tags"].items())
|
||||||
model_labels = _get_labels_from_model(nlp, "tagger")
|
|
||||||
msg.info(f"{len(label_list)} label(s) in train data")
|
msg.info(f"{len(label_list)} label(s) in train data")
|
||||||
|
p = np.array(counts)
|
||||||
|
p = p / p.sum()
|
||||||
|
entropy = np.round((-p*np.log2(p)).sum(), 2)
|
||||||
|
msg.info(f"{entropy} is the train data label entropy")
|
||||||
|
model_labels = _get_labels_from_model(nlp, "tagger")
|
||||||
labels = set(label_list)
|
labels = set(label_list)
|
||||||
missing_labels = model_labels - labels
|
missing_labels = model_labels - labels
|
||||||
if missing_labels:
|
if missing_labels:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user