mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Update debug data further for v3 (#7602)
* Update debug data further for v3 * Remove new/existing label distinction (new labels are not immediately distinguishable because the pipeline is already initialized) * Warn on missing labels in training data for all components except parser * Separate textcat and textcat_multilabel sections * Add section for morphologizer * Reword missing label warnings
This commit is contained in:
parent
2516896849
commit
73a8c0f992
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Sequence, Dict, Any, Tuple, Optional
|
||||
from typing import List, Sequence, Dict, Any, Tuple, Optional, Set
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
import sys
|
||||
|
@ -13,6 +13,8 @@ from ..training.initialize import get_sourced_components
|
|||
from ..schemas import ConfigSchemaTraining
|
||||
from ..pipeline._parser_internals import nonproj
|
||||
from ..pipeline._parser_internals.nonproj import DELIMITER
|
||||
from ..pipeline import Morphologizer
|
||||
from ..morphology import Morphology
|
||||
from ..language import Language
|
||||
from ..util import registry, resolve_dot_names
|
||||
from .. import util
|
||||
|
@ -194,32 +196,32 @@ def debug_data(
|
|||
)
|
||||
label_counts = gold_train_data["ner"]
|
||||
model_labels = _get_labels_from_model(nlp, "ner")
|
||||
new_labels = [l for l in labels if l not in model_labels]
|
||||
existing_labels = [l for l in labels if l in model_labels]
|
||||
has_low_data_warning = False
|
||||
has_no_neg_warning = False
|
||||
has_ws_ents_error = False
|
||||
has_punct_ents_warning = False
|
||||
|
||||
msg.divider("Named Entity Recognition")
|
||||
msg.info(
|
||||
f"{len(new_labels)} new label(s), {len(existing_labels)} existing label(s)"
|
||||
)
|
||||
msg.info(f"{len(model_labels)} label(s)")
|
||||
missing_values = label_counts["-"]
|
||||
msg.text(f"{missing_values} missing value(s) (tokens with '-' label)")
|
||||
for label in new_labels:
|
||||
for label in labels:
|
||||
if len(label) == 0:
|
||||
msg.fail("Empty label found in new labels")
|
||||
if new_labels:
|
||||
labels_with_counts = [
|
||||
(label, count)
|
||||
for label, count in label_counts.most_common()
|
||||
if label != "-"
|
||||
]
|
||||
labels_with_counts = _format_labels(labels_with_counts, counts=True)
|
||||
msg.text(f"New: {labels_with_counts}", show=verbose)
|
||||
if existing_labels:
|
||||
msg.text(f"Existing: {_format_labels(existing_labels)}", show=verbose)
|
||||
msg.fail("Empty label found in train data")
|
||||
labels_with_counts = [
|
||||
(label, count)
|
||||
for label, count in label_counts.most_common()
|
||||
if label != "-"
|
||||
]
|
||||
labels_with_counts = _format_labels(labels_with_counts, counts=True)
|
||||
msg.text(f"Labels in train data: {_format_labels(labels)}", show=verbose)
|
||||
missing_labels = model_labels - labels
|
||||
if missing_labels:
|
||||
msg.warn(
|
||||
"Some model labels are not present in the train data. The "
|
||||
"model performance may be degraded for these labels after "
|
||||
f"training: {_format_labels(missing_labels)}."
|
||||
)
|
||||
if gold_train_data["ws_ents"]:
|
||||
msg.fail(f"{gold_train_data['ws_ents']} invalid whitespace entity spans")
|
||||
has_ws_ents_error = True
|
||||
|
@ -228,10 +230,10 @@ def debug_data(
|
|||
msg.warn(f"{gold_train_data['punct_ents']} entity span(s) with punctuation")
|
||||
has_punct_ents_warning = True
|
||||
|
||||
for label in new_labels:
|
||||
for label in labels:
|
||||
if label_counts[label] <= NEW_LABEL_THRESHOLD:
|
||||
msg.warn(
|
||||
f"Low number of examples for new label '{label}' ({label_counts[label]})"
|
||||
f"Low number of examples for label '{label}' ({label_counts[label]})"
|
||||
)
|
||||
has_low_data_warning = True
|
||||
|
||||
|
@ -276,22 +278,52 @@ def debug_data(
|
|||
)
|
||||
|
||||
if "textcat" in factory_names:
|
||||
msg.divider("Text Classification")
|
||||
labels = [label for label in gold_train_data["cats"]]
|
||||
model_labels = _get_labels_from_model(nlp, "textcat")
|
||||
new_labels = [l for l in labels if l not in model_labels]
|
||||
existing_labels = [l for l in labels if l in model_labels]
|
||||
msg.info(
|
||||
f"Text Classification: {len(new_labels)} new label(s), "
|
||||
f"{len(existing_labels)} existing label(s)"
|
||||
msg.divider("Text Classification (Exclusive Classes)")
|
||||
labels = _get_labels_from_model(nlp, "textcat")
|
||||
msg.info(f"Text Classification: {len(labels)} label(s)")
|
||||
msg.text(f"Labels: {_format_labels(labels)}", show=verbose)
|
||||
labels_with_counts = _format_labels(
|
||||
gold_train_data["cats"].most_common(), counts=True
|
||||
)
|
||||
if new_labels:
|
||||
labels_with_counts = _format_labels(
|
||||
gold_train_data["cats"].most_common(), counts=True
|
||||
msg.text(f"Labels in train data: {labels_with_counts}", show=verbose)
|
||||
missing_labels = labels - set(gold_train_data["cats"].keys())
|
||||
if missing_labels:
|
||||
msg.warn(
|
||||
"Some model labels are not present in the train data. The "
|
||||
"model performance may be degraded for these labels after "
|
||||
f"training: {_format_labels(missing_labels)}."
|
||||
)
|
||||
if gold_train_data["n_cats_multilabel"] > 0:
|
||||
# Note: you should never get here because you run into E895 on
|
||||
# initialization first.
|
||||
msg.warn(
|
||||
"The train data contains instances without "
|
||||
"mutually-exclusive classes. Use the component "
|
||||
"'textcat_multilabel' instead of 'textcat'."
|
||||
)
|
||||
if gold_dev_data["n_cats_multilabel"] > 0:
|
||||
msg.fail(
|
||||
"Train/dev mismatch: the dev data contains instances "
|
||||
"without mutually-exclusive classes while the train data "
|
||||
"contains only instances with mutually-exclusive classes."
|
||||
)
|
||||
|
||||
if "textcat_multilabel" in factory_names:
|
||||
msg.divider("Text Classification (Multilabel)")
|
||||
labels = _get_labels_from_model(nlp, "textcat_multilabel")
|
||||
msg.info(f"Text Classification: {len(labels)} label(s)")
|
||||
msg.text(f"Labels: {_format_labels(labels)}", show=verbose)
|
||||
labels_with_counts = _format_labels(
|
||||
gold_train_data["cats"].most_common(), counts=True
|
||||
)
|
||||
msg.text(f"Labels in train data: {labels_with_counts}", show=verbose)
|
||||
missing_labels = labels - set(gold_train_data["cats"].keys())
|
||||
if missing_labels:
|
||||
msg.warn(
|
||||
"Some model labels are not present in the train data. The "
|
||||
"model performance may be degraded for these labels after "
|
||||
f"training: {_format_labels(missing_labels)}."
|
||||
)
|
||||
msg.text(f"New: {labels_with_counts}", show=verbose)
|
||||
if existing_labels:
|
||||
msg.text(f"Existing: {_format_labels(existing_labels)}", show=verbose)
|
||||
if set(gold_train_data["cats"]) != set(gold_dev_data["cats"]):
|
||||
msg.fail(
|
||||
f"The train and dev labels are not the same. "
|
||||
|
@ -299,11 +331,6 @@ def debug_data(
|
|||
f"Dev labels: {_format_labels(gold_dev_data['cats'])}."
|
||||
)
|
||||
if gold_train_data["n_cats_multilabel"] > 0:
|
||||
msg.info(
|
||||
"The train data contains instances without "
|
||||
"mutually-exclusive classes. Use '--textcat-multilabel' "
|
||||
"when training."
|
||||
)
|
||||
if gold_dev_data["n_cats_multilabel"] == 0:
|
||||
msg.warn(
|
||||
"Potential train/dev mismatch: the train data contains "
|
||||
|
@ -311,9 +338,10 @@ def debug_data(
|
|||
"dev data does not."
|
||||
)
|
||||
else:
|
||||
msg.info(
|
||||
msg.warn(
|
||||
"The train data contains only instances with "
|
||||
"mutually-exclusive classes."
|
||||
"mutually-exclusive classes. You can potentially use the "
|
||||
"component 'textcat' instead of 'textcat_multilabel'."
|
||||
)
|
||||
if gold_dev_data["n_cats_multilabel"] > 0:
|
||||
msg.fail(
|
||||
|
@ -325,13 +353,37 @@ def debug_data(
|
|||
if "tagger" in factory_names:
|
||||
msg.divider("Part-of-speech Tagging")
|
||||
labels = [label for label in gold_train_data["tags"]]
|
||||
# TODO: does this need to be updated?
|
||||
msg.info(f"{len(labels)} label(s) in data")
|
||||
model_labels = _get_labels_from_model(nlp, "tagger")
|
||||
msg.info(f"{len(labels)} label(s) in train data")
|
||||
missing_labels = model_labels - set(labels)
|
||||
if missing_labels:
|
||||
msg.warn(
|
||||
"Some model labels are not present in the train data. The "
|
||||
"model performance may be degraded for these labels after "
|
||||
f"training: {_format_labels(missing_labels)}."
|
||||
)
|
||||
labels_with_counts = _format_labels(
|
||||
gold_train_data["tags"].most_common(), counts=True
|
||||
)
|
||||
msg.text(labels_with_counts, show=verbose)
|
||||
|
||||
if "morphologizer" in factory_names:
|
||||
msg.divider("Morphologizer (POS+Morph)")
|
||||
labels = [label for label in gold_train_data["morphs"]]
|
||||
model_labels = _get_labels_from_model(nlp, "morphologizer")
|
||||
msg.info(f"{len(labels)} label(s) in train data")
|
||||
missing_labels = model_labels - set(labels)
|
||||
if missing_labels:
|
||||
msg.warn(
|
||||
"Some model labels are not present in the train data. The "
|
||||
"model performance may be degraded for these labels after "
|
||||
f"training: {_format_labels(missing_labels)}."
|
||||
)
|
||||
labels_with_counts = _format_labels(
|
||||
gold_train_data["morphs"].most_common(), counts=True
|
||||
)
|
||||
msg.text(labels_with_counts, show=verbose)
|
||||
|
||||
if "parser" in factory_names:
|
||||
has_low_data_warning = False
|
||||
msg.divider("Dependency Parsing")
|
||||
|
@ -491,6 +543,7 @@ def _compile_gold(
|
|||
"ner": Counter(),
|
||||
"cats": Counter(),
|
||||
"tags": Counter(),
|
||||
"morphs": Counter(),
|
||||
"deps": Counter(),
|
||||
"words": Counter(),
|
||||
"roots": Counter(),
|
||||
|
@ -544,13 +597,36 @@ def _compile_gold(
|
|||
data["ner"][combined_label] += 1
|
||||
elif label == "-":
|
||||
data["ner"]["-"] += 1
|
||||
if "textcat" in factory_names:
|
||||
if "textcat" in factory_names or "textcat_multilabel" in factory_names:
|
||||
data["cats"].update(gold.cats)
|
||||
if list(gold.cats.values()).count(1.0) != 1:
|
||||
data["n_cats_multilabel"] += 1
|
||||
if "tagger" in factory_names:
|
||||
tags = eg.get_aligned("TAG", as_string=True)
|
||||
data["tags"].update([x for x in tags if x is not None])
|
||||
if "morphologizer" in factory_names:
|
||||
pos_tags = eg.get_aligned("POS", as_string=True)
|
||||
morphs = eg.get_aligned("MORPH", as_string=True)
|
||||
for pos, morph in zip(pos_tags, morphs):
|
||||
# POS may align (same value for multiple tokens) when morph
|
||||
# doesn't, so if either is misaligned (None), treat the
|
||||
# annotation as missing so that truths doesn't end up with an
|
||||
# unknown morph+POS combination
|
||||
if pos is None or morph is None:
|
||||
pass
|
||||
# If both are unset, the annotation is missing (empty morph
|
||||
# converted from int is "_" rather than "")
|
||||
elif pos == "" and morph == "":
|
||||
pass
|
||||
# Otherwise, generate the combined label
|
||||
else:
|
||||
label_dict = Morphology.feats_to_dict(morph)
|
||||
if pos:
|
||||
label_dict[Morphologizer.POS_FEAT] = pos
|
||||
label = eg.reference.vocab.strings[
|
||||
eg.reference.vocab.morphology.add(label_dict)
|
||||
]
|
||||
data["morphs"].update([label])
|
||||
if "parser" in factory_names:
|
||||
aligned_heads, aligned_deps = eg.get_aligned_parse(projectivize=make_proj)
|
||||
data["deps"].update([x for x in aligned_deps if x is not None])
|
||||
|
@ -584,8 +660,8 @@ def _get_examples_without_label(data: Sequence[Example], label: str) -> int:
|
|||
return count
|
||||
|
||||
|
||||
def _get_labels_from_model(nlp: Language, pipe_name: str) -> Sequence[str]:
|
||||
def _get_labels_from_model(nlp: Language, pipe_name: str) -> Set[str]:
|
||||
if pipe_name not in nlp.pipe_names:
|
||||
return set()
|
||||
pipe = nlp.get_pipe(pipe_name)
|
||||
return pipe.labels
|
||||
return set(pipe.labels)
|
||||
|
|
Loading…
Reference in New Issue
Block a user