2020-07-11 14:03:53 +03:00
|
|
|
from typing import List, Sequence, Dict, Any, Tuple, Optional
|
2018-11-30 22:16:14 +03:00
|
|
|
from pathlib import Path
|
|
|
|
from collections import Counter
|
|
|
|
import sys
|
💫 Replace ujson, msgpack and dill/pickle/cloudpickle with srsly (#3003)
Remove hacks and wrappers, keep code in sync across our libraries and move spaCy a few steps closer to only depending on packages with binary wheels 🎉
See here: https://github.com/explosion/srsly
Serialization is hard, especially across Python versions and multiple platforms. After dealing with many subtle bugs over the years (encodings, locales, large files) our libraries like spaCy and Prodigy have steadily grown a number of utility functions to wrap the multiple serialization formats we need to support (especially json, msgpack and pickle). These wrapping functions ended up duplicated across our codebases, so we wanted to put them in one place.
At the same time, we noticed that having a lot of small dependencies was making maintainence harder, and making installation slower. To solve this, we've made srsly standalone, by including the component packages directly within it. This way we can provide all the serialization utilities we need in a single binary wheel.
srsly currently includes forks of the following packages:
ujson
msgpack
msgpack-numpy
cloudpickle
* WIP: replace json/ujson with srsly
* Replace ujson in examples
Use regular json instead of srsly to make code easier to read and follow
* Update requirements
* Fix imports
* Fix typos
* Replace msgpack with srsly
* Fix warning
2018-12-03 03:28:22 +03:00
|
|
|
import srsly
|
2020-08-17 22:38:20 +03:00
|
|
|
from wasabi import Printer, MESSAGES, msg
|
2020-07-11 14:03:53 +03:00
|
|
|
import typer
|
2018-11-30 22:16:14 +03:00
|
|
|
|
2020-07-11 14:03:53 +03:00
|
|
|
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
2020-09-28 16:09:59 +03:00
|
|
|
from ._util import import_code, debug_cli
|
2020-09-30 00:07:11 +03:00
|
|
|
from ..training import Example
|
2020-09-28 16:09:59 +03:00
|
|
|
from ..training.initialize import get_sourced_components
|
|
|
|
from ..schemas import ConfigSchemaTraining
|
2020-07-31 00:30:54 +03:00
|
|
|
from ..pipeline._parser_internals import nonproj
|
2021-01-07 20:58:13 +03:00
|
|
|
from ..pipeline._parser_internals.nonproj import DELIMITER
|
2020-06-21 22:35:01 +03:00
|
|
|
from ..language import Language
|
2020-09-30 00:07:11 +03:00
|
|
|
from ..util import registry, resolve_dot_names
|
2020-07-11 14:03:53 +03:00
|
|
|
from .. import util
|
2018-11-30 22:16:14 +03:00
|
|
|
|
|
|
|
|
2019-08-16 11:52:46 +03:00
|
|
|
# Minimum number of expected occurrences of NER label in data to train new label
|
2018-11-30 22:16:14 +03:00
|
|
|
NEW_LABEL_THRESHOLD = 50
|
2019-08-16 11:52:46 +03:00
|
|
|
# Minimum number of expected occurrences of dependency labels
|
|
|
|
DEP_LABEL_THRESHOLD = 20
|
2020-09-03 14:13:03 +03:00
|
|
|
# Minimum number of expected examples to train a new pipeline
|
2018-11-30 22:16:14 +03:00
|
|
|
BLANK_MODEL_MIN_THRESHOLD = 100
|
|
|
|
BLANK_MODEL_THRESHOLD = 2000
|
|
|
|
|
|
|
|
|
2020-07-12 14:53:41 +03:00
|
|
|
@debug_cli.command(
|
2020-09-29 22:39:28 +03:00
|
|
|
"data", context_settings={"allow_extra_args": True, "ignore_unknown_options": True}
|
2020-07-12 14:53:41 +03:00
|
|
|
)
|
2020-07-11 20:17:59 +03:00
|
|
|
@app.command(
|
|
|
|
"debug-data",
|
|
|
|
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
2020-07-12 14:53:41 +03:00
|
|
|
hidden=True, # hide this from main CLI help but still allow it to work with warning
|
2020-07-11 20:17:59 +03:00
|
|
|
)
|
2020-06-21 22:35:01 +03:00
|
|
|
def debug_data_cli(
|
2020-03-09 13:17:20 +03:00
|
|
|
# fmt: off
|
2020-07-11 14:03:53 +03:00
|
|
|
ctx: typer.Context, # This is only used to read additional arguments
|
2020-12-08 12:41:18 +03:00
|
|
|
config_path: Path = Arg(..., help="Path to config file", exists=True, allow_dash=True),
|
2020-07-11 14:03:53 +03:00
|
|
|
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
2020-06-21 14:44:00 +03:00
|
|
|
ignore_warnings: bool = Opt(False, "--ignore-warnings", "-IW", help="Ignore warnings, only show stats and errors"),
|
|
|
|
verbose: bool = Opt(False, "--verbose", "-V", help="Print additional information and explanations"),
|
|
|
|
no_format: bool = Opt(False, "--no-format", "-NF", help="Don't pretty-print the results"),
|
2020-03-09 13:17:20 +03:00
|
|
|
# fmt: on
|
2018-11-30 22:16:14 +03:00
|
|
|
):
|
2019-09-12 16:26:20 +03:00
|
|
|
"""
|
2020-07-12 14:53:41 +03:00
|
|
|
Analyze, debug and validate your training and development data. Outputs
|
|
|
|
useful stats, and can help you find problems like invalid entity annotations,
|
|
|
|
cyclic dependencies, low data labels and more.
|
2020-09-04 13:58:50 +03:00
|
|
|
|
|
|
|
DOCS: https://nightly.spacy.io/api/cli#debug-data
|
2019-09-12 16:26:20 +03:00
|
|
|
"""
|
2020-07-12 14:53:41 +03:00
|
|
|
if ctx.command.name == "debug-data":
|
|
|
|
msg.warn(
|
|
|
|
"The debug-data command is now available via the 'debug data' "
|
|
|
|
"subcommand (without the hyphen). You can run python -m spacy debug "
|
|
|
|
"--help for an overview of the other available debugging commands."
|
|
|
|
)
|
2020-07-11 14:03:53 +03:00
|
|
|
overrides = parse_config_overrides(ctx.args)
|
|
|
|
import_code(code_path)
|
2020-06-21 22:35:01 +03:00
|
|
|
debug_data(
|
2020-07-11 14:03:53 +03:00
|
|
|
config_path,
|
|
|
|
config_overrides=overrides,
|
2020-06-21 22:35:01 +03:00
|
|
|
ignore_warnings=ignore_warnings,
|
|
|
|
verbose=verbose,
|
|
|
|
no_format=no_format,
|
|
|
|
silent=False,
|
|
|
|
)
|
|
|
|
|
2018-11-30 22:16:14 +03:00
|
|
|
|
2020-06-21 22:35:01 +03:00
|
|
|
def debug_data(
|
2020-07-11 14:03:53 +03:00
|
|
|
config_path: Path,
|
2020-06-21 22:35:01 +03:00
|
|
|
*,
|
2020-07-11 14:03:53 +03:00
|
|
|
config_overrides: Dict[str, Any] = {},
|
2020-06-21 22:35:01 +03:00
|
|
|
ignore_warnings: bool = False,
|
|
|
|
verbose: bool = False,
|
|
|
|
no_format: bool = True,
|
|
|
|
silent: bool = True,
|
|
|
|
):
|
|
|
|
msg = Printer(
|
|
|
|
no_print=silent, pretty=not no_format, ignore_warnings=ignore_warnings
|
|
|
|
)
|
2018-11-30 22:16:14 +03:00
|
|
|
# Make sure all files and paths exists if they are needed
|
2020-08-02 16:18:30 +03:00
|
|
|
with show_validation_error(config_path):
|
2020-08-14 15:06:22 +03:00
|
|
|
cfg = util.load_config(config_path, overrides=config_overrides)
|
2020-09-27 23:21:31 +03:00
|
|
|
nlp = util.load_model_from_config(cfg)
|
2020-09-30 00:07:11 +03:00
|
|
|
config = nlp.config.interpolate()
|
|
|
|
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
2020-08-05 00:39:19 +03:00
|
|
|
# Use original config here, not resolved version
|
|
|
|
sourced_components = get_sourced_components(cfg)
|
2020-09-28 16:09:59 +03:00
|
|
|
frozen_components = T["frozen_components"]
|
2020-08-05 00:39:19 +03:00
|
|
|
resume_components = [p for p in sourced_components if p not in frozen_components]
|
2020-07-22 14:42:59 +03:00
|
|
|
pipeline = nlp.pipe_names
|
|
|
|
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
|
2020-07-11 14:03:53 +03:00
|
|
|
msg.divider("Data file validation")
|
2018-11-30 22:16:14 +03:00
|
|
|
|
|
|
|
# Create the gold corpus to be able to better analyze data
|
2020-09-30 00:07:11 +03:00
|
|
|
dot_names = [T["train_corpus"], T["dev_corpus"]]
|
|
|
|
train_corpus, dev_corpus = resolve_dot_names(config, dot_names)
|
|
|
|
train_dataset = list(train_corpus(nlp))
|
|
|
|
dev_dataset = list(dev_corpus(nlp))
|
2018-11-30 22:16:14 +03:00
|
|
|
msg.good("Corpus is loadable")
|
|
|
|
|
2020-09-30 00:07:11 +03:00
|
|
|
nlp.initialize(lambda: train_dataset)
|
|
|
|
msg.good("Pipeline can be initialized with data")
|
|
|
|
|
2019-11-11 19:35:27 +03:00
|
|
|
# Create all gold data here to avoid iterating over the train_dataset constantly
|
2020-07-22 14:42:59 +03:00
|
|
|
gold_train_data = _compile_gold(train_dataset, factory_names, nlp, make_proj=True)
|
2020-07-04 17:25:34 +03:00
|
|
|
gold_train_unpreprocessed_data = _compile_gold(
|
2020-07-22 14:42:59 +03:00
|
|
|
train_dataset, factory_names, nlp, make_proj=False
|
2020-07-04 17:25:34 +03:00
|
|
|
)
|
2020-07-22 14:42:59 +03:00
|
|
|
gold_dev_data = _compile_gold(dev_dataset, factory_names, nlp, make_proj=True)
|
2019-08-16 11:52:46 +03:00
|
|
|
|
|
|
|
train_texts = gold_train_data["texts"]
|
|
|
|
dev_texts = gold_dev_data["texts"]
|
2020-09-28 16:09:59 +03:00
|
|
|
frozen_components = T["frozen_components"]
|
2018-11-30 22:16:14 +03:00
|
|
|
|
|
|
|
msg.divider("Training stats")
|
2020-09-28 16:09:59 +03:00
|
|
|
msg.text(f"Language: {nlp.lang}")
|
2019-12-22 03:53:56 +03:00
|
|
|
msg.text(f"Training pipeline: {', '.join(pipeline)}")
|
2020-08-05 00:39:19 +03:00
|
|
|
if resume_components:
|
2020-09-03 14:13:03 +03:00
|
|
|
msg.text(f"Components from other pipelines: {', '.join(resume_components)}")
|
2020-08-05 00:39:19 +03:00
|
|
|
if frozen_components:
|
|
|
|
msg.text(f"Frozen components: {', '.join(frozen_components)}")
|
2019-12-22 03:53:56 +03:00
|
|
|
msg.text(f"{len(train_dataset)} training docs")
|
2020-01-08 18:51:51 +03:00
|
|
|
msg.text(f"{len(dev_dataset)} evaluation docs")
|
2018-11-30 22:16:14 +03:00
|
|
|
|
2019-11-13 17:25:03 +03:00
|
|
|
if not len(gold_dev_data):
|
2019-11-02 18:08:11 +03:00
|
|
|
msg.fail("No evaluation docs")
|
2018-11-30 22:16:14 +03:00
|
|
|
overlap = len(train_texts.intersection(dev_texts))
|
|
|
|
if overlap:
|
2019-12-22 03:53:56 +03:00
|
|
|
msg.warn(f"{overlap} training examples also in evaluation data")
|
2018-11-30 22:16:14 +03:00
|
|
|
else:
|
|
|
|
msg.good("No overlap between training and evaluation data")
|
2020-08-05 00:39:19 +03:00
|
|
|
# TODO: make this feedback more fine-grained and report on updated
|
|
|
|
# components vs. blank components
|
|
|
|
if not resume_components and len(train_dataset) < BLANK_MODEL_THRESHOLD:
|
2020-09-03 14:13:03 +03:00
|
|
|
text = f"Low number of examples to train a new pipeline ({len(train_dataset)})"
|
2019-11-11 19:35:27 +03:00
|
|
|
if len(train_dataset) < BLANK_MODEL_MIN_THRESHOLD:
|
2018-11-30 22:16:14 +03:00
|
|
|
msg.fail(text)
|
|
|
|
else:
|
|
|
|
msg.warn(text)
|
|
|
|
msg.text(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"It's recommended to use at least {BLANK_MODEL_THRESHOLD} examples "
|
|
|
|
f"(minimum {BLANK_MODEL_MIN_THRESHOLD})",
|
2018-11-30 22:16:14 +03:00
|
|
|
show=verbose,
|
|
|
|
)
|
|
|
|
|
|
|
|
msg.divider("Vocab & Vectors")
|
2019-08-16 11:52:46 +03:00
|
|
|
n_words = gold_train_data["n_words"]
|
2018-11-30 22:16:14 +03:00
|
|
|
msg.info(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"{n_words} total word(s) in the data ({len(gold_train_data['words'])} unique)"
|
2018-11-30 22:16:14 +03:00
|
|
|
)
|
2019-08-16 11:52:46 +03:00
|
|
|
if gold_train_data["n_misaligned_words"] > 0:
|
2019-12-22 03:53:56 +03:00
|
|
|
n_misaligned = gold_train_data["n_misaligned_words"]
|
|
|
|
msg.warn(f"{n_misaligned} misaligned tokens in the training data")
|
2019-08-16 11:52:46 +03:00
|
|
|
if gold_dev_data["n_misaligned_words"] > 0:
|
2019-12-22 03:53:56 +03:00
|
|
|
n_misaligned = gold_dev_data["n_misaligned_words"]
|
|
|
|
msg.warn(f"{n_misaligned} misaligned tokens in the dev data")
|
2019-08-16 11:52:46 +03:00
|
|
|
most_common_words = gold_train_data["words"].most_common(10)
|
2018-11-30 22:16:14 +03:00
|
|
|
msg.text(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"10 most common words: {_format_labels(most_common_words, counts=True)}",
|
2018-11-30 22:16:14 +03:00
|
|
|
show=verbose,
|
|
|
|
)
|
|
|
|
if len(nlp.vocab.vectors):
|
|
|
|
msg.info(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"{len(nlp.vocab.vectors)} vectors ({nlp.vocab.vectors.n_keys} "
|
|
|
|
f"unique keys, {nlp.vocab.vectors_length} dimensions)"
|
2018-11-30 22:16:14 +03:00
|
|
|
)
|
2020-04-29 13:56:46 +03:00
|
|
|
n_missing_vectors = sum(gold_train_data["words_missing_vectors"].values())
|
|
|
|
msg.warn(
|
|
|
|
"{} words in training data without vectors ({:0.2f}%)".format(
|
2020-10-03 18:20:18 +03:00
|
|
|
n_missing_vectors, n_missing_vectors / gold_train_data["n_words"]
|
2020-04-29 13:56:46 +03:00
|
|
|
),
|
|
|
|
)
|
|
|
|
msg.text(
|
2020-05-21 15:14:01 +03:00
|
|
|
"10 most common words without vectors: {}".format(
|
|
|
|
_format_labels(
|
|
|
|
gold_train_data["words_missing_vectors"].most_common(10),
|
|
|
|
counts=True,
|
|
|
|
)
|
|
|
|
),
|
|
|
|
show=verbose,
|
2020-04-29 13:56:46 +03:00
|
|
|
)
|
2018-11-30 22:16:14 +03:00
|
|
|
else:
|
2020-09-03 14:13:03 +03:00
|
|
|
msg.info("No word vectors present in the package")
|
2018-11-30 22:16:14 +03:00
|
|
|
|
2020-07-22 14:42:59 +03:00
|
|
|
if "ner" in factory_names:
|
2018-11-30 22:16:14 +03:00
|
|
|
# Get all unique NER labels present in the data
|
2019-08-18 16:09:16 +03:00
|
|
|
labels = set(
|
2019-11-13 17:25:03 +03:00
|
|
|
label for label in gold_train_data["ner"] if label not in ("O", "-", None)
|
2019-08-18 16:09:16 +03:00
|
|
|
)
|
2019-08-16 11:52:46 +03:00
|
|
|
label_counts = gold_train_data["ner"]
|
2018-11-30 22:16:14 +03:00
|
|
|
model_labels = _get_labels_from_model(nlp, "ner")
|
|
|
|
new_labels = [l for l in labels if l not in model_labels]
|
|
|
|
existing_labels = [l for l in labels if l in model_labels]
|
|
|
|
has_low_data_warning = False
|
|
|
|
has_no_neg_warning = False
|
2018-12-08 13:49:43 +03:00
|
|
|
has_ws_ents_error = False
|
2020-01-06 16:59:28 +03:00
|
|
|
has_punct_ents_warning = False
|
2018-11-30 22:16:14 +03:00
|
|
|
|
|
|
|
msg.divider("Named Entity Recognition")
|
|
|
|
msg.info(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"{len(new_labels)} new label(s), {len(existing_labels)} existing label(s)"
|
2018-11-30 22:16:14 +03:00
|
|
|
)
|
|
|
|
missing_values = label_counts["-"]
|
2019-12-22 03:53:56 +03:00
|
|
|
msg.text(f"{missing_values} missing value(s) (tokens with '-' label)")
|
2019-10-21 13:20:28 +03:00
|
|
|
for label in new_labels:
|
|
|
|
if len(label) == 0:
|
|
|
|
msg.fail("Empty label found in new labels")
|
2018-11-30 22:16:14 +03:00
|
|
|
if new_labels:
|
|
|
|
labels_with_counts = [
|
|
|
|
(label, count)
|
|
|
|
for label, count in label_counts.most_common()
|
|
|
|
if label != "-"
|
|
|
|
]
|
|
|
|
labels_with_counts = _format_labels(labels_with_counts, counts=True)
|
2019-12-22 03:53:56 +03:00
|
|
|
msg.text(f"New: {labels_with_counts}", show=verbose)
|
2018-11-30 22:16:14 +03:00
|
|
|
if existing_labels:
|
2019-12-22 03:53:56 +03:00
|
|
|
msg.text(f"Existing: {_format_labels(existing_labels)}", show=verbose)
|
2019-08-16 11:52:46 +03:00
|
|
|
if gold_train_data["ws_ents"]:
|
2019-12-22 03:53:56 +03:00
|
|
|
msg.fail(f"{gold_train_data['ws_ents']} invalid whitespace entity spans")
|
2018-12-08 13:49:43 +03:00
|
|
|
has_ws_ents_error = True
|
|
|
|
|
2020-01-06 16:59:28 +03:00
|
|
|
if gold_train_data["punct_ents"]:
|
2020-02-18 17:38:18 +03:00
|
|
|
msg.warn(f"{gold_train_data['punct_ents']} entity span(s) with punctuation")
|
2020-01-06 16:59:28 +03:00
|
|
|
has_punct_ents_warning = True
|
|
|
|
|
2018-11-30 22:16:14 +03:00
|
|
|
for label in new_labels:
|
|
|
|
if label_counts[label] <= NEW_LABEL_THRESHOLD:
|
|
|
|
msg.warn(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"Low number of examples for new label '{label}' ({label_counts[label]})"
|
2018-11-30 22:16:14 +03:00
|
|
|
)
|
|
|
|
has_low_data_warning = True
|
|
|
|
|
|
|
|
with msg.loading("Analyzing label distribution..."):
|
2019-11-11 19:35:27 +03:00
|
|
|
neg_docs = _get_examples_without_label(train_dataset, label)
|
2018-11-30 22:16:14 +03:00
|
|
|
if neg_docs == 0:
|
2019-12-22 03:53:56 +03:00
|
|
|
msg.warn(f"No examples for texts WITHOUT new label '{label}'")
|
2018-11-30 22:16:14 +03:00
|
|
|
has_no_neg_warning = True
|
|
|
|
|
|
|
|
if not has_low_data_warning:
|
|
|
|
msg.good("Good amount of examples for all labels")
|
|
|
|
if not has_no_neg_warning:
|
2019-08-16 11:53:38 +03:00
|
|
|
msg.good("Examples without occurrences available for all labels")
|
2018-12-08 13:49:43 +03:00
|
|
|
if not has_ws_ents_error:
|
|
|
|
msg.good("No entities consisting of or starting/ending with whitespace")
|
2020-01-06 16:59:28 +03:00
|
|
|
if not has_punct_ents_warning:
|
|
|
|
msg.good("No entities consisting of or starting/ending with punctuation")
|
2018-11-30 22:16:14 +03:00
|
|
|
|
|
|
|
if has_low_data_warning:
|
|
|
|
msg.text(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"To train a new entity type, your data should include at "
|
|
|
|
f"least {NEW_LABEL_THRESHOLD} instances of the new label",
|
2018-11-30 22:16:14 +03:00
|
|
|
show=verbose,
|
|
|
|
)
|
|
|
|
if has_no_neg_warning:
|
|
|
|
msg.text(
|
|
|
|
"Training data should always include examples of entities "
|
|
|
|
"in context, as well as examples without a given entity "
|
|
|
|
"type.",
|
|
|
|
show=verbose,
|
|
|
|
)
|
2018-12-08 13:49:43 +03:00
|
|
|
if has_ws_ents_error:
|
|
|
|
msg.text(
|
|
|
|
"As of spaCy v2.1.0, entity spans consisting of or starting/ending "
|
|
|
|
"with whitespace characters are considered invalid."
|
|
|
|
)
|
2018-11-30 22:16:14 +03:00
|
|
|
|
2020-01-06 16:59:28 +03:00
|
|
|
if has_punct_ents_warning:
|
|
|
|
msg.text(
|
|
|
|
"Entity spans consisting of or starting/ending "
|
|
|
|
"with punctuation can not be trained with a noise level > 0."
|
|
|
|
)
|
|
|
|
|
2020-07-22 14:42:59 +03:00
|
|
|
if "textcat" in factory_names:
|
2018-11-30 22:16:14 +03:00
|
|
|
msg.divider("Text Classification")
|
2019-09-15 23:31:31 +03:00
|
|
|
labels = [label for label in gold_train_data["cats"]]
|
2018-11-30 22:16:14 +03:00
|
|
|
model_labels = _get_labels_from_model(nlp, "textcat")
|
|
|
|
new_labels = [l for l in labels if l not in model_labels]
|
|
|
|
existing_labels = [l for l in labels if l in model_labels]
|
|
|
|
msg.info(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"Text Classification: {len(new_labels)} new label(s), "
|
|
|
|
f"{len(existing_labels)} existing label(s)"
|
2018-11-30 22:16:14 +03:00
|
|
|
)
|
|
|
|
if new_labels:
|
|
|
|
labels_with_counts = _format_labels(
|
2019-09-15 23:31:31 +03:00
|
|
|
gold_train_data["cats"].most_common(), counts=True
|
2018-11-30 22:16:14 +03:00
|
|
|
)
|
2019-12-22 03:53:56 +03:00
|
|
|
msg.text(f"New: {labels_with_counts}", show=verbose)
|
2018-11-30 22:16:14 +03:00
|
|
|
if existing_labels:
|
2019-12-22 03:53:56 +03:00
|
|
|
msg.text(f"Existing: {_format_labels(existing_labels)}", show=verbose)
|
2019-09-15 23:31:31 +03:00
|
|
|
if set(gold_train_data["cats"]) != set(gold_dev_data["cats"]):
|
|
|
|
msg.fail(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"The train and dev labels are not the same. "
|
|
|
|
f"Train labels: {_format_labels(gold_train_data['cats'])}. "
|
|
|
|
f"Dev labels: {_format_labels(gold_dev_data['cats'])}."
|
2019-09-15 23:31:31 +03:00
|
|
|
)
|
|
|
|
if gold_train_data["n_cats_multilabel"] > 0:
|
2019-09-18 21:27:03 +03:00
|
|
|
msg.info(
|
|
|
|
"The train data contains instances without "
|
2019-09-15 23:31:31 +03:00
|
|
|
"mutually-exclusive classes. Use '--textcat-multilabel' "
|
|
|
|
"when training."
|
|
|
|
)
|
|
|
|
if gold_dev_data["n_cats_multilabel"] == 0:
|
|
|
|
msg.warn(
|
|
|
|
"Potential train/dev mismatch: the train data contains "
|
|
|
|
"instances without mutually-exclusive classes while the "
|
|
|
|
"dev data does not."
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
msg.info(
|
|
|
|
"The train data contains only instances with "
|
|
|
|
"mutually-exclusive classes."
|
|
|
|
)
|
|
|
|
if gold_dev_data["n_cats_multilabel"] > 0:
|
|
|
|
msg.fail(
|
|
|
|
"Train/dev mismatch: the dev data contains instances "
|
|
|
|
"without mutually-exclusive classes while the train data "
|
|
|
|
"contains only instances with mutually-exclusive classes."
|
|
|
|
)
|
2018-11-30 22:16:14 +03:00
|
|
|
|
2020-07-22 14:42:59 +03:00
|
|
|
if "tagger" in factory_names:
|
2018-11-30 22:16:14 +03:00
|
|
|
msg.divider("Part-of-speech Tagging")
|
2019-08-16 11:52:46 +03:00
|
|
|
labels = [label for label in gold_train_data["tags"]]
|
2020-09-28 16:09:59 +03:00
|
|
|
# TODO: does this need to be updated?
|
2020-09-30 00:07:11 +03:00
|
|
|
msg.info(f"{len(labels)} label(s) in data")
|
2018-11-30 22:16:14 +03:00
|
|
|
labels_with_counts = _format_labels(
|
2019-08-16 11:52:46 +03:00
|
|
|
gold_train_data["tags"].most_common(), counts=True
|
2018-11-30 22:16:14 +03:00
|
|
|
)
|
|
|
|
msg.text(labels_with_counts, show=verbose)
|
|
|
|
|
2020-07-22 14:42:59 +03:00
|
|
|
if "parser" in factory_names:
|
2019-09-27 21:56:49 +03:00
|
|
|
has_low_data_warning = False
|
2018-11-30 22:16:14 +03:00
|
|
|
msg.divider("Dependency Parsing")
|
2019-08-16 11:52:46 +03:00
|
|
|
|
|
|
|
# profile sentence length
|
|
|
|
msg.info(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"Found {gold_train_data['n_sents']} sentence(s) with an average "
|
|
|
|
f"length of {gold_train_data['n_words'] / gold_train_data['n_sents']:.1f} words."
|
2019-08-16 11:52:46 +03:00
|
|
|
)
|
|
|
|
|
2019-10-18 11:59:16 +03:00
|
|
|
# check for documents with multiple sentences
|
|
|
|
sents_per_doc = gold_train_data["n_sents"] / len(gold_train_data["texts"])
|
|
|
|
if sents_per_doc < 1.1:
|
|
|
|
msg.warn(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"The training data contains {sents_per_doc:.2f} sentences per "
|
|
|
|
f"document. When there are very few documents containing more "
|
|
|
|
f"than one sentence, the parser will not learn how to segment "
|
|
|
|
f"longer texts into sentences."
|
2019-10-18 11:59:16 +03:00
|
|
|
)
|
|
|
|
|
2019-08-16 11:52:46 +03:00
|
|
|
# profile labels
|
|
|
|
labels_train = [label for label in gold_train_data["deps"]]
|
2019-08-18 16:09:16 +03:00
|
|
|
labels_train_unpreprocessed = [
|
|
|
|
label for label in gold_train_unpreprocessed_data["deps"]
|
|
|
|
]
|
2019-08-16 11:52:46 +03:00
|
|
|
labels_dev = [label for label in gold_dev_data["deps"]]
|
|
|
|
|
|
|
|
if gold_train_unpreprocessed_data["n_nonproj"] > 0:
|
2019-12-22 03:53:56 +03:00
|
|
|
n_nonproj = gold_train_unpreprocessed_data["n_nonproj"]
|
|
|
|
msg.info(f"Found {n_nonproj} nonprojective train sentence(s)")
|
2019-08-16 11:52:46 +03:00
|
|
|
if gold_dev_data["n_nonproj"] > 0:
|
2019-12-22 03:53:56 +03:00
|
|
|
n_nonproj = gold_dev_data["n_nonproj"]
|
|
|
|
msg.info(f"Found {n_nonproj} nonprojective dev sentence(s)")
|
2020-07-11 14:03:53 +03:00
|
|
|
msg.info(f"{len(labels_train_unpreprocessed)} label(s) in train data")
|
2019-12-22 03:53:56 +03:00
|
|
|
msg.info(f"{len(labels_train)} label(s) in projectivized train data")
|
2018-11-30 22:16:14 +03:00
|
|
|
labels_with_counts = _format_labels(
|
2019-08-16 11:52:46 +03:00
|
|
|
gold_train_unpreprocessed_data["deps"].most_common(), counts=True
|
2018-11-30 22:16:14 +03:00
|
|
|
)
|
|
|
|
msg.text(labels_with_counts, show=verbose)
|
|
|
|
|
2019-08-16 11:52:46 +03:00
|
|
|
# rare labels in train
|
|
|
|
for label in gold_train_unpreprocessed_data["deps"]:
|
|
|
|
if gold_train_unpreprocessed_data["deps"][label] <= DEP_LABEL_THRESHOLD:
|
|
|
|
msg.warn(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"Low number of examples for label '{label}' "
|
|
|
|
f"({gold_train_unpreprocessed_data['deps'][label]})"
|
2019-08-16 11:52:46 +03:00
|
|
|
)
|
|
|
|
has_low_data_warning = True
|
|
|
|
|
|
|
|
# rare labels in projectivized train
|
|
|
|
rare_projectivized_labels = []
|
|
|
|
for label in gold_train_data["deps"]:
|
2021-01-15 03:57:36 +03:00
|
|
|
if (
|
|
|
|
gold_train_data["deps"][label] <= DEP_LABEL_THRESHOLD
|
|
|
|
and DELIMITER in label
|
|
|
|
):
|
2019-08-18 16:09:16 +03:00
|
|
|
rare_projectivized_labels.append(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"{label}: {gold_train_data['deps'][label]}"
|
2019-08-18 16:09:16 +03:00
|
|
|
)
|
2019-08-16 11:52:46 +03:00
|
|
|
|
|
|
|
if len(rare_projectivized_labels) > 0:
|
2019-08-18 16:09:16 +03:00
|
|
|
msg.warn(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"Low number of examples for {len(rare_projectivized_labels)} "
|
|
|
|
"label(s) in the projectivized dependency trees used for "
|
|
|
|
"training. You may want to projectivize labels such as punct "
|
|
|
|
"before training in order to improve parser performance."
|
2019-08-18 16:09:16 +03:00
|
|
|
)
|
|
|
|
msg.warn(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"Projectivized labels with low numbers of examples: ",
|
|
|
|
", ".join(rare_projectivized_labels),
|
2019-08-18 16:09:16 +03:00
|
|
|
show=verbose,
|
|
|
|
)
|
|
|
|
has_low_data_warning = True
|
2019-08-16 11:52:46 +03:00
|
|
|
|
|
|
|
# labels only in train
|
|
|
|
if set(labels_train) - set(labels_dev):
|
|
|
|
msg.warn(
|
2019-12-22 03:53:56 +03:00
|
|
|
"The following labels were found only in the train data:",
|
|
|
|
", ".join(set(labels_train) - set(labels_dev)),
|
2019-08-18 16:09:16 +03:00
|
|
|
show=verbose,
|
2019-08-16 11:52:46 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# labels only in dev
|
|
|
|
if set(labels_dev) - set(labels_train):
|
|
|
|
msg.warn(
|
2019-12-22 03:53:56 +03:00
|
|
|
"The following labels were found only in the dev data:",
|
|
|
|
", ".join(set(labels_dev) - set(labels_train)),
|
2019-08-18 16:09:16 +03:00
|
|
|
show=verbose,
|
2019-08-16 11:52:46 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
if has_low_data_warning:
|
|
|
|
msg.text(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"To train a parser, your data should include at "
|
|
|
|
f"least {DEP_LABEL_THRESHOLD} instances of each label.",
|
2019-08-16 11:52:46 +03:00
|
|
|
show=verbose,
|
|
|
|
)
|
|
|
|
|
|
|
|
# multiple root labels
|
|
|
|
if len(gold_train_unpreprocessed_data["roots"]) > 1:
|
|
|
|
msg.warn(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"Multiple root labels "
|
|
|
|
f"({', '.join(gold_train_unpreprocessed_data['roots'])}) "
|
|
|
|
f"found in training data. spaCy's parser uses a single root "
|
|
|
|
f"label ROOT so this distinction will not be available."
|
2019-08-16 11:52:46 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# these should not happen, but just in case
|
|
|
|
if gold_train_data["n_nonproj"] > 0:
|
|
|
|
msg.fail(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"Found {gold_train_data['n_nonproj']} nonprojective "
|
|
|
|
f"projectivized train sentence(s)"
|
2019-08-16 11:52:46 +03:00
|
|
|
)
|
|
|
|
if gold_train_data["n_cycles"] > 0:
|
|
|
|
msg.fail(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"Found {gold_train_data['n_cycles']} projectivized train sentence(s) with cycles"
|
2019-08-16 11:52:46 +03:00
|
|
|
)
|
|
|
|
|
2018-11-30 22:16:14 +03:00
|
|
|
msg.divider("Summary")
|
|
|
|
good_counts = msg.counts[MESSAGES.GOOD]
|
|
|
|
warn_counts = msg.counts[MESSAGES.WARN]
|
|
|
|
fail_counts = msg.counts[MESSAGES.FAIL]
|
|
|
|
if good_counts:
|
2019-12-22 03:53:56 +03:00
|
|
|
msg.good(f"{good_counts} {'check' if good_counts == 1 else 'checks'} passed")
|
2018-11-30 22:16:14 +03:00
|
|
|
if warn_counts:
|
2019-12-22 03:53:56 +03:00
|
|
|
msg.warn(f"{warn_counts} {'warning' if warn_counts == 1 else 'warnings'}")
|
2018-11-30 22:16:14 +03:00
|
|
|
if fail_counts:
|
2019-12-22 03:53:56 +03:00
|
|
|
msg.fail(f"{fail_counts} {'error' if fail_counts == 1 else 'errors'}")
|
2018-11-30 22:16:14 +03:00
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
2020-06-21 22:35:01 +03:00
|
|
|
def _load_file(file_path: Path, msg: Printer) -> None:
|
2018-11-30 22:16:14 +03:00
|
|
|
file_name = file_path.parts[-1]
|
|
|
|
if file_path.suffix == ".json":
|
2019-12-22 03:53:56 +03:00
|
|
|
with msg.loading(f"Loading {file_name}..."):
|
2018-12-08 13:49:43 +03:00
|
|
|
data = srsly.read_json(file_path)
|
2019-12-22 03:53:56 +03:00
|
|
|
msg.good(f"Loaded {file_name}")
|
2018-11-30 22:16:14 +03:00
|
|
|
return data
|
|
|
|
elif file_path.suffix == ".jsonl":
|
2019-12-22 03:53:56 +03:00
|
|
|
with msg.loading(f"Loading {file_name}..."):
|
2018-12-08 13:49:43 +03:00
|
|
|
data = srsly.read_jsonl(file_path)
|
2019-12-22 03:53:56 +03:00
|
|
|
msg.good(f"Loaded {file_name}")
|
2018-11-30 22:16:14 +03:00
|
|
|
return data
|
|
|
|
msg.fail(
|
2019-12-22 03:53:56 +03:00
|
|
|
f"Can't load file extension {file_path.suffix}",
|
2018-11-30 22:16:14 +03:00
|
|
|
"Expected .json or .jsonl",
|
|
|
|
exits=1,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2020-06-21 22:35:01 +03:00
|
|
|
def _compile_gold(
|
2020-07-22 14:42:59 +03:00
|
|
|
examples: Sequence[Example],
|
|
|
|
factory_names: List[str],
|
|
|
|
nlp: Language,
|
|
|
|
make_proj: bool,
|
2020-06-21 22:35:01 +03:00
|
|
|
) -> Dict[str, Any]:
|
2018-11-30 22:16:14 +03:00
|
|
|
data = {
|
|
|
|
"ner": Counter(),
|
|
|
|
"cats": Counter(),
|
|
|
|
"tags": Counter(),
|
|
|
|
"deps": Counter(),
|
|
|
|
"words": Counter(),
|
2019-08-16 11:52:46 +03:00
|
|
|
"roots": Counter(),
|
2018-12-08 13:49:43 +03:00
|
|
|
"ws_ents": 0,
|
2020-01-06 16:59:28 +03:00
|
|
|
"punct_ents": 0,
|
2018-11-30 22:16:14 +03:00
|
|
|
"n_words": 0,
|
2019-08-16 11:52:46 +03:00
|
|
|
"n_misaligned_words": 0,
|
2020-04-29 13:56:46 +03:00
|
|
|
"words_missing_vectors": Counter(),
|
2019-08-16 11:52:46 +03:00
|
|
|
"n_sents": 0,
|
|
|
|
"n_nonproj": 0,
|
|
|
|
"n_cycles": 0,
|
2019-09-15 23:31:31 +03:00
|
|
|
"n_cats_multilabel": 0,
|
2018-11-30 22:16:14 +03:00
|
|
|
"texts": set(),
|
|
|
|
}
|
2020-06-29 15:33:00 +03:00
|
|
|
for eg in examples:
|
|
|
|
gold = eg.reference
|
|
|
|
doc = eg.predicted
|
2020-12-15 11:43:14 +03:00
|
|
|
valid_words = [x.text for x in gold]
|
2019-08-16 11:52:46 +03:00
|
|
|
data["words"].update(valid_words)
|
|
|
|
data["n_words"] += len(valid_words)
|
2020-12-15 11:43:14 +03:00
|
|
|
align = eg.alignment
|
|
|
|
for token in doc:
|
|
|
|
if token.orth_.isspace():
|
|
|
|
continue
|
|
|
|
if align.x2y.lengths[token.i] != 1:
|
|
|
|
data["n_misaligned_words"] += 1
|
2018-11-30 22:16:14 +03:00
|
|
|
data["texts"].add(doc.text)
|
2020-04-29 13:56:46 +03:00
|
|
|
if len(nlp.vocab.vectors):
|
2020-12-15 11:43:14 +03:00
|
|
|
for word in [t.text for t in doc]:
|
2020-04-29 13:56:46 +03:00
|
|
|
if nlp.vocab.strings[word] not in nlp.vocab.vectors:
|
|
|
|
data["words_missing_vectors"].update([word])
|
2020-07-22 14:42:59 +03:00
|
|
|
if "ner" in factory_names:
|
2020-06-29 15:33:00 +03:00
|
|
|
for i, label in enumerate(eg.get_aligned_ner()):
|
2019-08-16 11:52:46 +03:00
|
|
|
if label is None:
|
|
|
|
continue
|
2018-12-08 13:49:43 +03:00
|
|
|
if label.startswith(("B-", "U-", "L-")) and doc[i].is_space:
|
|
|
|
# "Illegal" whitespace entity
|
|
|
|
data["ws_ents"] += 1
|
2020-03-09 13:17:20 +03:00
|
|
|
if label.startswith(("B-", "U-", "L-")) and doc[i].text in [
|
|
|
|
".",
|
|
|
|
"'",
|
|
|
|
"!",
|
|
|
|
"?",
|
|
|
|
",",
|
|
|
|
]:
|
2020-01-06 16:59:28 +03:00
|
|
|
# punctuation entity: could be replaced by whitespace when training with noise,
|
|
|
|
# so add a warning to alert the user to this unexpected side effect.
|
|
|
|
data["punct_ents"] += 1
|
2018-11-30 22:16:14 +03:00
|
|
|
if label.startswith(("B-", "U-")):
|
|
|
|
combined_label = label.split("-")[1]
|
|
|
|
data["ner"][combined_label] += 1
|
|
|
|
elif label == "-":
|
|
|
|
data["ner"]["-"] += 1
|
2020-07-22 14:42:59 +03:00
|
|
|
if "textcat" in factory_names:
|
2018-11-30 22:16:14 +03:00
|
|
|
data["cats"].update(gold.cats)
|
2019-09-15 23:31:31 +03:00
|
|
|
if list(gold.cats.values()).count(1.0) != 1:
|
|
|
|
data["n_cats_multilabel"] += 1
|
2020-07-22 14:42:59 +03:00
|
|
|
if "tagger" in factory_names:
|
2020-06-29 15:33:00 +03:00
|
|
|
tags = eg.get_aligned("TAG", as_string=True)
|
|
|
|
data["tags"].update([x for x in tags if x is not None])
|
2020-07-22 14:42:59 +03:00
|
|
|
if "parser" in factory_names:
|
2020-06-29 15:33:00 +03:00
|
|
|
aligned_heads, aligned_deps = eg.get_aligned_parse(projectivize=make_proj)
|
|
|
|
data["deps"].update([x for x in aligned_deps if x is not None])
|
|
|
|
for i, (dep, head) in enumerate(zip(aligned_deps, aligned_heads)):
|
2019-08-16 11:52:46 +03:00
|
|
|
if head == i:
|
|
|
|
data["roots"].update([dep])
|
|
|
|
data["n_sents"] += 1
|
2020-06-29 15:33:00 +03:00
|
|
|
if nonproj.is_nonproj_tree(aligned_heads):
|
2019-08-16 11:52:46 +03:00
|
|
|
data["n_nonproj"] += 1
|
2020-06-29 15:33:00 +03:00
|
|
|
if nonproj.contains_cycle(aligned_heads):
|
2019-08-16 11:52:46 +03:00
|
|
|
data["n_cycles"] += 1
|
2018-11-30 22:16:14 +03:00
|
|
|
return data
|
|
|
|
|
|
|
|
|
2020-06-21 22:35:01 +03:00
|
|
|
def _format_labels(labels: List[Tuple[str, int]], counts: bool = False) -> str:
|
2018-11-30 22:16:14 +03:00
|
|
|
if counts:
|
2019-12-22 03:53:56 +03:00
|
|
|
return ", ".join([f"'{l}' ({c})" for l, c in labels])
|
|
|
|
return ", ".join([f"'{l}'" for l in labels])
|
2018-11-30 22:16:14 +03:00
|
|
|
|
|
|
|
|
2020-06-21 22:35:01 +03:00
|
|
|
def _get_examples_without_label(data: Sequence[Example], label: str) -> int:
|
2018-11-30 22:16:14 +03:00
|
|
|
count = 0
|
2020-06-26 20:34:12 +03:00
|
|
|
for eg in data:
|
2020-04-29 13:56:46 +03:00
|
|
|
labels = [
|
|
|
|
label.split("-")[1]
|
2020-06-29 15:33:00 +03:00
|
|
|
for label in eg.get_aligned_ner()
|
2019-12-22 03:53:56 +03:00
|
|
|
if label not in ("O", "-", None)
|
2020-04-29 13:56:46 +03:00
|
|
|
]
|
2018-11-30 22:16:14 +03:00
|
|
|
if label not in labels:
|
|
|
|
count += 1
|
|
|
|
return count
|
|
|
|
|
|
|
|
|
2020-06-21 22:35:01 +03:00
|
|
|
def _get_labels_from_model(nlp: Language, pipe_name: str) -> Sequence[str]:
|
2018-11-30 22:16:14 +03:00
|
|
|
if pipe_name not in nlp.pipe_names:
|
|
|
|
return set()
|
|
|
|
pipe = nlp.get_pipe(pipe_name)
|
|
|
|
return pipe.labels
|