mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
f37863093a
Remove hacks and wrappers, keep code in sync across our libraries and move spaCy a few steps closer to only depending on packages with binary wheels 🎉 See here: https://github.com/explosion/srsly Serialization is hard, especially across Python versions and multiple platforms. After dealing with many subtle bugs over the years (encodings, locales, large files) our libraries like spaCy and Prodigy have steadily grown a number of utility functions to wrap the multiple serialization formats we need to support (especially json, msgpack and pickle). These wrapping functions ended up duplicated across our codebases, so we wanted to put them in one place. At the same time, we noticed that having a lot of small dependencies was making maintainence harder, and making installation slower. To solve this, we've made srsly standalone, by including the component packages directly within it. This way we can provide all the serialization utilities we need in a single binary wheel. srsly currently includes forks of the following packages: ujson msgpack msgpack-numpy cloudpickle * WIP: replace json/ujson with srsly * Replace ujson in examples Use regular json instead of srsly to make code easier to read and follow * Update requirements * Fix imports * Fix typos * Replace msgpack with srsly * Fix warning
400 lines
14 KiB
Python
400 lines
14 KiB
Python
# coding: utf8
|
||
from __future__ import unicode_literals, print_function
|
||
|
||
from pathlib import Path
|
||
from collections import Counter
|
||
import plac
|
||
import sys
|
||
import srsly
|
||
from wasabi import Printer, MESSAGES
|
||
|
||
from ..gold import GoldCorpus, read_json_object
|
||
from ..util import load_model, get_lang_class
|
||
|
||
# from .schemas import get_schema, validate_json
|
||
from ._messages import Messages
|
||
|
||
|
||
# Minimum number of expected occurences of label in data to train new label
|
||
NEW_LABEL_THRESHOLD = 50
|
||
# Minimum number of expected examples to train a blank model
|
||
BLANK_MODEL_MIN_THRESHOLD = 100
|
||
BLANK_MODEL_THRESHOLD = 2000
|
||
|
||
|
||
@plac.annotations(
|
||
lang=("model language", "positional", None, str),
|
||
train_path=("location of JSON-formatted training data", "positional", None, Path),
|
||
dev_path=("location of JSON-formatted development data", "positional", None, Path),
|
||
base_model=("name of model to update (optional)", "option", "b", str),
|
||
pipeline=(
|
||
"Comma-separated names of pipeline components to train",
|
||
"option",
|
||
"p",
|
||
str,
|
||
),
|
||
ignore_warnings=("Ignore warnings, only show stats and errors", "flag", "IW", bool),
|
||
ignore_validation=(
|
||
"Don't exit if JSON format validation fails",
|
||
"flag",
|
||
"IV",
|
||
bool,
|
||
),
|
||
verbose=("Print additional information and explanations", "flag", "V", bool),
|
||
no_format=("Don't pretty-print the results", "flag", "NF", bool),
|
||
)
|
||
def debug_data(
|
||
lang,
|
||
train_path,
|
||
dev_path,
|
||
base_model=None,
|
||
pipeline="tagger,parser,ner",
|
||
ignore_warnings=False,
|
||
ignore_validation=False,
|
||
verbose=False,
|
||
no_format=False,
|
||
):
|
||
msg = Printer(pretty=not no_format, ignore_warnings=ignore_warnings)
|
||
|
||
# Make sure all files and paths exists if they are needed
|
||
if not train_path.exists():
|
||
msg.fail(Messages.M050, train_path, exits=1)
|
||
if not dev_path.exists():
|
||
msg.fail(Messages.M051, dev_path, exits=1)
|
||
|
||
# Initialize the model and pipeline
|
||
pipeline = [p.strip() for p in pipeline.split(",")]
|
||
if base_model:
|
||
nlp = load_model(base_model)
|
||
else:
|
||
lang_cls = get_lang_class(lang)
|
||
nlp = lang_cls()
|
||
|
||
msg.divider("Data format validation")
|
||
# Load the data in one – might take a while but okay in this case
|
||
with msg.loading("Loading {}...".format(train_path.parts[-1])):
|
||
train_data = _load_file(train_path, msg)
|
||
with msg.loading("Loading {}...".format(dev_path.parts[-1])):
|
||
dev_data = _load_file(dev_path, msg)
|
||
|
||
# Validate data format using the JSON schema
|
||
# TODO: update once the new format is ready
|
||
# schema = get_schema("training")
|
||
train_data_errors = [] # TODO: validate_json(train_data, schema)
|
||
dev_data_errors = [] # TODO: validate_json(dev_data, schema)
|
||
if not train_data_errors:
|
||
msg.good("Training data JSON format is valid")
|
||
if not dev_data_errors:
|
||
msg.good("Development data JSON format is valid")
|
||
for error in train_data_errors:
|
||
msg.fail("Training data: {}".format(error))
|
||
for error in dev_data_errors:
|
||
msg.fail("Develoment data: {}".format(error))
|
||
if (train_data_errors or dev_data_errors) and not ignore_validation:
|
||
sys.exit(1)
|
||
|
||
# Create the gold corpus to be able to better analyze data
|
||
with msg.loading("Analyzing corpus..."):
|
||
train_data = read_json_object(train_data)
|
||
dev_data = read_json_object(dev_data)
|
||
corpus = GoldCorpus(train_data, dev_data)
|
||
train_docs = list(corpus.train_docs(nlp))
|
||
dev_docs = list(corpus.dev_docs(nlp))
|
||
msg.good("Corpus is loadable")
|
||
|
||
# Create all gold data here to avoid iterating over the train_docs constantly
|
||
gold_data = _compile_gold(train_docs, pipeline)
|
||
train_texts = gold_data["texts"]
|
||
dev_texts = set([doc.text for doc, gold in dev_docs])
|
||
|
||
msg.divider("Training stats")
|
||
msg.text("Training pipeline: {}".format(", ".join(pipeline)))
|
||
for pipe in [p for p in pipeline if p not in nlp.factories]:
|
||
msg.fail("Pipeline component '{}' not available in factories".format(pipe))
|
||
if base_model:
|
||
msg.text("Starting with base model '{}'".format(base_model))
|
||
else:
|
||
msg.text("Starting with blank model '{}'".format(lang))
|
||
msg.text("{} training docs".format(len(train_docs)))
|
||
msg.text("{} evaluation docs".format(len(dev_docs)))
|
||
|
||
overlap = len(train_texts.intersection(dev_texts))
|
||
if overlap:
|
||
msg.warn("{} training examples also in evaluation data".format(overlap))
|
||
else:
|
||
msg.good("No overlap between training and evaluation data")
|
||
if not base_model and len(train_docs) < BLANK_MODEL_THRESHOLD:
|
||
text = "Low number of examples to train from a blank model ({})".format(
|
||
len(train_docs)
|
||
)
|
||
if len(train_docs) < BLANK_MODEL_MIN_THRESHOLD:
|
||
msg.fail(text)
|
||
else:
|
||
msg.warn(text)
|
||
msg.text(
|
||
"It's recommended to use at least {} examples (minimum {})".format(
|
||
BLANK_MODEL_THRESHOLD, BLANK_MODEL_MIN_THRESHOLD
|
||
),
|
||
show=verbose,
|
||
)
|
||
|
||
msg.divider("Vocab & Vectors")
|
||
n_words = gold_data["n_words"]
|
||
msg.info(
|
||
"{} total {} in the data ({} unique)".format(
|
||
n_words, "word" if n_words == 1 else "words", len(gold_data["words"])
|
||
)
|
||
)
|
||
most_common_words = gold_data["words"].most_common(10)
|
||
msg.text(
|
||
"10 most common words: {}".format(
|
||
_format_labels(most_common_words, counts=True)
|
||
),
|
||
show=verbose,
|
||
)
|
||
if len(nlp.vocab.vectors):
|
||
msg.info(
|
||
"{} vectors ({} unique keys, {} dimensions)".format(
|
||
len(nlp.vocab.vectors),
|
||
nlp.vocab.vectors.n_keys,
|
||
nlp.vocab.vectors_length,
|
||
)
|
||
)
|
||
else:
|
||
msg.info("No word vectors present in the model")
|
||
|
||
if "ner" in pipeline:
|
||
# Get all unique NER labels present in the data
|
||
labels = set(label for label in gold_data["ner"] if label not in ("O", "-"))
|
||
label_counts = gold_data["ner"]
|
||
model_labels = _get_labels_from_model(nlp, "ner")
|
||
new_labels = [l for l in labels if l not in model_labels]
|
||
existing_labels = [l for l in labels if l in model_labels]
|
||
has_low_data_warning = False
|
||
has_no_neg_warning = False
|
||
|
||
msg.divider("Named Entity Recognition")
|
||
msg.info(
|
||
"{} new {}, {} existing {}".format(
|
||
len(new_labels),
|
||
"label" if len(new_labels) == 1 else "labels",
|
||
len(existing_labels),
|
||
"label" if len(existing_labels) == 1 else "labels",
|
||
)
|
||
)
|
||
missing_values = label_counts["-"]
|
||
msg.text(
|
||
"{} missing {} (tokens with '-' label)".format(
|
||
missing_values, "value" if missing_values == 1 else "values"
|
||
)
|
||
)
|
||
if new_labels:
|
||
labels_with_counts = [
|
||
(label, count)
|
||
for label, count in label_counts.most_common()
|
||
if label != "-"
|
||
]
|
||
labels_with_counts = _format_labels(labels_with_counts, counts=True)
|
||
msg.text("New: {}".format(labels_with_counts), show=verbose)
|
||
if existing_labels:
|
||
msg.text(
|
||
"Existing: {}".format(_format_labels(existing_labels)), show=verbose
|
||
)
|
||
|
||
for label in new_labels:
|
||
if label_counts[label] <= NEW_LABEL_THRESHOLD:
|
||
msg.warn(
|
||
"Low number of examples for new label '{}' ({})".format(
|
||
label, label_counts[label]
|
||
)
|
||
)
|
||
has_low_data_warning = True
|
||
|
||
with msg.loading("Analyzing label distribution..."):
|
||
neg_docs = _get_examples_without_label(train_docs, label)
|
||
if neg_docs == 0:
|
||
msg.warn(
|
||
"No examples for texts WITHOUT new label '{}'".format(label)
|
||
)
|
||
has_no_neg_warning = True
|
||
|
||
if not has_low_data_warning:
|
||
msg.good("Good amount of examples for all labels")
|
||
if not has_no_neg_warning:
|
||
msg.good("Examples without occurences available for all labels")
|
||
|
||
if has_low_data_warning:
|
||
msg.text(
|
||
"To train a new entity type, your data should include at "
|
||
"least {} insteances of the new label".format(NEW_LABEL_THRESHOLD),
|
||
show=verbose,
|
||
)
|
||
if has_no_neg_warning:
|
||
msg.text(
|
||
"Training data should always include examples of entities "
|
||
"in context, as well as examples without a given entity "
|
||
"type.",
|
||
show=verbose,
|
||
)
|
||
|
||
if "textcat" in pipeline:
|
||
msg.divider("Text Classification")
|
||
labels = [label for label in gold_data["textcat"]]
|
||
model_labels = _get_labels_from_model(nlp, "textcat")
|
||
new_labels = [l for l in labels if l not in model_labels]
|
||
existing_labels = [l for l in labels if l in model_labels]
|
||
msg.info(
|
||
"Text Classification: {} new label(s), {} existing label(s)".format(
|
||
len(new_labels), len(existing_labels)
|
||
)
|
||
)
|
||
if new_labels:
|
||
labels_with_counts = _format_labels(
|
||
gold_data["textcat"].most_common(), counts=True
|
||
)
|
||
msg.text("New: {}".format(labels_with_counts), show=verbose)
|
||
if existing_labels:
|
||
msg.text(
|
||
"Existing: {}".format(_format_labels(existing_labels)), show=verbose
|
||
)
|
||
|
||
if "tagger" in pipeline:
|
||
msg.divider("Part-of-speech Tagging")
|
||
labels = [label for label in gold_data["tags"]]
|
||
tag_map = nlp.Defaults.tag_map
|
||
msg.info(
|
||
"{} {} in data ({} {} in tag map)".format(
|
||
len(labels),
|
||
"label" if len(labels) == 1 else "labels",
|
||
len(tag_map),
|
||
"label" if len(tag_map) == 1 else "labels",
|
||
)
|
||
)
|
||
labels_with_counts = _format_labels(
|
||
gold_data["tags"].most_common(), counts=True
|
||
)
|
||
msg.text(labels_with_counts, show=verbose)
|
||
non_tagmap = [l for l in labels if l not in tag_map]
|
||
if not non_tagmap:
|
||
msg.good("All labels present in tag map for language '{}'".format(nlp.lang))
|
||
for label in non_tagmap:
|
||
msg.fail(
|
||
"Label '{}' not found in tag map for language '{}'".format(
|
||
label, nlp.lang
|
||
)
|
||
)
|
||
|
||
if "parser" in pipeline:
|
||
msg.divider("Dependency Parsing")
|
||
labels = [label for label in gold_data["deps"]]
|
||
msg.info(
|
||
"{} {} in data".format(
|
||
len(labels), "label" if len(labels) == 1 else "labels"
|
||
)
|
||
)
|
||
labels_with_counts = _format_labels(
|
||
gold_data["deps"].most_common(), counts=True
|
||
)
|
||
msg.text(labels_with_counts, show=verbose)
|
||
|
||
msg.divider("Summary")
|
||
good_counts = msg.counts[MESSAGES.GOOD]
|
||
warn_counts = msg.counts[MESSAGES.WARN]
|
||
fail_counts = msg.counts[MESSAGES.FAIL]
|
||
if good_counts:
|
||
msg.good(
|
||
"{} {} passed".format(
|
||
good_counts, "check" if good_counts == 1 else "checks"
|
||
)
|
||
)
|
||
if warn_counts:
|
||
msg.warn(
|
||
"{} {}".format(warn_counts, "warning" if warn_counts == 1 else "warnings")
|
||
)
|
||
if fail_counts:
|
||
msg.fail("{} {}".format(fail_counts, "error" if fail_counts == 1 else "errors"))
|
||
|
||
if fail_counts:
|
||
sys.exit(1)
|
||
|
||
|
||
def _load_file(file_path, msg):
|
||
file_name = file_path.parts[-1]
|
||
if file_path.suffix == ".json":
|
||
data = srsly.read_json(file_path)
|
||
msg.good("Loaded {}".format(file_name))
|
||
return data
|
||
elif file_path.suffix == ".jsonl":
|
||
data = srsly.read_jsonl(file_path)
|
||
msg.good("Loaded {}".format(file_name))
|
||
return data
|
||
msg.fail(
|
||
"Can't load file extension {}".format(file_path.suffix),
|
||
"Expected .json or .jsonl",
|
||
exits=1,
|
||
)
|
||
|
||
|
||
def _compile_gold(train_docs, pipeline):
|
||
data = {
|
||
"ner": Counter(),
|
||
"cats": Counter(),
|
||
"tags": Counter(),
|
||
"deps": Counter(),
|
||
"words": Counter(),
|
||
"n_words": 0,
|
||
"texts": set(),
|
||
}
|
||
for doc, gold in train_docs:
|
||
data["words"].update(gold.words)
|
||
data["n_words"] += len(gold.words)
|
||
data["texts"].add(doc.text)
|
||
if "ner" in pipeline:
|
||
for label in gold.ner:
|
||
if label.startswith(("B-", "U-")):
|
||
combined_label = label.split("-")[1]
|
||
data["ner"][combined_label] += 1
|
||
elif label == "-":
|
||
data["ner"]["-"] += 1
|
||
if "textcat" in pipeline:
|
||
data["cats"].update(gold.cats)
|
||
if "tagger" in pipeline:
|
||
data["tags"].update(gold.tags)
|
||
if "parser" in pipeline:
|
||
data["deps"].update(gold.labels)
|
||
return data
|
||
|
||
|
||
def _format_labels(labels, counts=False):
|
||
if counts:
|
||
return ", ".join(["'{}' ({})".format(l, c) for l, c in labels])
|
||
return ", ".join(["'{}'".format(l) for l in labels])
|
||
|
||
|
||
def _get_ner_counts(data):
|
||
counter = Counter()
|
||
for doc, gold in data:
|
||
for label in gold.ner:
|
||
if label.startswith(("B-", "U-")):
|
||
combined_label = label.split("-")[1]
|
||
counter[combined_label] += 1
|
||
elif label == "-":
|
||
counter["-"] += 1
|
||
return counter
|
||
|
||
|
||
def _get_examples_without_label(data, label):
|
||
count = 0
|
||
for doc, gold in data:
|
||
labels = [label.split("-")[1] for label in gold.ner if label not in ("O", "-")]
|
||
if label not in labels:
|
||
count += 1
|
||
return count
|
||
|
||
|
||
def _get_labels_from_model(nlp, pipe_name):
|
||
if pipe_name not in nlp.pipe_names:
|
||
return set()
|
||
pipe = nlp.get_pipe(pipe_name)
|
||
return pipe.labels
|