mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
Update debug data further for v3 (#7602)
* Update debug data further for v3 * Remove new/existing label distinction (new labels are not immediately distinguishable because the pipeline is already initialized) * Warn on missing labels in training data for all components except parser * Separate textcat and textcat_multilabel sections * Add section for morphologizer * Reword missing label warnings
This commit is contained in:
parent
2516896849
commit
73a8c0f992
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Sequence, Dict, Any, Tuple, Optional
|
from typing import List, Sequence, Dict, Any, Tuple, Optional, Set
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import sys
|
import sys
|
||||||
|
@ -13,6 +13,8 @@ from ..training.initialize import get_sourced_components
|
||||||
from ..schemas import ConfigSchemaTraining
|
from ..schemas import ConfigSchemaTraining
|
||||||
from ..pipeline._parser_internals import nonproj
|
from ..pipeline._parser_internals import nonproj
|
||||||
from ..pipeline._parser_internals.nonproj import DELIMITER
|
from ..pipeline._parser_internals.nonproj import DELIMITER
|
||||||
|
from ..pipeline import Morphologizer
|
||||||
|
from ..morphology import Morphology
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..util import registry, resolve_dot_names
|
from ..util import registry, resolve_dot_names
|
||||||
from .. import util
|
from .. import util
|
||||||
|
@ -194,32 +196,32 @@ def debug_data(
|
||||||
)
|
)
|
||||||
label_counts = gold_train_data["ner"]
|
label_counts = gold_train_data["ner"]
|
||||||
model_labels = _get_labels_from_model(nlp, "ner")
|
model_labels = _get_labels_from_model(nlp, "ner")
|
||||||
new_labels = [l for l in labels if l not in model_labels]
|
|
||||||
existing_labels = [l for l in labels if l in model_labels]
|
|
||||||
has_low_data_warning = False
|
has_low_data_warning = False
|
||||||
has_no_neg_warning = False
|
has_no_neg_warning = False
|
||||||
has_ws_ents_error = False
|
has_ws_ents_error = False
|
||||||
has_punct_ents_warning = False
|
has_punct_ents_warning = False
|
||||||
|
|
||||||
msg.divider("Named Entity Recognition")
|
msg.divider("Named Entity Recognition")
|
||||||
msg.info(
|
msg.info(f"{len(model_labels)} label(s)")
|
||||||
f"{len(new_labels)} new label(s), {len(existing_labels)} existing label(s)"
|
|
||||||
)
|
|
||||||
missing_values = label_counts["-"]
|
missing_values = label_counts["-"]
|
||||||
msg.text(f"{missing_values} missing value(s) (tokens with '-' label)")
|
msg.text(f"{missing_values} missing value(s) (tokens with '-' label)")
|
||||||
for label in new_labels:
|
for label in labels:
|
||||||
if len(label) == 0:
|
if len(label) == 0:
|
||||||
msg.fail("Empty label found in new labels")
|
msg.fail("Empty label found in train data")
|
||||||
if new_labels:
|
|
||||||
labels_with_counts = [
|
labels_with_counts = [
|
||||||
(label, count)
|
(label, count)
|
||||||
for label, count in label_counts.most_common()
|
for label, count in label_counts.most_common()
|
||||||
if label != "-"
|
if label != "-"
|
||||||
]
|
]
|
||||||
labels_with_counts = _format_labels(labels_with_counts, counts=True)
|
labels_with_counts = _format_labels(labels_with_counts, counts=True)
|
||||||
msg.text(f"New: {labels_with_counts}", show=verbose)
|
msg.text(f"Labels in train data: {_format_labels(labels)}", show=verbose)
|
||||||
if existing_labels:
|
missing_labels = model_labels - labels
|
||||||
msg.text(f"Existing: {_format_labels(existing_labels)}", show=verbose)
|
if missing_labels:
|
||||||
|
msg.warn(
|
||||||
|
"Some model labels are not present in the train data. The "
|
||||||
|
"model performance may be degraded for these labels after "
|
||||||
|
f"training: {_format_labels(missing_labels)}."
|
||||||
|
)
|
||||||
if gold_train_data["ws_ents"]:
|
if gold_train_data["ws_ents"]:
|
||||||
msg.fail(f"{gold_train_data['ws_ents']} invalid whitespace entity spans")
|
msg.fail(f"{gold_train_data['ws_ents']} invalid whitespace entity spans")
|
||||||
has_ws_ents_error = True
|
has_ws_ents_error = True
|
||||||
|
@ -228,10 +230,10 @@ def debug_data(
|
||||||
msg.warn(f"{gold_train_data['punct_ents']} entity span(s) with punctuation")
|
msg.warn(f"{gold_train_data['punct_ents']} entity span(s) with punctuation")
|
||||||
has_punct_ents_warning = True
|
has_punct_ents_warning = True
|
||||||
|
|
||||||
for label in new_labels:
|
for label in labels:
|
||||||
if label_counts[label] <= NEW_LABEL_THRESHOLD:
|
if label_counts[label] <= NEW_LABEL_THRESHOLD:
|
||||||
msg.warn(
|
msg.warn(
|
||||||
f"Low number of examples for new label '{label}' ({label_counts[label]})"
|
f"Low number of examples for label '{label}' ({label_counts[label]})"
|
||||||
)
|
)
|
||||||
has_low_data_warning = True
|
has_low_data_warning = True
|
||||||
|
|
||||||
|
@ -276,22 +278,52 @@ def debug_data(
|
||||||
)
|
)
|
||||||
|
|
||||||
if "textcat" in factory_names:
|
if "textcat" in factory_names:
|
||||||
msg.divider("Text Classification")
|
msg.divider("Text Classification (Exclusive Classes)")
|
||||||
labels = [label for label in gold_train_data["cats"]]
|
labels = _get_labels_from_model(nlp, "textcat")
|
||||||
model_labels = _get_labels_from_model(nlp, "textcat")
|
msg.info(f"Text Classification: {len(labels)} label(s)")
|
||||||
new_labels = [l for l in labels if l not in model_labels]
|
msg.text(f"Labels: {_format_labels(labels)}", show=verbose)
|
||||||
existing_labels = [l for l in labels if l in model_labels]
|
|
||||||
msg.info(
|
|
||||||
f"Text Classification: {len(new_labels)} new label(s), "
|
|
||||||
f"{len(existing_labels)} existing label(s)"
|
|
||||||
)
|
|
||||||
if new_labels:
|
|
||||||
labels_with_counts = _format_labels(
|
labels_with_counts = _format_labels(
|
||||||
gold_train_data["cats"].most_common(), counts=True
|
gold_train_data["cats"].most_common(), counts=True
|
||||||
)
|
)
|
||||||
msg.text(f"New: {labels_with_counts}", show=verbose)
|
msg.text(f"Labels in train data: {labels_with_counts}", show=verbose)
|
||||||
if existing_labels:
|
missing_labels = labels - set(gold_train_data["cats"].keys())
|
||||||
msg.text(f"Existing: {_format_labels(existing_labels)}", show=verbose)
|
if missing_labels:
|
||||||
|
msg.warn(
|
||||||
|
"Some model labels are not present in the train data. The "
|
||||||
|
"model performance may be degraded for these labels after "
|
||||||
|
f"training: {_format_labels(missing_labels)}."
|
||||||
|
)
|
||||||
|
if gold_train_data["n_cats_multilabel"] > 0:
|
||||||
|
# Note: you should never get here because you run into E895 on
|
||||||
|
# initialization first.
|
||||||
|
msg.warn(
|
||||||
|
"The train data contains instances without "
|
||||||
|
"mutually-exclusive classes. Use the component "
|
||||||
|
"'textcat_multilabel' instead of 'textcat'."
|
||||||
|
)
|
||||||
|
if gold_dev_data["n_cats_multilabel"] > 0:
|
||||||
|
msg.fail(
|
||||||
|
"Train/dev mismatch: the dev data contains instances "
|
||||||
|
"without mutually-exclusive classes while the train data "
|
||||||
|
"contains only instances with mutually-exclusive classes."
|
||||||
|
)
|
||||||
|
|
||||||
|
if "textcat_multilabel" in factory_names:
|
||||||
|
msg.divider("Text Classification (Multilabel)")
|
||||||
|
labels = _get_labels_from_model(nlp, "textcat_multilabel")
|
||||||
|
msg.info(f"Text Classification: {len(labels)} label(s)")
|
||||||
|
msg.text(f"Labels: {_format_labels(labels)}", show=verbose)
|
||||||
|
labels_with_counts = _format_labels(
|
||||||
|
gold_train_data["cats"].most_common(), counts=True
|
||||||
|
)
|
||||||
|
msg.text(f"Labels in train data: {labels_with_counts}", show=verbose)
|
||||||
|
missing_labels = labels - set(gold_train_data["cats"].keys())
|
||||||
|
if missing_labels:
|
||||||
|
msg.warn(
|
||||||
|
"Some model labels are not present in the train data. The "
|
||||||
|
"model performance may be degraded for these labels after "
|
||||||
|
f"training: {_format_labels(missing_labels)}."
|
||||||
|
)
|
||||||
if set(gold_train_data["cats"]) != set(gold_dev_data["cats"]):
|
if set(gold_train_data["cats"]) != set(gold_dev_data["cats"]):
|
||||||
msg.fail(
|
msg.fail(
|
||||||
f"The train and dev labels are not the same. "
|
f"The train and dev labels are not the same. "
|
||||||
|
@ -299,11 +331,6 @@ def debug_data(
|
||||||
f"Dev labels: {_format_labels(gold_dev_data['cats'])}."
|
f"Dev labels: {_format_labels(gold_dev_data['cats'])}."
|
||||||
)
|
)
|
||||||
if gold_train_data["n_cats_multilabel"] > 0:
|
if gold_train_data["n_cats_multilabel"] > 0:
|
||||||
msg.info(
|
|
||||||
"The train data contains instances without "
|
|
||||||
"mutually-exclusive classes. Use '--textcat-multilabel' "
|
|
||||||
"when training."
|
|
||||||
)
|
|
||||||
if gold_dev_data["n_cats_multilabel"] == 0:
|
if gold_dev_data["n_cats_multilabel"] == 0:
|
||||||
msg.warn(
|
msg.warn(
|
||||||
"Potential train/dev mismatch: the train data contains "
|
"Potential train/dev mismatch: the train data contains "
|
||||||
|
@ -311,9 +338,10 @@ def debug_data(
|
||||||
"dev data does not."
|
"dev data does not."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
msg.info(
|
msg.warn(
|
||||||
"The train data contains only instances with "
|
"The train data contains only instances with "
|
||||||
"mutually-exclusive classes."
|
"mutually-exclusive classes. You can potentially use the "
|
||||||
|
"component 'textcat' instead of 'textcat_multilabel'."
|
||||||
)
|
)
|
||||||
if gold_dev_data["n_cats_multilabel"] > 0:
|
if gold_dev_data["n_cats_multilabel"] > 0:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
|
@ -325,13 +353,37 @@ def debug_data(
|
||||||
if "tagger" in factory_names:
|
if "tagger" in factory_names:
|
||||||
msg.divider("Part-of-speech Tagging")
|
msg.divider("Part-of-speech Tagging")
|
||||||
labels = [label for label in gold_train_data["tags"]]
|
labels = [label for label in gold_train_data["tags"]]
|
||||||
# TODO: does this need to be updated?
|
model_labels = _get_labels_from_model(nlp, "tagger")
|
||||||
msg.info(f"{len(labels)} label(s) in data")
|
msg.info(f"{len(labels)} label(s) in train data")
|
||||||
|
missing_labels = model_labels - set(labels)
|
||||||
|
if missing_labels:
|
||||||
|
msg.warn(
|
||||||
|
"Some model labels are not present in the train data. The "
|
||||||
|
"model performance may be degraded for these labels after "
|
||||||
|
f"training: {_format_labels(missing_labels)}."
|
||||||
|
)
|
||||||
labels_with_counts = _format_labels(
|
labels_with_counts = _format_labels(
|
||||||
gold_train_data["tags"].most_common(), counts=True
|
gold_train_data["tags"].most_common(), counts=True
|
||||||
)
|
)
|
||||||
msg.text(labels_with_counts, show=verbose)
|
msg.text(labels_with_counts, show=verbose)
|
||||||
|
|
||||||
|
if "morphologizer" in factory_names:
|
||||||
|
msg.divider("Morphologizer (POS+Morph)")
|
||||||
|
labels = [label for label in gold_train_data["morphs"]]
|
||||||
|
model_labels = _get_labels_from_model(nlp, "morphologizer")
|
||||||
|
msg.info(f"{len(labels)} label(s) in train data")
|
||||||
|
missing_labels = model_labels - set(labels)
|
||||||
|
if missing_labels:
|
||||||
|
msg.warn(
|
||||||
|
"Some model labels are not present in the train data. The "
|
||||||
|
"model performance may be degraded for these labels after "
|
||||||
|
f"training: {_format_labels(missing_labels)}."
|
||||||
|
)
|
||||||
|
labels_with_counts = _format_labels(
|
||||||
|
gold_train_data["morphs"].most_common(), counts=True
|
||||||
|
)
|
||||||
|
msg.text(labels_with_counts, show=verbose)
|
||||||
|
|
||||||
if "parser" in factory_names:
|
if "parser" in factory_names:
|
||||||
has_low_data_warning = False
|
has_low_data_warning = False
|
||||||
msg.divider("Dependency Parsing")
|
msg.divider("Dependency Parsing")
|
||||||
|
@ -491,6 +543,7 @@ def _compile_gold(
|
||||||
"ner": Counter(),
|
"ner": Counter(),
|
||||||
"cats": Counter(),
|
"cats": Counter(),
|
||||||
"tags": Counter(),
|
"tags": Counter(),
|
||||||
|
"morphs": Counter(),
|
||||||
"deps": Counter(),
|
"deps": Counter(),
|
||||||
"words": Counter(),
|
"words": Counter(),
|
||||||
"roots": Counter(),
|
"roots": Counter(),
|
||||||
|
@ -544,13 +597,36 @@ def _compile_gold(
|
||||||
data["ner"][combined_label] += 1
|
data["ner"][combined_label] += 1
|
||||||
elif label == "-":
|
elif label == "-":
|
||||||
data["ner"]["-"] += 1
|
data["ner"]["-"] += 1
|
||||||
if "textcat" in factory_names:
|
if "textcat" in factory_names or "textcat_multilabel" in factory_names:
|
||||||
data["cats"].update(gold.cats)
|
data["cats"].update(gold.cats)
|
||||||
if list(gold.cats.values()).count(1.0) != 1:
|
if list(gold.cats.values()).count(1.0) != 1:
|
||||||
data["n_cats_multilabel"] += 1
|
data["n_cats_multilabel"] += 1
|
||||||
if "tagger" in factory_names:
|
if "tagger" in factory_names:
|
||||||
tags = eg.get_aligned("TAG", as_string=True)
|
tags = eg.get_aligned("TAG", as_string=True)
|
||||||
data["tags"].update([x for x in tags if x is not None])
|
data["tags"].update([x for x in tags if x is not None])
|
||||||
|
if "morphologizer" in factory_names:
|
||||||
|
pos_tags = eg.get_aligned("POS", as_string=True)
|
||||||
|
morphs = eg.get_aligned("MORPH", as_string=True)
|
||||||
|
for pos, morph in zip(pos_tags, morphs):
|
||||||
|
# POS may align (same value for multiple tokens) when morph
|
||||||
|
# doesn't, so if either is misaligned (None), treat the
|
||||||
|
# annotation as missing so that truths doesn't end up with an
|
||||||
|
# unknown morph+POS combination
|
||||||
|
if pos is None or morph is None:
|
||||||
|
pass
|
||||||
|
# If both are unset, the annotation is missing (empty morph
|
||||||
|
# converted from int is "_" rather than "")
|
||||||
|
elif pos == "" and morph == "":
|
||||||
|
pass
|
||||||
|
# Otherwise, generate the combined label
|
||||||
|
else:
|
||||||
|
label_dict = Morphology.feats_to_dict(morph)
|
||||||
|
if pos:
|
||||||
|
label_dict[Morphologizer.POS_FEAT] = pos
|
||||||
|
label = eg.reference.vocab.strings[
|
||||||
|
eg.reference.vocab.morphology.add(label_dict)
|
||||||
|
]
|
||||||
|
data["morphs"].update([label])
|
||||||
if "parser" in factory_names:
|
if "parser" in factory_names:
|
||||||
aligned_heads, aligned_deps = eg.get_aligned_parse(projectivize=make_proj)
|
aligned_heads, aligned_deps = eg.get_aligned_parse(projectivize=make_proj)
|
||||||
data["deps"].update([x for x in aligned_deps if x is not None])
|
data["deps"].update([x for x in aligned_deps if x is not None])
|
||||||
|
@ -584,8 +660,8 @@ def _get_examples_without_label(data: Sequence[Example], label: str) -> int:
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
def _get_labels_from_model(nlp: Language, pipe_name: str) -> Sequence[str]:
|
def _get_labels_from_model(nlp: Language, pipe_name: str) -> Set[str]:
|
||||||
if pipe_name not in nlp.pipe_names:
|
if pipe_name not in nlp.pipe_names:
|
||||||
return set()
|
return set()
|
||||||
pipe = nlp.get_pipe(pipe_name)
|
pipe = nlp.get_pipe(pipe_name)
|
||||||
return pipe.labels
|
return set(pipe.labels)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user