Fix debug data [ci skip]

This commit is contained in:
Ines Montani 2020-09-29 23:07:11 +02:00
parent a2aa1f6882
commit 9bb958fd0a
2 changed files with 14 additions and 43 deletions

View File

@ -8,12 +8,12 @@ import typer
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code, debug_cli
from ..training import Corpus, Example
from ..training import Example
from ..training.initialize import get_sourced_components
from ..schemas import ConfigSchemaTraining
from ..pipeline._parser_internals import nonproj
from ..language import Language
from ..util import registry
from ..util import registry, resolve_dot_names
from .. import util
@ -37,8 +37,6 @@ BLANK_MODEL_THRESHOLD = 2000
def debug_data_cli(
# fmt: off
ctx: typer.Context, # This is only used to read additional arguments
train_path: Path = Arg(..., help="Location of JSON-formatted training data", exists=True),
dev_path: Path = Arg(..., help="Location of JSON-formatted development data", exists=True),
config_path: Path = Arg(..., help="Path to config file", exists=True),
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
ignore_warnings: bool = Opt(False, "--ignore-warnings", "-IW", help="Ignore warnings, only show stats and errors"),
@ -62,8 +60,6 @@ def debug_data_cli(
overrides = parse_config_overrides(ctx.args)
import_code(code_path)
debug_data(
train_path,
dev_path,
config_path,
config_overrides=overrides,
ignore_warnings=ignore_warnings,
@ -74,8 +70,6 @@ def debug_data_cli(
def debug_data(
train_path: Path,
dev_path: Path,
config_path: Path,
*,
config_overrides: Dict[str, Any] = {},
@ -88,18 +82,11 @@ def debug_data(
no_print=silent, pretty=not no_format, ignore_warnings=ignore_warnings
)
# Make sure all files and paths exists if they are needed
if not train_path.exists():
msg.fail("Training data not found", train_path, exits=1)
if not dev_path.exists():
msg.fail("Development data not found", dev_path, exits=1)
if not config_path.exists():
msg.fail("Config file not found", config_path, exists=1)
with show_validation_error(config_path):
cfg = util.load_config(config_path, overrides=config_overrides)
nlp = util.load_model_from_config(cfg)
T = registry.resolve(
nlp.config.interpolate()["training"], schema=ConfigSchemaTraining
)
config = nlp.config.interpolate()
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
# Use original config here, not resolved version
sourced_components = get_sourced_components(cfg)
frozen_components = T["frozen_components"]
@ -109,25 +96,15 @@ def debug_data(
msg.divider("Data file validation")
# Create the gold corpus to be able to better analyze data
loading_train_error_message = ""
loading_dev_error_message = ""
with msg.loading("Loading corpus..."):
try:
train_dataset = list(Corpus(train_path)(nlp))
except ValueError as e:
loading_train_error_message = f"Training data cannot be loaded: {e}"
try:
dev_dataset = list(Corpus(dev_path)(nlp))
except ValueError as e:
loading_dev_error_message = f"Development data cannot be loaded: {e}"
if loading_train_error_message or loading_dev_error_message:
if loading_train_error_message:
msg.fail(loading_train_error_message)
if loading_dev_error_message:
msg.fail(loading_dev_error_message)
sys.exit(1)
dot_names = [T["train_corpus"], T["dev_corpus"]]
train_corpus, dev_corpus = resolve_dot_names(config, dot_names)
train_dataset = list(train_corpus(nlp))
dev_dataset = list(dev_corpus(nlp))
msg.good("Corpus is loadable")
nlp.initialize(lambda: train_dataset)
msg.good("Pipeline can be initialized with data")
# Create all gold data here to avoid iterating over the train_dataset constantly
gold_train_data = _compile_gold(train_dataset, factory_names, nlp, make_proj=True)
gold_train_unpreprocessed_data = _compile_gold(
@ -348,17 +325,11 @@ def debug_data(
msg.divider("Part-of-speech Tagging")
labels = [label for label in gold_train_data["tags"]]
# TODO: does this need to be updated?
tag_map = nlp.vocab.morphology.tag_map
msg.info(f"{len(labels)} label(s) in data ({len(tag_map)} label(s) in tag map)")
msg.info(f"{len(labels)} label(s) in data")
labels_with_counts = _format_labels(
gold_train_data["tags"].most_common(), counts=True
)
msg.text(labels_with_counts, show=verbose)
non_tagmap = [l for l in labels if l not in tag_map]
if not non_tagmap:
msg.good(f"All labels present in tag map for language '{nlp.lang}'")
for label in non_tagmap:
msg.fail(f"Label '{label}' not found in tag map for language '{nlp.lang}'")
if "parser" in factory_names:
has_low_data_warning = False

View File

@ -436,6 +436,7 @@ $ python -m spacy debug data [config_path] [--code] [--ignore-warnings] [--verbo
```
=========================== Data format validation ===========================
✔ Corpus is loadable
✔ Pipeline can be initialized with data
=============================== Training stats ===============================
Training pipeline: tagger, parser, ner
@ -465,7 +466,7 @@ New: 'ORG' (23860), 'PERSON' (21395), 'GPE' (21193), 'DATE' (18080), 'CARDINAL'
✔ No entities consisting of or starting/ending with whitespace
=========================== Part-of-speech Tagging ===========================
49 labels in data (57 labels in tag map)
49 labels in data
'NN' (266331), 'IN' (227365), 'DT' (185600), 'NNP' (164404), 'JJ' (119830),
'NNS' (110957), '.' (101482), ',' (92476), 'RB' (90090), 'PRP' (90081), 'VB'
(74538), 'VBD' (68199), 'CC' (62862), 'VBZ' (50712), 'VBP' (43420), 'VBN'
@ -476,7 +477,6 @@ New: 'ORG' (23860), 'PERSON' (21395), 'GPE' (21193), 'DATE' (18080), 'CARDINAL'
'-RRB-' (2825), '-LRB-' (2788), 'PDT' (2078), 'XX' (1316), 'RBS' (1142), 'FW'
(794), 'NFP' (557), 'SYM' (440), 'WP$' (294), 'LS' (293), 'ADD' (191), 'AFX'
(24)
✔ All labels present in tag map for language 'en'
============================= Dependency Parsing =============================
Found 111703 sentences with an average length of 18.6 words.