From 1d34aa2b3dd1ba0931dcb1863dfbeba6ae5b912d Mon Sep 17 00:00:00 2001 From: Lj Miranda <12949683+ljvmiranda921@users.noreply.github.com> Date: Tue, 24 May 2022 01:06:38 +0800 Subject: [PATCH] Add spacy-span-analyzer to debug data (#10668) * Rename to spans_key for consistency * Implement spans length in debug data * Implement how span bounds and spans are obtained In this commit, I implemented how span boundaries (the tokens) around a given span and spans are obtained. I've put them in the compile_gold() function so that it's accessible later on. I will do the actual computation of the span and boundary distinctiveness in the main function above. * Compute for p_spans and p_bounds * Add computation for SD and BD * Fix mypy issues * Add weighted average computation * Fix compile_gold conditional logic * Add test for frequency distribution computation * Add tests for kl-divergence computation * Fix weighted average computation * Make tables more compact by rounding them * Add more descriptive checks for spans * Modularize span computation methods In this commit, I added the _get_span_characteristics and _print_span_characteristics functions so that they can be reusable anywhere. * Remove unnecessary arguments and make fxs more compact * Update a few parameter arguments * Add tests for print_span and get_span methods * Update API to talk about span characteristics in brief * Add better reporting of spans_length * Add test for span length reporting * Update formatting of span length report Removed '' to indicate that it's not a string, then sort the n-grams by their length, not by their frequency. * Apply suggestions from code review Co-authored-by: Adriane Boyd * Show all frequency distribution when -V In this commit, I displayed the full frequency distribution of the span lengths when --verbose is passed. To make things simpler, I rewrote some of the formatter functions so that I can call them whenever. Another notable change is that instead of showing percentages as Integers, I showed them as floats (max 2-decimal places). I did this because it looks weird when it displays (0%). * Update logic on how total is computed The way the 90% thresholding is computed now is that we keep adding the percentages until we reach >= 90%. I also updated the wording and used the term "At least" to denote that >= 90% of your spans have these distributions. * Fix display when showing the threshold percentage * Apply suggestions from code review Co-authored-by: Adriane Boyd * Add better phrasing for span information * Update spacy/cli/debug_data.py Co-authored-by: Adriane Boyd * Add minor edits for whitespaces etc. Co-authored-by: Adriane Boyd Co-authored-by: Adriane Boyd --- spacy/cli/debug_data.py | 293 +++++++++++++++++++++++++++++++++++++++- spacy/tests/test_cli.py | 115 ++++++++++++++++ website/docs/api/cli.md | 12 ++ 3 files changed, 415 insertions(+), 5 deletions(-) diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py index f94319d1d..0061515c6 100644 --- a/spacy/cli/debug_data.py +++ b/spacy/cli/debug_data.py @@ -6,6 +6,7 @@ import sys import srsly from wasabi import Printer, MESSAGES, msg import typer +import math from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides from ._util import import_code, debug_cli @@ -30,6 +31,12 @@ DEP_LABEL_THRESHOLD = 20 # Minimum number of expected examples to train a new pipeline BLANK_MODEL_MIN_THRESHOLD = 100 BLANK_MODEL_THRESHOLD = 2000 +# Arbitrary threshold where SpanCat performs well +SPAN_DISTINCT_THRESHOLD = 1 +# Arbitrary threshold where SpanCat performs well +BOUNDARY_DISTINCT_THRESHOLD = 1 +# Arbitrary threshold for filtering span lengths during reporting (percentage) +SPAN_LENGTH_THRESHOLD_PERCENTAGE = 90 @debug_cli.command( @@ -247,6 +254,69 @@ def debug_data( msg.warn(f"No examples for texts WITHOUT new label '{label}'") has_no_neg_warning = True + with msg.loading("Obtaining span characteristics..."): + span_characteristics = _get_span_characteristics( + train_dataset, gold_train_data, spans_key + ) + + msg.info(f"Span characteristics for spans_key '{spans_key}'") + msg.info("SD = Span Distinctiveness, BD = Boundary Distinctiveness") + _print_span_characteristics(span_characteristics) + + _span_freqs = _get_spans_length_freq_dist( + gold_train_data["spans_length"][spans_key] + ) + _filtered_span_freqs = _filter_spans_length_freq_dist( + _span_freqs, threshold=SPAN_LENGTH_THRESHOLD_PERCENTAGE + ) + + msg.info( + f"Over {SPAN_LENGTH_THRESHOLD_PERCENTAGE}% of spans have lengths of 1 -- " + f"{max(_filtered_span_freqs.keys())} " + f"(min={span_characteristics['min_length']}, max={span_characteristics['max_length']}). " + f"The most common span lengths are: {_format_freqs(_filtered_span_freqs)}. " + "If you are using the n-gram suggester, note that omitting " + "infrequent n-gram lengths can greatly improve speed and " + "memory usage." + ) + + msg.text( + f"Full distribution of span lengths: {_format_freqs(_span_freqs)}", + show=verbose, + ) + + # Add report regarding span characteristics + if span_characteristics["avg_sd"] < SPAN_DISTINCT_THRESHOLD: + msg.warn("Spans may not be distinct from the rest of the corpus") + else: + msg.good("Spans are distinct from the rest of the corpus") + + p_spans = span_characteristics["p_spans"].values() + all_span_tokens: Counter = sum(p_spans, Counter()) + most_common_spans = [w for w, _ in all_span_tokens.most_common(10)] + msg.text( + "10 most common span tokens: {}".format( + _format_labels(most_common_spans) + ), + show=verbose, + ) + + # Add report regarding span boundary characteristics + if span_characteristics["avg_bd"] < BOUNDARY_DISTINCT_THRESHOLD: + msg.warn("Boundary tokens are not distinct from the rest of the corpus") + else: + msg.good("Boundary tokens are distinct from the rest of the corpus") + + p_bounds = span_characteristics["p_bounds"].values() + all_span_bound_tokens: Counter = sum(p_bounds, Counter()) + most_common_bounds = [w for w, _ in all_span_bound_tokens.most_common(10)] + msg.text( + "10 most common span boundary tokens: {}".format( + _format_labels(most_common_bounds) + ), + show=verbose, + ) + if has_low_data_warning: msg.text( f"To train a new span type, your data should include at " @@ -647,6 +717,9 @@ def _compile_gold( "words": Counter(), "roots": Counter(), "spancat": dict(), + "spans_length": dict(), + "spans_per_type": dict(), + "sb_per_type": dict(), "ws_ents": 0, "boundary_cross_ents": 0, "n_words": 0, @@ -692,14 +765,59 @@ def _compile_gold( elif label == "-": data["ner"]["-"] += 1 if "spancat" in factory_names: - for span_key in list(eg.reference.spans.keys()): - if span_key not in data["spancat"]: - data["spancat"][span_key] = Counter() - for i, span in enumerate(eg.reference.spans[span_key]): + for spans_key in list(eg.reference.spans.keys()): + # Obtain the span frequency + if spans_key not in data["spancat"]: + data["spancat"][spans_key] = Counter() + for i, span in enumerate(eg.reference.spans[spans_key]): if span.label_ is None: continue else: - data["spancat"][span_key][span.label_] += 1 + data["spancat"][spans_key][span.label_] += 1 + + # Obtain the span length + if spans_key not in data["spans_length"]: + data["spans_length"][spans_key] = dict() + for span in gold.spans[spans_key]: + if span.label_ is None: + continue + if span.label_ not in data["spans_length"][spans_key]: + data["spans_length"][spans_key][span.label_] = [] + data["spans_length"][spans_key][span.label_].append(len(span)) + + # Obtain spans per span type + if spans_key not in data["spans_per_type"]: + data["spans_per_type"][spans_key] = dict() + for span in gold.spans[spans_key]: + if span.label_ not in data["spans_per_type"][spans_key]: + data["spans_per_type"][spans_key][span.label_] = [] + data["spans_per_type"][spans_key][span.label_].append(span) + + # Obtain boundary tokens per span type + window_size = 1 + if spans_key not in data["sb_per_type"]: + data["sb_per_type"][spans_key] = dict() + for span in gold.spans[spans_key]: + if span.label_ not in data["sb_per_type"][spans_key]: + # Creating a data structure that holds the start and + # end tokens for each span type + data["sb_per_type"][spans_key][span.label_] = { + "start": [], + "end": [], + } + for offset in range(window_size): + sb_start_idx = span.start - (offset + 1) + if sb_start_idx >= 0: + data["sb_per_type"][spans_key][span.label_]["start"].append( + gold[sb_start_idx : sb_start_idx + 1] + ) + + sb_end_idx = span.end + (offset + 1) + if sb_end_idx <= len(gold): + data["sb_per_type"][spans_key][span.label_]["end"].append( + gold[sb_end_idx - 1 : sb_end_idx] + ) + if "textcat" in factory_names or "textcat_multilabel" in factory_names: data["cats"].update(gold.cats) if any(val not in (0, 1) for val in gold.cats.values()): @@ -770,6 +888,16 @@ def _format_labels( return ", ".join([f"'{l}'" for l in cast(Iterable[str], labels)]) +def _format_freqs(freqs: Dict[int, float], sort: bool = True) -> str: + if sort: + freqs = dict(sorted(freqs.items())) + + _freqs = [(str(k), v) for k, v in freqs.items()] + return ", ".join( + [f"{l} ({c}%)" for l, c in cast(Iterable[Tuple[str, float]], _freqs)] + ) + + def _get_examples_without_label( data: Sequence[Example], label: str, @@ -824,3 +952,158 @@ def _get_labels_from_spancat(nlp: Language) -> Dict[str, Set[str]]: labels[pipe.key] = set() labels[pipe.key].update(pipe.labels) return labels + + +def _gmean(l: List) -> float: + """Compute geometric mean of a list""" + return math.exp(math.fsum(math.log(i) for i in l) / len(l)) + + +def _wgt_average(metric: Dict[str, float], frequencies: Counter) -> float: + total = sum(value * frequencies[span_type] for span_type, value in metric.items()) + return total / sum(frequencies.values()) + + +def _get_distribution(docs, normalize: bool = True) -> Counter: + """Get the frequency distribution given a set of Docs""" + word_counts: Counter = Counter() + for doc in docs: + for token in doc: + # Normalize the text + t = token.text.lower().replace("``", '"').replace("''", '"') + word_counts[t] += 1 + if normalize: + total = sum(word_counts.values(), 0.0) + word_counts = Counter({k: v / total for k, v in word_counts.items()}) + return word_counts + + +def _get_kl_divergence(p: Counter, q: Counter) -> float: + """Compute the Kullback-Leibler divergence from two frequency distributions""" + total = 0.0 + for word, p_word in p.items(): + total += p_word * math.log(p_word / q[word]) + return total + + +def _format_span_row(span_data: List[Dict], labels: List[str]) -> List[Any]: + """Compile into one list for easier reporting""" + d = { + label: [label] + list(round(d[label], 2) for d in span_data) for label in labels + } + return list(d.values()) + + +def _get_span_characteristics( + examples: List[Example], compiled_gold: Dict[str, Any], spans_key: str +) -> Dict[str, Any]: + """Obtain all span characteristics""" + data_labels = compiled_gold["spancat"][spans_key] + # Get lengths + span_length = { + label: _gmean(l) + for label, l in compiled_gold["spans_length"][spans_key].items() + } + min_lengths = [min(l) for l in compiled_gold["spans_length"][spans_key].values()] + max_lengths = [max(l) for l in compiled_gold["spans_length"][spans_key].values()] + + # Get relevant distributions: corpus, spans, span boundaries + p_corpus = _get_distribution([eg.reference for eg in examples], normalize=True) + p_spans = { + label: _get_distribution(spans, normalize=True) + for label, spans in compiled_gold["spans_per_type"][spans_key].items() + } + p_bounds = { + label: _get_distribution(sb["start"] + sb["end"], normalize=True) + for label, sb in compiled_gold["sb_per_type"][spans_key].items() + } + + # Compute for actual span characteristics + span_distinctiveness = { + label: _get_kl_divergence(freq_dist, p_corpus) + for label, freq_dist in p_spans.items() + } + sb_distinctiveness = { + label: _get_kl_divergence(freq_dist, p_corpus) + for label, freq_dist in p_bounds.items() + } + + return { + "sd": span_distinctiveness, + "bd": sb_distinctiveness, + "lengths": span_length, + "min_length": min(min_lengths), + "max_length": max(max_lengths), + "avg_sd": _wgt_average(span_distinctiveness, data_labels), + "avg_bd": _wgt_average(sb_distinctiveness, data_labels), + "avg_length": _wgt_average(span_length, data_labels), + "labels": list(data_labels.keys()), + "p_spans": p_spans, + "p_bounds": p_bounds, + } + + +def _print_span_characteristics(span_characteristics: Dict[str, Any]): + """Print all span characteristics into a table""" + headers = ("Span Type", "Length", "SD", "BD") + # Prepare table data with all span characteristics + table_data = [ + span_characteristics["lengths"], + span_characteristics["sd"], + span_characteristics["bd"], + ] + table = _format_span_row( + span_data=table_data, labels=span_characteristics["labels"] + ) + # Prepare table footer with weighted averages + footer_data = [ + span_characteristics["avg_length"], + span_characteristics["avg_sd"], + span_characteristics["avg_bd"], + ] + footer = ["Wgt. Average"] + [str(round(f, 2)) for f in footer_data] + msg.table(table, footer=footer, header=headers, divider=True) + + +def _get_spans_length_freq_dist( + length_dict: Dict, threshold=SPAN_LENGTH_THRESHOLD_PERCENTAGE +) -> Dict[int, float]: + """Get frequency distribution of spans length under a certain threshold""" + all_span_lengths = [] + for _, lengths in length_dict.items(): + all_span_lengths.extend(lengths) + + freq_dist: Counter = Counter() + for i in all_span_lengths: + if freq_dist.get(i): + freq_dist[i] += 1 + else: + freq_dist[i] = 1 + + # We will be working with percentages instead of raw counts + freq_dist_percentage = {} + for span_length, count in freq_dist.most_common(): + percentage = (count / len(all_span_lengths)) * 100.0 + percentage = round(percentage, 2) + freq_dist_percentage[span_length] = percentage + + return freq_dist_percentage + + +def _filter_spans_length_freq_dist( + freq_dist: Dict[int, float], threshold: int +) -> Dict[int, float]: + """Filter frequency distribution with respect to a threshold + + We're going to filter all the span lengths that fall + around a percentage threshold when summed. + """ + total = 0.0 + filtered_freq_dist = {} + for span_length, dist in freq_dist.items(): + if total >= threshold: + break + else: + filtered_freq_dist[span_length] = dist + total += dist + return filtered_freq_dist diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 3ef56d9f6..838e00369 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -1,4 +1,7 @@ import os +import math +from random import sample +from typing import Counter import pytest import srsly @@ -14,6 +17,10 @@ from spacy.cli._util import substitute_project_variables from spacy.cli._util import validate_project_commands from spacy.cli.debug_data import _compile_gold, _get_labels_from_model from spacy.cli.debug_data import _get_labels_from_spancat +from spacy.cli.debug_data import _get_distribution, _get_kl_divergence +from spacy.cli.debug_data import _get_span_characteristics +from spacy.cli.debug_data import _print_span_characteristics +from spacy.cli.debug_data import _get_spans_length_freq_dist from spacy.cli.download import get_compatibility, get_version from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config from spacy.cli.package import get_third_party_dependencies @@ -24,6 +31,7 @@ from spacy.lang.nl import Dutch from spacy.language import Language from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate from spacy.tokens import Doc +from spacy.tokens.span import Span from spacy.training import Example, docs_to_json, offsets_to_biluo_tags from spacy.training.converters import conll_ner_to_docs, conllu_to_docs from spacy.training.converters import iob_to_docs @@ -740,3 +748,110 @@ def test_debug_data_compile_gold(): eg = Example(pred, ref) data = _compile_gold([eg], ["ner"], nlp, True) assert data["boundary_cross_ents"] == 1 + + +def test_debug_data_compile_gold_for_spans(): + nlp = English() + spans_key = "sc" + + pred = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."]) + pred.spans[spans_key] = [Span(pred, 3, 6, "ORG"), Span(pred, 5, 6, "GPE")] + ref = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."]) + ref.spans[spans_key] = [Span(ref, 3, 6, "ORG"), Span(ref, 5, 6, "GPE")] + eg = Example(pred, ref) + + data = _compile_gold([eg], ["spancat"], nlp, True) + + assert data["spancat"][spans_key] == Counter({"ORG": 1, "GPE": 1}) + assert data["spans_length"][spans_key] == {"ORG": [3], "GPE": [1]} + assert data["spans_per_type"][spans_key] == { + "ORG": [Span(ref, 3, 6, "ORG")], + "GPE": [Span(ref, 5, 6, "GPE")], + } + assert data["sb_per_type"][spans_key] == { + "ORG": {"start": [ref[2:3]], "end": [ref[6:7]]}, + "GPE": {"start": [ref[4:5]], "end": [ref[6:7]]}, + } + + +def test_frequency_distribution_is_correct(): + nlp = English() + docs = [ + Doc(nlp.vocab, words=["Bank", "of", "China"]), + Doc(nlp.vocab, words=["China"]), + ] + + expected = Counter({"china": 0.5, "bank": 0.25, "of": 0.25}) + freq_distribution = _get_distribution(docs, normalize=True) + assert freq_distribution == expected + + +def test_kl_divergence_computation_is_correct(): + p = Counter({"a": 0.5, "b": 0.25}) + q = Counter({"a": 0.25, "b": 0.50, "c": 0.15, "d": 0.10}) + result = _get_kl_divergence(p, q) + expected = 0.1733 + assert math.isclose(result, expected, rel_tol=1e-3) + + +def test_get_span_characteristics_return_value(): + nlp = English() + spans_key = "sc" + + pred = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."]) + pred.spans[spans_key] = [Span(pred, 3, 6, "ORG"), Span(pred, 5, 6, "GPE")] + ref = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."]) + ref.spans[spans_key] = [Span(ref, 3, 6, "ORG"), Span(ref, 5, 6, "GPE")] + eg = Example(pred, ref) + + examples = [eg] + data = _compile_gold(examples, ["spancat"], nlp, True) + span_characteristics = _get_span_characteristics( + examples=examples, compiled_gold=data, spans_key=spans_key + ) + + assert {"sd", "bd", "lengths"}.issubset(span_characteristics.keys()) + assert span_characteristics["min_length"] == 1 + assert span_characteristics["max_length"] == 3 + + +def test_ensure_print_span_characteristics_wont_fail(): + """Test if interface between two methods aren't destroyed if refactored""" + nlp = English() + spans_key = "sc" + + pred = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."]) + pred.spans[spans_key] = [Span(pred, 3, 6, "ORG"), Span(pred, 5, 6, "GPE")] + ref = Doc(nlp.vocab, words=["Welcome", "to", "the", "Bank", "of", "China", "."]) + ref.spans[spans_key] = [Span(ref, 3, 6, "ORG"), Span(ref, 5, 6, "GPE")] + eg = Example(pred, ref) + + examples = [eg] + data = _compile_gold(examples, ["spancat"], nlp, True) + span_characteristics = _get_span_characteristics( + examples=examples, compiled_gold=data, spans_key=spans_key + ) + _print_span_characteristics(span_characteristics) + + +@pytest.mark.parametrize("threshold", [70, 80, 85, 90, 95]) +def test_span_length_freq_dist_threshold_must_be_correct(threshold): + sample_span_lengths = { + "span_type_1": [1, 4, 4, 5], + "span_type_2": [5, 3, 3, 2], + "span_type_3": [3, 1, 3, 3], + } + span_freqs = _get_spans_length_freq_dist(sample_span_lengths, threshold) + assert sum(span_freqs.values()) >= threshold + + +def test_span_length_freq_dist_output_must_be_correct(): + sample_span_lengths = { + "span_type_1": [1, 4, 4, 5], + "span_type_2": [5, 3, 3, 2], + "span_type_3": [3, 1, 3, 3], + } + threshold = 90 + span_freqs = _get_spans_length_freq_dist(sample_span_lengths, threshold) + assert sum(span_freqs.values()) >= threshold + assert list(span_freqs.keys()) == [3, 1, 4, 5, 2] diff --git a/website/docs/api/cli.md b/website/docs/api/cli.md index e801ff0a6..dd396d0b3 100644 --- a/website/docs/api/cli.md +++ b/website/docs/api/cli.md @@ -466,6 +466,18 @@ takes the same arguments as `train` and reads settings off the + + +If your pipeline contains a `spancat` component, then this command will also +report span characteristics such as the average span length and the span (or +span boundary) distinctiveness. The distinctiveness measure shows how different +the tokens are with respect to the rest of the corpus using the KL-divergence of +the token distributions. To learn more, you can check out Papay et al.'s work on +[*Dissecting Span Identification Tasks with Performance Prediction* (EMNLP +2020)](https://aclanthology.org/2020.emnlp-main.396/). + + + ```cli $ python -m spacy debug data [config_path] [--code] [--ignore-warnings] [--verbose] [--no-format] [overrides] ```