mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-03 05:04:09 +03:00
Add spacy-span-analyzer to debug data (#10668)
* Rename to spans_key for consistency * Implement spans length in debug data * Implement how span bounds and spans are obtained In this commit, I implemented how span boundaries (the tokens) around a given span and spans are obtained. I've put them in the compile_gold() function so that it's accessible later on. I will do the actual computation of the span and boundary distinctiveness in the main function above. * Compute for p_spans and p_bounds * Add computation for SD and BD * Fix mypy issues * Add weighted average computation * Fix compile_gold conditional logic * Add test for frequency distribution computation * Add tests for kl-divergence computation * Fix weighted average computation * Make tables more compact by rounding them * Add more descriptive checks for spans * Modularize span computation methods In this commit, I added the _get_span_characteristics and _print_span_characteristics functions so that they can be reusable anywhere. * Remove unnecessary arguments and make fxs more compact * Update a few parameter arguments * Add tests for print_span and get_span methods * Update API to talk about span characteristics in brief * Add better reporting of spans_length * Add test for span length reporting * Update formatting of span length report Removed '' to indicate that it's not a string, then sort the n-grams by their length, not by their frequency. * Apply suggestions from code review Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Show all frequency distribution when -V In this commit, I displayed the full frequency distribution of the span lengths when --verbose is passed. To make things simpler, I rewrote some of the formatter functions so that I can call them whenever. Another notable change is that instead of showing percentages as Integers, I showed them as floats (max 2-decimal places). I did this because it looks weird when it displays (0%). * Update logic on how total is computed The way the 90% thresholding is computed now is that we keep adding the percentages until we reach >= 90%. I also updated the wording and used the term "At least" to denote that >= 90% of your spans have these distributions. * Fix display when showing the threshold percentage * Apply suggestions from code review Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Add better phrasing for span information * Update spacy/cli/debug_data.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Add minor edits for whitespaces etc. Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
7ce3460b23
commit
1d34aa2b3d
|
@ -6,6 +6,7 @@ import sys
|
||||||
import srsly
|
import srsly
|
||||||
from wasabi import Printer, MESSAGES, msg
|
from wasabi import Printer, MESSAGES, msg
|
||||||
import typer
|
import typer
|
||||||
|
import math
|
||||||
|
|
||||||
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
||||||
from ._util import import_code, debug_cli
|
from ._util import import_code, debug_cli
|
||||||
|
@ -30,6 +31,12 @@ DEP_LABEL_THRESHOLD = 20
|
||||||
# Minimum number of expected examples to train a new pipeline
|
# Minimum number of expected examples to train a new pipeline
|
||||||
BLANK_MODEL_MIN_THRESHOLD = 100
|
BLANK_MODEL_MIN_THRESHOLD = 100
|
||||||
BLANK_MODEL_THRESHOLD = 2000
|
BLANK_MODEL_THRESHOLD = 2000
|
||||||
|
# Arbitrary threshold where SpanCat performs well
|
||||||
|
SPAN_DISTINCT_THRESHOLD = 1
|
||||||
|
# Arbitrary threshold where SpanCat performs well
|
||||||
|
BOUNDARY_DISTINCT_THRESHOLD = 1
|
||||||
|
# Arbitrary threshold for filtering span lengths during reporting (percentage)
|
||||||
|
SPAN_LENGTH_THRESHOLD_PERCENTAGE = 90
|
||||||
|
|
||||||
|
|
||||||
@debug_cli.command(
|
@debug_cli.command(
|
||||||
|
@ -247,6 +254,69 @@ def debug_data(
|
||||||
msg.warn(f"No examples for texts WITHOUT new label '{label}'")
|
msg.warn(f"No examples for texts WITHOUT new label '{label}'")
|
||||||
has_no_neg_warning = True
|
has_no_neg_warning = True
|
||||||
|
|
||||||
|
with msg.loading("Obtaining span characteristics..."):
|
||||||
|
span_characteristics = _get_span_characteristics(
|
||||||
|
train_dataset, gold_train_data, spans_key
|
||||||
|
)
|
||||||
|
|
||||||
|
msg.info(f"Span characteristics for spans_key '{spans_key}'")
|
||||||
|
msg.info("SD = Span Distinctiveness, BD = Boundary Distinctiveness")
|
||||||
|
_print_span_characteristics(span_characteristics)
|
||||||
|
|
||||||
|
_span_freqs = _get_spans_length_freq_dist(
|
||||||
|
gold_train_data["spans_length"][spans_key]
|
||||||
|
)
|
||||||
|
_filtered_span_freqs = _filter_spans_length_freq_dist(
|
||||||
|
_span_freqs, threshold=SPAN_LENGTH_THRESHOLD_PERCENTAGE
|
||||||
|
)
|
||||||
|
|
||||||
|
msg.info(
|
||||||
|
f"Over {SPAN_LENGTH_THRESHOLD_PERCENTAGE}% of spans have lengths of 1 -- "
|
||||||
|
f"{max(_filtered_span_freqs.keys())} "
|
||||||
|
f"(min={span_characteristics['min_length']}, max={span_characteristics['max_length']}). "
|
||||||
|
f"The most common span lengths are: {_format_freqs(_filtered_span_freqs)}. "
|
||||||
|
"If you are using the n-gram suggester, note that omitting "
|
||||||
|
"infrequent n-gram lengths can greatly improve speed and "
|
||||||
|
"memory usage."
|
||||||
|
)
|
||||||
|
|
||||||
|
msg.text(
|
||||||
|
f"Full distribution of span lengths: {_format_freqs(_span_freqs)}",
|
||||||
|
show=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add report regarding span characteristics
|
||||||
|
if span_characteristics["avg_sd"] < SPAN_DISTINCT_THRESHOLD:
|
||||||
|
msg.warn("Spans may not be distinct from the rest of the corpus")
|
||||||
|
else:
|
||||||
|
msg.good("Spans are distinct from the rest of the corpus")
|
||||||
|
|
||||||
|
p_spans = span_characteristics["p_spans"].values()
|
||||||
|
all_span_tokens: Counter = sum(p_spans, Counter())
|
||||||
|
most_common_spans = [w for w, _ in all_span_tokens.most_common(10)]
|
||||||
|
msg.text(
|
||||||
|
"10 most common span tokens: {}".format(
|
||||||
|
_format_labels(most_common_spans)
|
||||||
|
),
|
||||||
|
show=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add report regarding span boundary characteristics
|
||||||
|
if span_characteristics["avg_bd"] < BOUNDARY_DISTINCT_THRESHOLD:
|
||||||
|
msg.warn("Boundary tokens are not distinct from the rest of the corpus")
|
||||||
|
else:
|
||||||
|
msg.good("Boundary tokens are distinct from the rest of the corpus")
|
||||||
|
|
||||||
|
p_bounds = span_characteristics["p_bounds"].values()
|
||||||
|
all_span_bound_tokens: Counter = sum(p_bounds, Counter())
|
||||||
|
most_common_bounds = [w for w, _ in all_span_bound_tokens.most_common(10)]
|
||||||
|
msg.text(
|
||||||
|
"10 most common span boundary tokens: {}".format(
|
||||||
|
_format_labels(most_common_bounds)
|
||||||
|
),
|
||||||
|
show=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
if has_low_data_warning:
|
if has_low_data_warning:
|
||||||
msg.text(
|
msg.text(
|
||||||
f"To train a new span type, your data should include at "
|
f"To train a new span type, your data should include at "
|
||||||
|
@ -647,6 +717,9 @@ def _compile_gold(
|
||||||
"words": Counter(),
|
"words": Counter(),
|
||||||
"roots": Counter(),
|
"roots": Counter(),
|
||||||
"spancat": dict(),
|
"spancat": dict(),
|
||||||
|
"spans_length": dict(),
|
||||||
|
"spans_per_type": dict(),
|
||||||
|
"sb_per_type": dict(),
|
||||||
"ws_ents": 0,
|
"ws_ents": 0,
|
||||||
"boundary_cross_ents": 0,
|
"boundary_cross_ents": 0,
|
||||||
"n_words": 0,
|
"n_words": 0,
|
||||||
|
@ -692,14 +765,59 @@ def _compile_gold(
|
||||||
elif label == "-":
|
elif label == "-":
|
||||||
data["ner"]["-"] += 1
|
data["ner"]["-"] += 1
|
||||||
if "spancat" in factory_names:
|
if "spancat" in factory_names:
|
||||||
for span_key in list(eg.reference.spans.keys()):
|
for spans_key in list(eg.reference.spans.keys()):
|
||||||
if span_key not in data["spancat"]:
|
# Obtain the span frequency
|
||||||
data["spancat"][span_key] = Counter()
|
if spans_key not in data["spancat"]:
|
||||||
for i, span in enumerate(eg.reference.spans[span_key]):
|
data["spancat"][spans_key] = Counter()
|
||||||
|
for i, span in enumerate(eg.reference.spans[spans_key]):
|
||||||
if span.label_ is None:
|
if span.label_ is None:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
data["spancat"][span_key][span.label_] += 1
|
data["spancat"][spans_key][span.label_] += 1
|
||||||
|
|
||||||
|
# Obtain the span length
|
||||||
|
if spans_key not in data["spans_length"]:
|
||||||
|
data["spans_length"][spans_key] = dict()
|
||||||
|
for span in gold.spans[spans_key]:
|
||||||
|
if span.label_ is None:
|
||||||
|
continue
|
||||||
|
if span.label_ not in data["spans_length"][spans_key]:
|
||||||
|
data["spans_length"][spans_key][span.label_] = []
|
||||||
|
data["spans_length"][spans_key][span.label_].append(len(span))
|
||||||
|
|
||||||
|
# Obtain spans per span type
|
||||||
|
if spans_key not in data["spans_per_type"]:
|
||||||
|
data["spans_per_type"][spans_key] = dict()
|
||||||
|
for span in gold.spans[spans_key]:
|
||||||
|
if span.label_ not in data["spans_per_type"][spans_key]:
|
||||||
|
data["spans_per_type"][spans_key][span.label_] = []
|
||||||
|
data["spans_per_type"][spans_key][span.label_].append(span)
|
||||||
|
|
||||||
|
# Obtain boundary tokens per span type
|
||||||
|
window_size = 1
|
||||||
|
if spans_key not in data["sb_per_type"]:
|
||||||
|
data["sb_per_type"][spans_key] = dict()
|
||||||
|
for span in gold.spans[spans_key]:
|
||||||
|
if span.label_ not in data["sb_per_type"][spans_key]:
|
||||||
|
# Creating a data structure that holds the start and
|
||||||
|
# end tokens for each span type
|
||||||
|
data["sb_per_type"][spans_key][span.label_] = {
|
||||||
|
"start": [],
|
||||||
|
"end": [],
|
||||||
|
}
|
||||||
|
for offset in range(window_size):
|
||||||
|
sb_start_idx = span.start - (offset + 1)
|
||||||
|
if sb_start_idx >= 0:
|
||||||
|
data["sb_per_type"][spans_key][span.label_]["start"].append(
|
||||||
|
gold[sb_start_idx : sb_start_idx + 1]
|
||||||
|
)
|
||||||
|
|
||||||
|
sb_end_idx = span.end + (offset + 1)
|
||||||
|
if sb_end_idx <= len(gold):
|
||||||
|
data["sb_per_type"][spans_key][span.label_]["end"].append(
|
||||||
|
gold[sb_end_idx - 1 : sb_end_idx]
|
||||||
|
)
|
||||||
|
|
||||||
if "textcat" in factory_names or "textcat_multilabel" in factory_names:
|
if "textcat" in factory_names or "textcat_multilabel" in factory_names:
|
||||||
data["cats"].update(gold.cats)
|
data["cats"].update(gold.cats)
|
||||||
if any(val not in (0, 1) for val in gold.cats.values()):
|
if any(val not in (0, 1) for val in gold.cats.values()):
|
||||||
|
@ -770,6 +888,16 @@ def _format_labels(
|
||||||
return ", ".join([f"'{l}'" for l in cast(Iterable[str], labels)])
|
return ", ".join([f"'{l}'" for l in cast(Iterable[str], labels)])
|
||||||
|
|
||||||
|
|
||||||
|
def _format_freqs(freqs: Dict[int, float], sort: bool = True) -> str:
|
||||||
|
if sort:
|
||||||
|
freqs = dict(sorted(freqs.items()))
|
||||||
|
|
||||||
|
_freqs = [(str(k), v) for k, v in freqs.items()]
|
||||||
|
return ", ".join(
|
||||||
|
[f"{l} ({c}%)" for l, c in cast(Iterable[Tuple[str, float]], _freqs)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_examples_without_label(
|
def _get_examples_without_label(
|
||||||
data: Sequence[Example],
|
data: Sequence[Example],
|
||||||
label: str,
|
label: str,
|
||||||
|
@ -824,3 +952,158 @@ def _get_labels_from_spancat(nlp: Language) -> Dict[str, Set[str]]:
|
||||||
labels[pipe.key] = set()
|
labels[pipe.key] = set()
|
||||||
labels[pipe.key].update(pipe.labels)
|
labels[pipe.key].update(pipe.labels)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
def _gmean(l: List) -> float:
|
||||||
|
"""Compute geometric mean of a list"""
|
||||||
|
return math.exp(math.fsum(math.log(i) for i in l) / len(l))
|
||||||
|
|
||||||
|
|
||||||
|
def _wgt_average(metric: Dict[str, float], frequencies: Counter) -> float:
|
||||||
|
total = sum(value * frequencies[span_type] for span_type, value in metric.items())
|
||||||
|
return total / sum(frequencies.values())
|
||||||
|
|
||||||
|
|
||||||
|
def _get_distribution(docs, normalize: bool = True) -> Counter:
|
||||||
|
"""Get the frequency distribution given a set of Docs"""
|
||||||
|
word_counts: Counter = Counter()
|
||||||
|
for doc in docs:
|
||||||
|
for token in doc:
|
||||||
|
# Normalize the text
|
||||||
|
t = token.text.lower().replace("``", '"').replace("''", '"')
|
||||||
|
word_counts[t] += 1
|
||||||
|
if normalize:
|
||||||
|
total = sum(word_counts.values(), 0.0)
|
||||||
|
word_counts = Counter({k: v / total for k, v in word_counts.items()})
|
||||||
|
return word_counts
|
||||||
|
|
||||||
|
|
||||||
|
def _get_kl_divergence(p: Counter, q: Counter) -> float:
|
||||||
|
"""Compute the Kullback-Leibler divergence from two frequency distributions"""
|
||||||
|
total = 0.0
|
||||||
|
for word, p_word in p.items():
|
||||||
|
total += p_word * math.log(p_word / q[word])
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
def _format_span_row(span_data: List[Dict], labels: List[str]) -> List[Any]:
|
||||||
|
"""Compile into one list for easier reporting"""
|
||||||
|
d = {
|
||||||
|
label: [label] + list(round(d[label], 2) for d in span_data) for label in labels
|
||||||
|
}
|
||||||
|
return list(d.values())
|
||||||
|
|
||||||
|
|
||||||
|
def _get_span_characteristics(
|
||||||
|
examples: List[Example], compiled_gold: Dict[str, Any], spans_key: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Obtain all span characteristics"""
|
||||||
|
data_labels = compiled_gold["spancat"][spans_key]
|
||||||
|
# Get lengths
|
||||||
|
span_length = {
|
||||||
|
label: _gmean(l)
|
||||||
|
for label, l in compiled_gold["spans_length"][spans_key].items()
|
||||||
|
}
|
||||||
|
min_lengths = [min(l) for l in compiled_gold["spans_length"][spans_key].values()]
|
||||||
|
max_lengths = [max(l) for l in compiled_gold["spans_length"][spans_key].values()]
|
||||||
|
|
||||||
|
# Get relevant distributions: corpus, spans, span boundaries
|
||||||
|
p_corpus = _get_distribution([eg.reference for eg in examples], normalize=True)
|
||||||
|
p_spans = {
|
||||||
|
label: _get_distribution(spans, normalize=True)
|
||||||
|
for label, spans in compiled_gold["spans_per_type"][spans_key].items()
|
||||||
|
}
|
||||||
|
p_bounds = {
|
||||||
|
label: _get_distribution(sb["start"] + sb["end"], normalize=True)
|
||||||
|
for label, sb in compiled_gold["sb_per_type"][spans_key].items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Compute for actual span characteristics
|
||||||
|
span_distinctiveness = {
|
||||||
|
label: _get_kl_divergence(freq_dist, p_corpus)
|
||||||
|
for label, freq_dist in p_spans.items()
|
||||||
|
}
|
||||||
|
sb_distinctiveness = {
|
||||||
|
label: _get_kl_divergence(freq_dist, p_corpus)
|
||||||
|
for label, freq_dist in p_bounds.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"sd": span_distinctiveness,
|
||||||
|
"bd": sb_distinctiveness,
|
||||||
|
"lengths": span_length,
|
||||||
|
"min_length": min(min_lengths),
|
||||||
|
"max_length": max(max_lengths),
|
||||||
|
"avg_sd": _wgt_average(span_distinctiveness, data_labels),
|
||||||
|
"avg_bd": _wgt_average(sb_distinctiveness, data_labels),
|
||||||
|
"avg_length": _wgt_average(span_length, data_labels),
|
||||||
|
"labels": list(data_labels.keys()),
|
||||||
|
"p_spans": p_spans,
|
||||||
|
"p_bounds": p_bounds,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _print_span_characteristics(span_characteristics: Dict[str, Any]):
|
||||||
|
"""Print all span characteristics into a table"""
|
||||||
|
headers = ("Span Type", "Length", "SD", "BD")
|
||||||
|
# Prepare table data with all span characteristics
|
||||||
|
table_data = [
|
||||||
|
span_characteristics["lengths"],
|
||||||
|
span_characteristics["sd"],
|
||||||
|
span_characteristics["bd"],
|
||||||
|
]
|
||||||
|
table = _format_span_row(
|
||||||
|
span_data=table_data, labels=span_characteristics["labels"]
|
||||||
|
)
|
||||||
|
# Prepare table footer with weighted averages
|
||||||
|
footer_data = [
|
||||||
|
span_characteristics["avg_length"],
|
||||||
|
span_characteristics["avg_sd"],
|
||||||
|
span_characteristics["avg_bd"],
|
||||||
|
]
|
||||||
|
footer = ["Wgt. Average"] + [str(round(f, 2)) for f in footer_data]
|
||||||
|
msg.table(table, footer=footer, header=headers, divider=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_spans_length_freq_dist(
|
||||||
|
length_dict: Dict, threshold=SPAN_LENGTH_THRESHOLD_PERCENTAGE
|
||||||
|
) -> Dict[int, float]:
|
||||||
|
"""Get frequency distribution of spans length under a certain threshold"""
|
||||||
|
all_span_lengths = []
|
||||||
|
for _, lengths in length_dict.items():
|
||||||
|
all_span_lengths.extend(lengths)
|
||||||
|
|
||||||
|
freq_dist: Counter = Counter()
|
||||||
|
for i in all_span_lengths:
|
||||||
|
if freq_dist.get(i):
|
||||||
|
freq_dist[i] += 1
|
||||||
|
else:
|
||||||
|
freq_dist[i] = 1
|
||||||
|
|
||||||
|
# We will be working with percentages instead of raw counts
|
||||||
|
freq_dist_percentage = {}
|
||||||
|
for span_length, count in freq_dist.most_common():
|
||||||
|
percentage = (count / len(all_span_lengths)) * 100.0
|
||||||
|
percentage = round(percentage, 2)
|
||||||
|
freq_dist_percentage[span_length] = percentage
|
||||||
|
|
||||||
|
return freq_dist_percentage
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_spans_length_freq_dist(
|
||||||
|
freq_dist: Dict[int, float], threshold: int
|
||||||
|
) -> Dict[int, float]:
|
||||||
|
"""Filter frequency distribution with respect to a threshold
|
||||||
|
|
||||||
|
We're going to filter all the span lengths that fall
|
||||||
|
around a percentage threshold when summed.
|
||||||
|
"""
|
||||||
|
total = 0.0
|
||||||
|
filtered_freq_dist = {}
|
||||||
|
for span_length, dist in freq_dist.items():
|
||||||
|
if total >= threshold:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
filtered_freq_dist[span_length] = dist
|
||||||
|
total += dist
|
||||||
|
return filtered_freq_dist
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
import os
|
import os
|
||||||
|
import math
|
||||||
|
from random import sample
|
||||||
|
from typing import Counter
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import srsly
|
import srsly
|
||||||
|
@ -14,6 +17,10 @@ from spacy.cli._util import substitute_project_variables
|
||||||
from spacy.cli._util import validate_project_commands
|
from spacy.cli._util import validate_project_commands
|
||||||
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
|
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
|
||||||
from spacy.cli.debug_data import _get_labels_from_spancat
|
from spacy.cli.debug_data import _get_labels_from_spancat
|
||||||
|
from spacy.cli.debug_data import _get_distribution, _get_kl_divergence
|
||||||
|
from spacy.cli.debug_data import _get_span_characteristics
|
||||||
|
from spacy.cli.debug_data import _print_span_characteristics
|
||||||
|
from spacy.cli.debug_data import _get_spans_length_freq_dist
|
||||||
from spacy.cli.download import get_compatibility, get_version
|
from spacy.cli.download import get_compatibility, get_version
|
||||||
from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
|
from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
|
||||||
from spacy.cli.package import get_third_party_dependencies
|
from spacy.cli.package import get_third_party_dependencies
|
||||||
|
@ -24,6 +31,7 @@ from spacy.lang.nl import Dutch
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
|
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
|
from spacy.tokens.span import Span
|
||||||
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
|
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
|
||||||
from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
|
from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
|
||||||
from spacy.training.converters import iob_to_docs
|
from spacy.training.converters import iob_to_docs
|
||||||
|
@ -740,3 +748,110 @@ def test_debug_data_compile_gold():
|
||||||
eg = Example(pred, ref)
|
eg = Example(pred, ref)
|
||||||
data = _compile_gold([eg], ["ner"], nlp, True)
|
data = _compile_gold([eg], ["ner"], nlp, True)
|
||||||
assert data["boundary_cross_ents"] == 1
|
assert data["boundary_cross_ents"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_debug_data_compile_gold_for_spans():
|
||||||
|
nlp = English()
|
||||||
|
spans_key = "sc"
|
||||||
|
|
||||||
|
pred = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."])
|
||||||
|
pred.spans[spans_key] = [Span(pred, 3, 6, "ORG"), Span(pred, 5, 6, "GPE")]
|
||||||
|
ref = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."])
|
||||||
|
ref.spans[spans_key] = [Span(ref, 3, 6, "ORG"), Span(ref, 5, 6, "GPE")]
|
||||||
|
eg = Example(pred, ref)
|
||||||
|
|
||||||
|
data = _compile_gold([eg], ["spancat"], nlp, True)
|
||||||
|
|
||||||
|
assert data["spancat"][spans_key] == Counter({"ORG": 1, "GPE": 1})
|
||||||
|
assert data["spans_length"][spans_key] == {"ORG": [3], "GPE": [1]}
|
||||||
|
assert data["spans_per_type"][spans_key] == {
|
||||||
|
"ORG": [Span(ref, 3, 6, "ORG")],
|
||||||
|
"GPE": [Span(ref, 5, 6, "GPE")],
|
||||||
|
}
|
||||||
|
assert data["sb_per_type"][spans_key] == {
|
||||||
|
"ORG": {"start": [ref[2:3]], "end": [ref[6:7]]},
|
||||||
|
"GPE": {"start": [ref[4:5]], "end": [ref[6:7]]},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_frequency_distribution_is_correct():
|
||||||
|
nlp = English()
|
||||||
|
docs = [
|
||||||
|
Doc(nlp.vocab, words=["Bank", "of", "China"]),
|
||||||
|
Doc(nlp.vocab, words=["China"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
expected = Counter({"china": 0.5, "bank": 0.25, "of": 0.25})
|
||||||
|
freq_distribution = _get_distribution(docs, normalize=True)
|
||||||
|
assert freq_distribution == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_kl_divergence_computation_is_correct():
|
||||||
|
p = Counter({"a": 0.5, "b": 0.25})
|
||||||
|
q = Counter({"a": 0.25, "b": 0.50, "c": 0.15, "d": 0.10})
|
||||||
|
result = _get_kl_divergence(p, q)
|
||||||
|
expected = 0.1733
|
||||||
|
assert math.isclose(result, expected, rel_tol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_span_characteristics_return_value():
|
||||||
|
nlp = English()
|
||||||
|
spans_key = "sc"
|
||||||
|
|
||||||
|
pred = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."])
|
||||||
|
pred.spans[spans_key] = [Span(pred, 3, 6, "ORG"), Span(pred, 5, 6, "GPE")]
|
||||||
|
ref = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."])
|
||||||
|
ref.spans[spans_key] = [Span(ref, 3, 6, "ORG"), Span(ref, 5, 6, "GPE")]
|
||||||
|
eg = Example(pred, ref)
|
||||||
|
|
||||||
|
examples = [eg]
|
||||||
|
data = _compile_gold(examples, ["spancat"], nlp, True)
|
||||||
|
span_characteristics = _get_span_characteristics(
|
||||||
|
examples=examples, compiled_gold=data, spans_key=spans_key
|
||||||
|
)
|
||||||
|
|
||||||
|
assert {"sd", "bd", "lengths"}.issubset(span_characteristics.keys())
|
||||||
|
assert span_characteristics["min_length"] == 1
|
||||||
|
assert span_characteristics["max_length"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_print_span_characteristics_wont_fail():
|
||||||
|
"""Test if interface between two methods aren't destroyed if refactored"""
|
||||||
|
nlp = English()
|
||||||
|
spans_key = "sc"
|
||||||
|
|
||||||
|
pred = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."])
|
||||||
|
pred.spans[spans_key] = [Span(pred, 3, 6, "ORG"), Span(pred, 5, 6, "GPE")]
|
||||||
|
ref = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."])
|
||||||
|
ref.spans[spans_key] = [Span(ref, 3, 6, "ORG"), Span(ref, 5, 6, "GPE")]
|
||||||
|
eg = Example(pred, ref)
|
||||||
|
|
||||||
|
examples = [eg]
|
||||||
|
data = _compile_gold(examples, ["spancat"], nlp, True)
|
||||||
|
span_characteristics = _get_span_characteristics(
|
||||||
|
examples=examples, compiled_gold=data, spans_key=spans_key
|
||||||
|
)
|
||||||
|
_print_span_characteristics(span_characteristics)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("threshold", [70, 80, 85, 90, 95])
|
||||||
|
def test_span_length_freq_dist_threshold_must_be_correct(threshold):
|
||||||
|
sample_span_lengths = {
|
||||||
|
"span_type_1": [1, 4, 4, 5],
|
||||||
|
"span_type_2": [5, 3, 3, 2],
|
||||||
|
"span_type_3": [3, 1, 3, 3],
|
||||||
|
}
|
||||||
|
span_freqs = _get_spans_length_freq_dist(sample_span_lengths, threshold)
|
||||||
|
assert sum(span_freqs.values()) >= threshold
|
||||||
|
|
||||||
|
|
||||||
|
def test_span_length_freq_dist_output_must_be_correct():
|
||||||
|
sample_span_lengths = {
|
||||||
|
"span_type_1": [1, 4, 4, 5],
|
||||||
|
"span_type_2": [5, 3, 3, 2],
|
||||||
|
"span_type_3": [3, 1, 3, 3],
|
||||||
|
}
|
||||||
|
threshold = 90
|
||||||
|
span_freqs = _get_spans_length_freq_dist(sample_span_lengths, threshold)
|
||||||
|
assert sum(span_freqs.values()) >= threshold
|
||||||
|
assert list(span_freqs.keys()) == [3, 1, 4, 5, 2]
|
||||||
|
|
|
@ -466,6 +466,18 @@ takes the same arguments as `train` and reads settings off the
|
||||||
|
|
||||||
</Infobox>
|
</Infobox>
|
||||||
|
|
||||||
|
<Infobox title="Notes on span characteristics" emoji="💡">
|
||||||
|
|
||||||
|
If your pipeline contains a `spancat` component, then this command will also
|
||||||
|
report span characteristics such as the average span length and the span (or
|
||||||
|
span boundary) distinctiveness. The distinctiveness measure shows how different
|
||||||
|
the tokens are with respect to the rest of the corpus using the KL-divergence of
|
||||||
|
the token distributions. To learn more, you can check out Papay et al.'s work on
|
||||||
|
[*Dissecting Span Identification Tasks with Performance Prediction* (EMNLP
|
||||||
|
2020)](https://aclanthology.org/2020.emnlp-main.396/).
|
||||||
|
|
||||||
|
</Infobox>
|
||||||
|
|
||||||
```cli
|
```cli
|
||||||
$ python -m spacy debug data [config_path] [--code] [--ignore-warnings] [--verbose] [--no-format] [overrides]
|
$ python -m spacy debug data [config_path] [--code] [--ignore-warnings] [--verbose] [--no-format] [overrides]
|
||||||
```
|
```
|
||||||
|
|
Loading…
Reference in New Issue
Block a user