Add spacy-span-analyzer to debug data (#10668)

* Rename to spans_key for consistency

* Implement spans length in debug data

* Implement how span bounds and spans are obtained

In this commit, I implemented how span boundaries (the tokens) around a
given span and spans are obtained. I've put them in the compile_gold()
function so that it's accessible later on. I will do the actual
computation of the span and boundary distinctiveness in the main
function above.

* Compute for p_spans and p_bounds

* Add computation for SD and BD

* Fix mypy issues

* Add weighted average computation

* Fix compile_gold conditional logic

* Add test for frequency distribution computation

* Add tests for kl-divergence computation

* Fix weighted average computation

* Make tables more compact by rounding them

* Add more descriptive checks for spans

* Modularize span computation methods

In this commit, I added the _get_span_characteristics and
_print_span_characteristics functions so that they can be reusable
anywhere.

* Remove unnecessary arguments and make fxs more compact

* Update a few parameter arguments

* Add tests for print_span and get_span methods

* Update API to talk about span characteristics in brief

* Add better reporting of spans_length

* Add test for span length reporting

* Update formatting of span length report

Removed '' to indicate that it's not a string, then
sort the n-grams by their length, not by their frequency.

* Apply suggestions from code review

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Show all frequency distribution when -V

In this commit, I displayed the full frequency distribution of the
span lengths when --verbose is passed. To make things simpler, I
rewrote some of the formatter functions so that I can call them
whenever.

Another notable change is that instead of showing percentages as
Integers, I showed them as floats (max 2-decimal places). I did this
because it looks weird when it displays (0%).

* Update logic on how total is computed

The way the 90% thresholding is computed now is that we keep
adding the percentages until we reach >= 90%. I also updated the wording
and used the term "At least" to denote that >= 90% of your spans have
these distributions.

* Fix display when showing the threshold percentage

* Apply suggestions from code review

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Add better phrasing for span information

* Update spacy/cli/debug_data.py

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Add minor edits for whitespaces etc.

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
Lj Miranda 2022-05-24 01:06:38 +08:00 committed by GitHub
parent 7ce3460b23
commit 1d34aa2b3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 415 additions and 5 deletions

View File

@ -6,6 +6,7 @@ import sys
import srsly
from wasabi import Printer, MESSAGES, msg
import typer
import math
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code, debug_cli
@ -30,6 +31,12 @@ DEP_LABEL_THRESHOLD = 20
# Minimum number of expected examples to train a new pipeline
BLANK_MODEL_MIN_THRESHOLD = 100
BLANK_MODEL_THRESHOLD = 2000
# Arbitrary threshold where SpanCat performs well
SPAN_DISTINCT_THRESHOLD = 1
# Arbitrary threshold where SpanCat performs well
BOUNDARY_DISTINCT_THRESHOLD = 1
# Arbitrary threshold for filtering span lengths during reporting (percentage)
SPAN_LENGTH_THRESHOLD_PERCENTAGE = 90
@debug_cli.command(
@ -247,6 +254,69 @@ def debug_data(
msg.warn(f"No examples for texts WITHOUT new label '{label}'")
has_no_neg_warning = True
with msg.loading("Obtaining span characteristics..."):
span_characteristics = _get_span_characteristics(
train_dataset, gold_train_data, spans_key
)
msg.info(f"Span characteristics for spans_key '{spans_key}'")
msg.info("SD = Span Distinctiveness, BD = Boundary Distinctiveness")
_print_span_characteristics(span_characteristics)
_span_freqs = _get_spans_length_freq_dist(
gold_train_data["spans_length"][spans_key]
)
_filtered_span_freqs = _filter_spans_length_freq_dist(
_span_freqs, threshold=SPAN_LENGTH_THRESHOLD_PERCENTAGE
)
msg.info(
f"Over {SPAN_LENGTH_THRESHOLD_PERCENTAGE}% of spans have lengths of 1 -- "
f"{max(_filtered_span_freqs.keys())} "
f"(min={span_characteristics['min_length']}, max={span_characteristics['max_length']}). "
f"The most common span lengths are: {_format_freqs(_filtered_span_freqs)}. "
"If you are using the n-gram suggester, note that omitting "
"infrequent n-gram lengths can greatly improve speed and "
"memory usage."
)
msg.text(
f"Full distribution of span lengths: {_format_freqs(_span_freqs)}",
show=verbose,
)
# Add report regarding span characteristics
if span_characteristics["avg_sd"] < SPAN_DISTINCT_THRESHOLD:
msg.warn("Spans may not be distinct from the rest of the corpus")
else:
msg.good("Spans are distinct from the rest of the corpus")
p_spans = span_characteristics["p_spans"].values()
all_span_tokens: Counter = sum(p_spans, Counter())
most_common_spans = [w for w, _ in all_span_tokens.most_common(10)]
msg.text(
"10 most common span tokens: {}".format(
_format_labels(most_common_spans)
),
show=verbose,
)
# Add report regarding span boundary characteristics
if span_characteristics["avg_bd"] < BOUNDARY_DISTINCT_THRESHOLD:
msg.warn("Boundary tokens are not distinct from the rest of the corpus")
else:
msg.good("Boundary tokens are distinct from the rest of the corpus")
p_bounds = span_characteristics["p_bounds"].values()
all_span_bound_tokens: Counter = sum(p_bounds, Counter())
most_common_bounds = [w for w, _ in all_span_bound_tokens.most_common(10)]
msg.text(
"10 most common span boundary tokens: {}".format(
_format_labels(most_common_bounds)
),
show=verbose,
)
if has_low_data_warning:
msg.text(
f"To train a new span type, your data should include at "
@ -647,6 +717,9 @@ def _compile_gold(
"words": Counter(),
"roots": Counter(),
"spancat": dict(),
"spans_length": dict(),
"spans_per_type": dict(),
"sb_per_type": dict(),
"ws_ents": 0,
"boundary_cross_ents": 0,
"n_words": 0,
@ -692,14 +765,59 @@ def _compile_gold(
elif label == "-":
data["ner"]["-"] += 1
if "spancat" in factory_names:
for span_key in list(eg.reference.spans.keys()):
if span_key not in data["spancat"]:
data["spancat"][span_key] = Counter()
for i, span in enumerate(eg.reference.spans[span_key]):
for spans_key in list(eg.reference.spans.keys()):
# Obtain the span frequency
if spans_key not in data["spancat"]:
data["spancat"][spans_key] = Counter()
for i, span in enumerate(eg.reference.spans[spans_key]):
if span.label_ is None:
continue
else:
data["spancat"][span_key][span.label_] += 1
data["spancat"][spans_key][span.label_] += 1
# Obtain the span length
if spans_key not in data["spans_length"]:
data["spans_length"][spans_key] = dict()
for span in gold.spans[spans_key]:
if span.label_ is None:
continue
if span.label_ not in data["spans_length"][spans_key]:
data["spans_length"][spans_key][span.label_] = []
data["spans_length"][spans_key][span.label_].append(len(span))
# Obtain spans per span type
if spans_key not in data["spans_per_type"]:
data["spans_per_type"][spans_key] = dict()
for span in gold.spans[spans_key]:
if span.label_ not in data["spans_per_type"][spans_key]:
data["spans_per_type"][spans_key][span.label_] = []
data["spans_per_type"][spans_key][span.label_].append(span)
# Obtain boundary tokens per span type
window_size = 1
if spans_key not in data["sb_per_type"]:
data["sb_per_type"][spans_key] = dict()
for span in gold.spans[spans_key]:
if span.label_ not in data["sb_per_type"][spans_key]:
# Creating a data structure that holds the start and
# end tokens for each span type
data["sb_per_type"][spans_key][span.label_] = {
"start": [],
"end": [],
}
for offset in range(window_size):
sb_start_idx = span.start - (offset + 1)
if sb_start_idx >= 0:
data["sb_per_type"][spans_key][span.label_]["start"].append(
gold[sb_start_idx : sb_start_idx + 1]
)
sb_end_idx = span.end + (offset + 1)
if sb_end_idx <= len(gold):
data["sb_per_type"][spans_key][span.label_]["end"].append(
gold[sb_end_idx - 1 : sb_end_idx]
)
if "textcat" in factory_names or "textcat_multilabel" in factory_names:
data["cats"].update(gold.cats)
if any(val not in (0, 1) for val in gold.cats.values()):
@ -770,6 +888,16 @@ def _format_labels(
return ", ".join([f"'{l}'" for l in cast(Iterable[str], labels)])
def _format_freqs(freqs: Dict[int, float], sort: bool = True) -> str:
if sort:
freqs = dict(sorted(freqs.items()))
_freqs = [(str(k), v) for k, v in freqs.items()]
return ", ".join(
[f"{l} ({c}%)" for l, c in cast(Iterable[Tuple[str, float]], _freqs)]
)
def _get_examples_without_label(
data: Sequence[Example],
label: str,
@ -824,3 +952,158 @@ def _get_labels_from_spancat(nlp: Language) -> Dict[str, Set[str]]:
labels[pipe.key] = set()
labels[pipe.key].update(pipe.labels)
return labels
def _gmean(l: List) -> float:
"""Compute geometric mean of a list"""
return math.exp(math.fsum(math.log(i) for i in l) / len(l))
def _wgt_average(metric: Dict[str, float], frequencies: Counter) -> float:
total = sum(value * frequencies[span_type] for span_type, value in metric.items())
return total / sum(frequencies.values())
def _get_distribution(docs, normalize: bool = True) -> Counter:
"""Get the frequency distribution given a set of Docs"""
word_counts: Counter = Counter()
for doc in docs:
for token in doc:
# Normalize the text
t = token.text.lower().replace("``", '"').replace("''", '"')
word_counts[t] += 1
if normalize:
total = sum(word_counts.values(), 0.0)
word_counts = Counter({k: v / total for k, v in word_counts.items()})
return word_counts
def _get_kl_divergence(p: Counter, q: Counter) -> float:
"""Compute the Kullback-Leibler divergence from two frequency distributions"""
total = 0.0
for word, p_word in p.items():
total += p_word * math.log(p_word / q[word])
return total
def _format_span_row(span_data: List[Dict], labels: List[str]) -> List[Any]:
"""Compile into one list for easier reporting"""
d = {
label: [label] + list(round(d[label], 2) for d in span_data) for label in labels
}
return list(d.values())
def _get_span_characteristics(
examples: List[Example], compiled_gold: Dict[str, Any], spans_key: str
) -> Dict[str, Any]:
"""Obtain all span characteristics"""
data_labels = compiled_gold["spancat"][spans_key]
# Get lengths
span_length = {
label: _gmean(l)
for label, l in compiled_gold["spans_length"][spans_key].items()
}
min_lengths = [min(l) for l in compiled_gold["spans_length"][spans_key].values()]
max_lengths = [max(l) for l in compiled_gold["spans_length"][spans_key].values()]
# Get relevant distributions: corpus, spans, span boundaries
p_corpus = _get_distribution([eg.reference for eg in examples], normalize=True)
p_spans = {
label: _get_distribution(spans, normalize=True)
for label, spans in compiled_gold["spans_per_type"][spans_key].items()
}
p_bounds = {
label: _get_distribution(sb["start"] + sb["end"], normalize=True)
for label, sb in compiled_gold["sb_per_type"][spans_key].items()
}
# Compute for actual span characteristics
span_distinctiveness = {
label: _get_kl_divergence(freq_dist, p_corpus)
for label, freq_dist in p_spans.items()
}
sb_distinctiveness = {
label: _get_kl_divergence(freq_dist, p_corpus)
for label, freq_dist in p_bounds.items()
}
return {
"sd": span_distinctiveness,
"bd": sb_distinctiveness,
"lengths": span_length,
"min_length": min(min_lengths),
"max_length": max(max_lengths),
"avg_sd": _wgt_average(span_distinctiveness, data_labels),
"avg_bd": _wgt_average(sb_distinctiveness, data_labels),
"avg_length": _wgt_average(span_length, data_labels),
"labels": list(data_labels.keys()),
"p_spans": p_spans,
"p_bounds": p_bounds,
}
def _print_span_characteristics(span_characteristics: Dict[str, Any]):
"""Print all span characteristics into a table"""
headers = ("Span Type", "Length", "SD", "BD")
# Prepare table data with all span characteristics
table_data = [
span_characteristics["lengths"],
span_characteristics["sd"],
span_characteristics["bd"],
]
table = _format_span_row(
span_data=table_data, labels=span_characteristics["labels"]
)
# Prepare table footer with weighted averages
footer_data = [
span_characteristics["avg_length"],
span_characteristics["avg_sd"],
span_characteristics["avg_bd"],
]
footer = ["Wgt. Average"] + [str(round(f, 2)) for f in footer_data]
msg.table(table, footer=footer, header=headers, divider=True)
def _get_spans_length_freq_dist(
length_dict: Dict, threshold=SPAN_LENGTH_THRESHOLD_PERCENTAGE
) -> Dict[int, float]:
"""Get frequency distribution of spans length under a certain threshold"""
all_span_lengths = []
for _, lengths in length_dict.items():
all_span_lengths.extend(lengths)
freq_dist: Counter = Counter()
for i in all_span_lengths:
if freq_dist.get(i):
freq_dist[i] += 1
else:
freq_dist[i] = 1
# We will be working with percentages instead of raw counts
freq_dist_percentage = {}
for span_length, count in freq_dist.most_common():
percentage = (count / len(all_span_lengths)) * 100.0
percentage = round(percentage, 2)
freq_dist_percentage[span_length] = percentage
return freq_dist_percentage
def _filter_spans_length_freq_dist(
freq_dist: Dict[int, float], threshold: int
) -> Dict[int, float]:
"""Filter frequency distribution with respect to a threshold
We're going to filter all the span lengths that fall
around a percentage threshold when summed.
"""
total = 0.0
filtered_freq_dist = {}
for span_length, dist in freq_dist.items():
if total >= threshold:
break
else:
filtered_freq_dist[span_length] = dist
total += dist
return filtered_freq_dist

View File

@ -1,4 +1,7 @@
import os
import math
from random import sample
from typing import Counter
import pytest
import srsly
@ -14,6 +17,10 @@ from spacy.cli._util import substitute_project_variables
from spacy.cli._util import validate_project_commands
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
from spacy.cli.debug_data import _get_labels_from_spancat
from spacy.cli.debug_data import _get_distribution, _get_kl_divergence
from spacy.cli.debug_data import _get_span_characteristics
from spacy.cli.debug_data import _print_span_characteristics
from spacy.cli.debug_data import _get_spans_length_freq_dist
from spacy.cli.download import get_compatibility, get_version
from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
from spacy.cli.package import get_third_party_dependencies
@ -24,6 +31,7 @@ from spacy.lang.nl import Dutch
from spacy.language import Language
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
from spacy.tokens import Doc
from spacy.tokens.span import Span
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
from spacy.training.converters import iob_to_docs
@ -740,3 +748,110 @@ def test_debug_data_compile_gold():
eg = Example(pred, ref)
data = _compile_gold([eg], ["ner"], nlp, True)
assert data["boundary_cross_ents"] == 1
def test_debug_data_compile_gold_for_spans():
nlp = English()
spans_key = "sc"
pred = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."])
pred.spans[spans_key] = [Span(pred, 3, 6, "ORG"), Span(pred, 5, 6, "GPE")]
ref = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."])
ref.spans[spans_key] = [Span(ref, 3, 6, "ORG"), Span(ref, 5, 6, "GPE")]
eg = Example(pred, ref)
data = _compile_gold([eg], ["spancat"], nlp, True)
assert data["spancat"][spans_key] == Counter({"ORG": 1, "GPE": 1})
assert data["spans_length"][spans_key] == {"ORG": [3], "GPE": [1]}
assert data["spans_per_type"][spans_key] == {
"ORG": [Span(ref, 3, 6, "ORG")],
"GPE": [Span(ref, 5, 6, "GPE")],
}
assert data["sb_per_type"][spans_key] == {
"ORG": {"start": [ref[2:3]], "end": [ref[6:7]]},
"GPE": {"start": [ref[4:5]], "end": [ref[6:7]]},
}
def test_frequency_distribution_is_correct():
nlp = English()
docs = [
Doc(nlp.vocab, words=["Bank", "of", "China"]),
Doc(nlp.vocab, words=["China"]),
]
expected = Counter({"china": 0.5, "bank": 0.25, "of": 0.25})
freq_distribution = _get_distribution(docs, normalize=True)
assert freq_distribution == expected
def test_kl_divergence_computation_is_correct():
p = Counter({"a": 0.5, "b": 0.25})
q = Counter({"a": 0.25, "b": 0.50, "c": 0.15, "d": 0.10})
result = _get_kl_divergence(p, q)
expected = 0.1733
assert math.isclose(result, expected, rel_tol=1e-3)
def test_get_span_characteristics_return_value():
nlp = English()
spans_key = "sc"
pred = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."])
pred.spans[spans_key] = [Span(pred, 3, 6, "ORG"), Span(pred, 5, 6, "GPE")]
ref = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."])
ref.spans[spans_key] = [Span(ref, 3, 6, "ORG"), Span(ref, 5, 6, "GPE")]
eg = Example(pred, ref)
examples = [eg]
data = _compile_gold(examples, ["spancat"], nlp, True)
span_characteristics = _get_span_characteristics(
examples=examples, compiled_gold=data, spans_key=spans_key
)
assert {"sd", "bd", "lengths"}.issubset(span_characteristics.keys())
assert span_characteristics["min_length"] == 1
assert span_characteristics["max_length"] == 3
def test_ensure_print_span_characteristics_wont_fail():
"""Test if interface between two methods aren't destroyed if refactored"""
nlp = English()
spans_key = "sc"
pred = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."])
pred.spans[spans_key] = [Span(pred, 3, 6, "ORG"), Span(pred, 5, 6, "GPE")]
ref = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."])
ref.spans[spans_key] = [Span(ref, 3, 6, "ORG"), Span(ref, 5, 6, "GPE")]
eg = Example(pred, ref)
examples = [eg]
data = _compile_gold(examples, ["spancat"], nlp, True)
span_characteristics = _get_span_characteristics(
examples=examples, compiled_gold=data, spans_key=spans_key
)
_print_span_characteristics(span_characteristics)
@pytest.mark.parametrize("threshold", [70, 80, 85, 90, 95])
def test_span_length_freq_dist_threshold_must_be_correct(threshold):
sample_span_lengths = {
"span_type_1": [1, 4, 4, 5],
"span_type_2": [5, 3, 3, 2],
"span_type_3": [3, 1, 3, 3],
}
span_freqs = _get_spans_length_freq_dist(sample_span_lengths, threshold)
assert sum(span_freqs.values()) >= threshold
def test_span_length_freq_dist_output_must_be_correct():
sample_span_lengths = {
"span_type_1": [1, 4, 4, 5],
"span_type_2": [5, 3, 3, 2],
"span_type_3": [3, 1, 3, 3],
}
threshold = 90
span_freqs = _get_spans_length_freq_dist(sample_span_lengths, threshold)
assert sum(span_freqs.values()) >= threshold
assert list(span_freqs.keys()) == [3, 1, 4, 5, 2]

View File

@ -466,6 +466,18 @@ takes the same arguments as `train` and reads settings off the
</Infobox>
<Infobox title="Notes on span characteristics" emoji="💡">
If your pipeline contains a `spancat` component, then this command will also
report span characteristics such as the average span length and the span (or
span boundary) distinctiveness. The distinctiveness measure shows how different
the tokens are with respect to the rest of the corpus using the KL-divergence of
the token distributions. To learn more, you can check out Papay et al.'s work on
[*Dissecting Span Identification Tasks with Performance Prediction* (EMNLP
2020)](https://aclanthology.org/2020.emnlp-main.396/).
</Infobox>
```cli
$ python -m spacy debug data [config_path] [--code] [--ignore-warnings] [--verbose] [--no-format] [overrides]
```