From 913d74f5092f7a1b286724412c975b590da64fca Mon Sep 17 00:00:00 2001 From: Lj Miranda <12949683+ljvmiranda921@users.noreply.github.com> Date: Thu, 9 Mar 2023 17:30:59 +0800 Subject: [PATCH] Add spancat_singlelabel pipeline for multiclass and non-overlapping span labelling tasks (#11365) * [wip] Update * [wip] Update * Add initial port * [wip] Update * Fix all imports * Add spancat_exclusive to pipeline * [WIP] Update * [ci skip] Add breakpoint for debugging * Use spacy.SpanCategorizer.v1 as default archi * Update spacy/pipeline/spancat_exclusive.py Co-authored-by: kadarakos * [ci skip] Small updates * Use Softmax v2 directly from thinc * Cache the label map * Fix mypy errors However, I ignored line 370 because it opened up a bunch of type errors that might be trickier to solve and might lead to a more complicated codebase. * avoid multiplication with 1.0 Co-authored-by: kadarakos * Update spacy/pipeline/spancat_exclusive.py Co-authored-by: Sofie Van Landeghem * Update component versions to v2 * Add scorer to docstring * Add _n_labels property to SpanCategorizer Instead of using len(self.labels) in initialize() I am using a private property self._n_labels. This achieves implementation parity and allows me to delete the whole initialize() method for spancat_exclusive (since it's now the same with spancat). * Inherit from SpanCat instead of TrainablePipe This commit changes the inheritance structure of Exclusive_Spancat, now it's inheriting from SpanCategorizer than TrainablePipe. This allows me to remove duplicate methods that are already present in the parent function. * Revert documentation link to spancat * Fix init call for exclusive spancat * Update spacy/pipeline/spancat_exclusive.py Co-authored-by: Adriane Boyd * Import Suggester from spancat * Include zero_init.v1 for spancat * Implement _allow_extra_label to use _n_labels To ensure that spancat / spancat_exclusive cannot be resized after initialization, I inherited the _allow_extra_label() method from spacy/pipeline/trainable_pipe.pyx and used self._n_labels instead of len(self.labels) for checking. I think that changing it locally is a better solution rather than forcing each class that inherits TrainablePipe to use the self._n_labels attribute. Also note that I turned-off black formatting in this block of code because it reads better without the overhang. * Extend existing tests to spancat_exclusive In this commit, I extended the existing tests for spancat to include spancat_exclusive. I parametrized the test functions with 'name' (similar var name with textcat and textcat_multilabel) for each applicable test. TODO: Add overfitting tests for spancat_exclusive * Update documentation for spancat * Turn on formatting for allow_extra_label * Remove initializers in default config * Use DEFAULT_EXCL_SPANCAT_MODEL I also renamed spancat_exclusive_default_config into spancat_excl_default_config because black does some not pretty formatting changes. * Update documentation Update grammar and usage Co-authored-by: Adriane Boyd * Clarify docstring for Exclusive_SpanCategorizer * Remove mypy ignore and typecast labels to list * Fix documentation API * Use a single variable for tests * Update defaults for number of rows Co-authored-by: Adriane Boyd * Put back initializers in spancat config Whenever I remove model.scorer.init_w and model.scorer.init_b, I encounter an error in the test: SystemError: returned a result with an error set. My Thinc version is 8.1.5, but I can't seem to check what's causing the error. * Update spancat_exclusive docstring * Remove init_W and init_B parameters This commit is expected to fail until the new Thinc release. * Require thinc>=8.1.6 for serializable Softmax defaults * Handle zero suggestions to make tests pass I'm not sure if this is the most elegant solution. But what should happen is that the _make_span_group function MUST return an empty SpanGroup if there are no suggestions. The error happens when the 'scores' variable is empty. We cannot get the 'predicted' and other downstream vars. * Better approach for handling zero suggestions * Update website/docs/api/spancategorizer.md Co-authored-by: Adriane Boyd * Update spancategorizer headers * Apply suggestions from code review Co-authored-by: Sofie Van Landeghem * Add default value in negative_weight in docs * Add default value in allow_overlap in docs * Update how spancat_exclusive is constructed In this commit, I added the following: - Put the default values of negative_weight and allow_overlap in the default_config dictionary. - Rename make_spancat -> make_exclusive_spancat * Run prettier on spancategorizer.mdx * Change exactly one -> at most one * Add suggester documentation in Exclusive_SpanCategorizer * Add suggester to spancat docstrings * merge multilabel and singlelabel spancat * rename spancat_exclusive to singlelable * wire up different make_spangroups for single and multilabel * black * black * add docstrings * more docstring and fix negative_label * don't rely on default arguments * black * remove spancat exclusive * replace single_label with add_negative_label and adjust inference * mypy * logical bug in configuration check * add spans.attrs[scores] * single label make_spangroup test * bugfix * black * tests for make_span_group with negative labels * refactor make_span_group * black * Update spacy/tests/pipeline/test_spancat.py Co-authored-by: Adriane Boyd * remove duplicate declaration * Update spacy/pipeline/spancat.py Co-authored-by: Adriane Boyd * raise error instead of just print * make label mapper private * update docs * run prettier * Update website/docs/api/spancategorizer.mdx Co-authored-by: Adriane Boyd * Update website/docs/api/spancategorizer.mdx Co-authored-by: Adriane Boyd * Update spacy/pipeline/spancat.py Co-authored-by: Adriane Boyd * Update spacy/pipeline/spancat.py Co-authored-by: Adriane Boyd * Update spacy/pipeline/spancat.py Co-authored-by: Adriane Boyd * Update spacy/pipeline/spancat.py Co-authored-by: Adriane Boyd * don't keep recomputing self._label_map for each span * typo in docs * Intervals to private and document 'name' param * Update spacy/pipeline/spancat.py Co-authored-by: Adriane Boyd * Update spacy/pipeline/spancat.py Co-authored-by: Adriane Boyd * add Tag to new features * replace tags * revert * revert * revert * revert * Update website/docs/api/spancategorizer.mdx Co-authored-by: Adriane Boyd * Update website/docs/api/spancategorizer.mdx Co-authored-by: Adriane Boyd * prettier * Fix merge * Update website/docs/api/spancategorizer.mdx * remove references to 'single_label' * remove old paragraph * Add spancat_singlelabel to config template * Format * Extend init config tests --------- Co-authored-by: kadarakos Co-authored-by: Sofie Van Landeghem Co-authored-by: Adriane Boyd --- spacy/cli/templates/quickstart_training.jinja | 61 +++- spacy/errors.py | 1 + spacy/pipeline/spancat.py | 323 ++++++++++++++++-- spacy/tests/pipeline/test_spancat.py | 157 ++++++++- spacy/tests/test_cli.py | 9 +- website/docs/api/spancategorizer.mdx | 68 ++-- 6 files changed, 552 insertions(+), 67 deletions(-) diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja index 441189341..c5e8c6c43 100644 --- a/spacy/cli/templates/quickstart_training.jinja +++ b/spacy/cli/templates/quickstart_training.jinja @@ -3,7 +3,7 @@ the docs and the init config command. It encodes various best practices and can help generate the best possible configuration, given a user's requirements. #} {%- set use_transformer = hardware != "cpu" and transformer_data -%} {%- set transformer = transformer_data[optimize] if use_transformer else {} -%} -{%- set listener_components = ["tagger", "morphologizer", "parser", "ner", "textcat", "textcat_multilabel", "entity_linker", "spancat", "trainable_lemmatizer"] -%} +{%- set listener_components = ["tagger", "morphologizer", "parser", "ner", "textcat", "textcat_multilabel", "entity_linker", "spancat", "spancat_singlelabel", "trainable_lemmatizer"] -%} [paths] train = null dev = null @@ -28,7 +28,7 @@ lang = "{{ lang }}" tok2vec/transformer. #} {%- set with_accuracy_or_transformer = (use_transformer or with_accuracy) -%} {%- set textcat_needs_features = has_textcat and with_accuracy_or_transformer -%} -{%- if ("tagger" in components or "morphologizer" in components or "parser" in components or "ner" in components or "spancat" in components or "trainable_lemmatizer" in components or "entity_linker" in components or textcat_needs_features) -%} +{%- if ("tagger" in components or "morphologizer" in components or "parser" in components or "ner" in components or "spancat" in components or "spancat_singlelabel" in components or "trainable_lemmatizer" in components or "entity_linker" in components or textcat_needs_features) -%} {%- set full_pipeline = ["transformer" if use_transformer else "tok2vec"] + components -%} {%- else -%} {%- set full_pipeline = components -%} @@ -159,6 +159,36 @@ grad_factor = 1.0 sizes = [1,2,3] {% endif -%} +{% if "spancat_singlelabel" in components %} +[components.spancat_singlelabel] +factory = "spancat_singlelabel" +negative_weight = 1.0 +allow_overlap = true +scorer = {"@scorers":"spacy.spancat_scorer.v1"} +spans_key = "sc" + +[components.spancat_singlelabel.model] +@architectures = "spacy.SpanCategorizer.v1" + +[components.spancat_singlelabel.model.reducer] +@layers = "spacy.mean_max_reducer.v1" +hidden_size = 128 + +[components.spancat_singlelabel.model.scorer] +@layers = "Softmax.v2" + +[components.spancat_singlelabel.model.tok2vec] +@architectures = "spacy-transformers.TransformerListener.v1" +grad_factor = 1.0 + +[components.spancat_singlelabel.model.tok2vec.pooling] +@layers = "reduce_mean.v1" + +[components.spancat_singlelabel.suggester] +@misc = "spacy.ngram_suggester.v1" +sizes = [1,2,3] +{% endif %} + {% if "trainable_lemmatizer" in components -%} [components.trainable_lemmatizer] factory = "trainable_lemmatizer" @@ -389,6 +419,33 @@ width = ${components.tok2vec.model.encode.width} sizes = [1,2,3] {% endif %} +{% if "spancat_singlelabel" in components %} +[components.spancat_singlelabel] +factory = "spancat_singlelabel" +negative_weight = 1.0 +allow_overlap = true +scorer = {"@scorers":"spacy.spancat_scorer.v1"} +spans_key = "sc" + +[components.spancat_singlelabel.model] +@architectures = "spacy.SpanCategorizer.v1" + +[components.spancat_singlelabel.model.reducer] +@layers = "spacy.mean_max_reducer.v1" +hidden_size = 128 + +[components.spancat_singlelabel.model.scorer] +@layers = "Softmax.v2" + +[components.spancat_singlelabel.model.tok2vec] +@architectures = "spacy.Tok2VecListener.v1" +width = ${components.tok2vec.model.encode.width} + +[components.spancat_singlelabel.suggester] +@misc = "spacy.ngram_suggester.v1" +sizes = [1,2,3] +{% endif %} + {% if "trainable_lemmatizer" in components -%} [components.trainable_lemmatizer] factory = "trainable_lemmatizer" diff --git a/spacy/errors.py b/spacy/errors.py index 1047ed21a..c897c29ff 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -969,6 +969,7 @@ class Errors(metaclass=ErrorsWithCodes): "with `displacy.serve(doc, port=port)`") E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` " "or use `auto_select_port=True` to pick an available port automatically.") + E1051 = ("'allow_overlap' can only be False when max_positive is 1, but found 'max_positive': {max_positive}.") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index a3388e81a..983e1fba9 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -1,4 +1,5 @@ -from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any +from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union +from dataclasses import dataclass from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops from thinc.api import Optimizer from thinc.types import Ragged, Ints2d, Floats2d @@ -43,7 +44,36 @@ maxout_pieces = 3 depth = 4 """ +spancat_singlelabel_default_config = """ +[model] +@architectures = "spacy.SpanCategorizer.v1" +scorer = {"@layers": "Softmax.v2"} + +[model.reducer] +@layers = spacy.mean_max_reducer.v1 +hidden_size = 128 + +[model.tok2vec] +@architectures = "spacy.Tok2Vec.v2" +[model.tok2vec.embed] +@architectures = "spacy.MultiHashEmbed.v1" +width = 96 +rows = [5000, 1000, 2500, 1000] +attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"] +include_static_vectors = false + +[model.tok2vec.encode] +@architectures = "spacy.MaxoutWindowEncoder.v2" +width = ${model.tok2vec.embed.width} +window_size = 1 +maxout_pieces = 3 +depth = 4 +""" + DEFAULT_SPANCAT_MODEL = Config().from_str(spancat_default_config)["model"] +DEFAULT_SPANCAT_SINGLELABEL_MODEL = Config().from_str( + spancat_singlelabel_default_config +)["model"] @runtime_checkable @@ -119,10 +149,14 @@ def make_spancat( threshold: float, max_positive: Optional[int], ) -> "SpanCategorizer": - """Create a SpanCategorizer component. The span categorizer consists of two + """Create a SpanCategorizer component and configure it for multi-label + classification to be able to assign multiple labels for each span. + The span categorizer consists of two parts: a suggester function that proposes candidate spans, and a labeller model that predicts one or more labels for each span. + name (str): The component instance name, used to add entries to the + losses during training. suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans. Spans are returned as a ragged array with two integer columns, for the start and end positions. @@ -144,12 +178,80 @@ def make_spancat( """ return SpanCategorizer( nlp.vocab, - suggester=suggester, model=model, - spans_key=spans_key, - threshold=threshold, - max_positive=max_positive, + suggester=suggester, name=name, + spans_key=spans_key, + negative_weight=None, + allow_overlap=True, + max_positive=max_positive, + threshold=threshold, + scorer=scorer, + add_negative_label=False, + ) + + +@Language.factory( + "spancat_singlelabel", + assigns=["doc.spans"], + default_config={ + "spans_key": "sc", + "model": DEFAULT_SPANCAT_SINGLELABEL_MODEL, + "negative_weight": 1.0, + "suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]}, + "scorer": {"@scorers": "spacy.spancat_scorer.v1"}, + "allow_overlap": True, + }, + default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0}, +) +def make_spancat_singlelabel( + nlp: Language, + name: str, + suggester: Suggester, + model: Model[Tuple[List[Doc], Ragged], Floats2d], + spans_key: str, + negative_weight: float, + allow_overlap: bool, + scorer: Optional[Callable], +) -> "SpanCategorizer": + """Create a SpanCategorizer component and configure it for multi-class + classification. With this configuration each span can get at most one + label. The span categorizer consists of two + parts: a suggester function that proposes candidate spans, and a labeller + model that predicts one or more labels for each span. + + name (str): The component instance name, used to add entries to the + losses during training. + suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans. + Spans are returned as a ragged array with two integer columns, for the + start and end positions. + model (Model[Tuple[List[Doc], Ragged], Floats2d]): A model instance that + is given a list of documents and (start, end) indices representing + candidate span offsets. The model predicts a probability for each category + for each span. + spans_key (str): Key of the doc.spans dict to save the spans under. During + initialization and training, the component will look for spans on the + reference document under the same key. + scorer (Optional[Callable]): The scoring method. Defaults to + Scorer.score_spans for the Doc.spans[spans_key] with overlapping + spans allowed. + negative_weight (float): Multiplier for the loss terms. + Can be used to downweight the negative samples if there are too many. + allow_overlap (bool): If True the data is assumed to contain overlapping spans. + Otherwise it produces non-overlapping spans greedily prioritizing + higher assigned label scores. + """ + return SpanCategorizer( + nlp.vocab, + model=model, + suggester=suggester, + name=name, + spans_key=spans_key, + negative_weight=negative_weight, + allow_overlap=allow_overlap, + max_positive=1, + add_negative_label=True, + threshold=None, scorer=scorer, ) @@ -172,6 +274,27 @@ def make_spancat_scorer(): return spancat_score +@dataclass +class _Intervals: + """ + Helper class to avoid storing overlapping spans. + """ + + def __init__(self): + self.ranges = set() + + def add(self, i, j): + for e in range(i, j): + self.ranges.add(e) + + def __contains__(self, rang): + i, j = rang + for e in range(i, j): + if e in self.ranges: + return True + return False + + class SpanCategorizer(TrainablePipe): """Pipeline component to label spans of text. @@ -185,25 +308,43 @@ class SpanCategorizer(TrainablePipe): suggester: Suggester, name: str = "spancat", *, + add_negative_label: bool = False, spans_key: str = "spans", - threshold: float = 0.5, + negative_weight: Optional[float] = 1.0, + allow_overlap: Optional[bool] = True, max_positive: Optional[int] = None, + threshold: Optional[float] = 0.5, scorer: Optional[Callable] = spancat_score, ) -> None: - """Initialize the span categorizer. + """Initialize the multi-label or multi-class span categorizer. + vocab (Vocab): The shared vocabulary. model (thinc.api.Model): The Thinc Model powering the pipeline component. + For multi-class classification (single label per span) we recommend + using a Softmax classifier as a the final layer, while for multi-label + classification (multiple possible labels per span) we recommend Logistic. + suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans. + Spans are returned as a ragged array with two integer columns, for the + start and end positions. name (str): The component instance name, used to add entries to the losses during training. spans_key (str): Key of the Doc.spans dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"spans"`. - threshold (float): Minimum probability to consider a prediction - positive. Spans with a positive prediction will be saved on the Doc. - Defaults to 0.5. + add_negative_label (bool): Learn to predict a special 'negative_label' + when a Span is not annotated. + threshold (Optional[float]): Minimum probability to consider a prediction + positive. Defaults to 0.5. Spans with a positive prediction will be saved + on the Doc. max_positive (Optional[int]): Maximum number of labels to consider positive per span. Defaults to None, indicating no limit. + negative_weight (float): Multiplier for the loss terms. + Can be used to downweight the negative samples if there are too many + when add_negative_label is True. Otherwise its unused. + allow_overlap (bool): If True the data is assumed to contain overlapping spans. + Otherwise it produces non-overlapping spans greedily prioritizing + higher assigned label scores. Only used when max_positive is 1. scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_spans for the Doc.spans[spans_key] with overlapping spans allowed. @@ -215,12 +356,17 @@ class SpanCategorizer(TrainablePipe): "spans_key": spans_key, "threshold": threshold, "max_positive": max_positive, + "negative_weight": negative_weight, + "allow_overlap": allow_overlap, } self.vocab = vocab self.suggester = suggester self.model = model self.name = name self.scorer = scorer + self.add_negative_label = add_negative_label + if not allow_overlap and max_positive is not None and max_positive > 1: + raise ValueError(Errors.E1051.format(max_positive=max_positive)) @property def key(self) -> str: @@ -230,6 +376,21 @@ class SpanCategorizer(TrainablePipe): """ return str(self.cfg["spans_key"]) + def _allow_extra_label(self) -> None: + """Raise an error if the component can not add any more labels.""" + nO = None + if self.model.has_dim("nO"): + nO = self.model.get_dim("nO") + elif self.model.has_ref("output_layer") and self.model.get_ref( + "output_layer" + ).has_dim("nO"): + nO = self.model.get_ref("output_layer").get_dim("nO") + if nO is not None and nO == self._n_labels: + if not self.is_resizable: + raise ValueError( + Errors.E922.format(name=self.name, nO=self.model.get_dim("nO")) + ) + def add_label(self, label: str) -> int: """Add a new label to the pipe. @@ -263,6 +424,27 @@ class SpanCategorizer(TrainablePipe): """ return list(self.labels) + @property + def _label_map(self) -> Dict[str, int]: + """RETURNS (Dict[str, int]): The label map.""" + return {label: i for i, label in enumerate(self.labels)} + + @property + def _n_labels(self) -> int: + """RETURNS (int): Number of labels.""" + if self.add_negative_label: + return len(self.labels) + 1 + else: + return len(self.labels) + + @property + def _negative_label_i(self) -> Union[int, None]: + """RETURNS (Union[int, None]): Index of the negative label.""" + if self.add_negative_label: + return len(self.label_data) + else: + return None + def predict(self, docs: Iterable[Doc]): """Apply the pipeline's model to a batch of docs, without modifying them. @@ -304,14 +486,24 @@ class SpanCategorizer(TrainablePipe): DOCS: https://spacy.io/api/spancategorizer#set_annotations """ - labels = self.labels indices, scores = indices_scores offset = 0 for i, doc in enumerate(docs): indices_i = indices[i].dataXd - doc.spans[self.key] = self._make_span_group( - doc, indices_i, scores[offset : offset + indices.lengths[i]], labels # type: ignore[arg-type] - ) + allow_overlap = cast(bool, self.cfg["allow_overlap"]) + if self.cfg["max_positive"] == 1: + doc.spans[self.key] = self._make_span_group_singlelabel( + doc, + indices_i, + scores[offset : offset + indices.lengths[i]], + allow_overlap, + ) + else: + doc.spans[self.key] = self._make_span_group_multilabel( + doc, + indices_i, + scores[offset : offset + indices.lengths[i]], + ) offset += indices.lengths[i] def update( @@ -371,9 +563,11 @@ class SpanCategorizer(TrainablePipe): spans = Ragged( self.model.ops.to_numpy(spans.data), self.model.ops.to_numpy(spans.lengths) ) - label_map = {label: i for i, label in enumerate(self.labels)} target = numpy.zeros(scores.shape, dtype=scores.dtype) + if self.add_negative_label: + negative_spans = numpy.ones((scores.shape[0])) offset = 0 + label_map = self._label_map for i, eg in enumerate(examples): # Map (start, end) offset of spans to the row in the d_scores array, # so that we can adjust the gradient for predictions that were @@ -390,10 +584,16 @@ class SpanCategorizer(TrainablePipe): row = spans_index[key] k = label_map[gold_span.label_] target[row, k] = 1.0 + if self.add_negative_label: + # delete negative label target. + negative_spans[row] = 0.0 # The target is a flat array for all docs. Track the position # we're at within the flat array. offset += spans.lengths[i] target = self.model.ops.asarray(target, dtype="f") # type: ignore + if self.add_negative_label: + negative_samples = numpy.nonzero(negative_spans)[0] + target[negative_samples, self._negative_label_i] = 1.0 # type: ignore # The target will have the values 0 (for untrue predictions) or 1 # (for true predictions). # The scores should be in the range [0, 1]. @@ -402,6 +602,10 @@ class SpanCategorizer(TrainablePipe): # If the prediction is 0.9 and it's false, the gradient will be # 0.9 (0.9 - 0.0) d_scores = scores - target + if self.add_negative_label: + neg_weight = cast(float, self.cfg["negative_weight"]) + if neg_weight != 1.0: + d_scores[negative_samples] *= neg_weight loss = float((d_scores**2).sum()) return loss, d_scores @@ -438,7 +642,7 @@ class SpanCategorizer(TrainablePipe): if subbatch: docs = [eg.x for eg in subbatch] spans = build_ngram_suggester(sizes=[1])(docs) - Y = self.model.ops.alloc2f(spans.dataXd.shape[0], len(self.labels)) + Y = self.model.ops.alloc2f(spans.dataXd.shape[0], self._n_labels) self.model.initialize(X=(docs, spans), Y=Y) else: self.model.initialize() @@ -452,31 +656,96 @@ class SpanCategorizer(TrainablePipe): eg.reference.spans.get(self.key, []), allow_overlap=True ) - def _make_span_group( - self, doc: Doc, indices: Ints2d, scores: Floats2d, labels: List[str] + def _make_span_group_multilabel( + self, + doc: Doc, + indices: Ints2d, + scores: Floats2d, ) -> SpanGroup: + """Find the top-k labels for each span (k=max_positive).""" spans = SpanGroup(doc, name=self.key) - max_positive = self.cfg["max_positive"] + if scores.size == 0: + return spans + scores = self.model.ops.to_numpy(scores) + indices = self.model.ops.to_numpy(indices) threshold = self.cfg["threshold"] + max_positive = self.cfg["max_positive"] keeps = scores >= threshold - ranked = (scores * -1).argsort() # type: ignore if max_positive is not None: assert isinstance(max_positive, int) + if self.add_negative_label: + negative_scores = numpy.copy(scores[:, self._negative_label_i]) + scores[:, self._negative_label_i] = -numpy.inf + ranked = (scores * -1).argsort() # type: ignore + scores[:, self._negative_label_i] = negative_scores + else: + ranked = (scores * -1).argsort() # type: ignore span_filter = ranked[:, max_positive:] for i, row in enumerate(span_filter): keeps[i, row] = False - spans.attrs["scores"] = scores[keeps].flatten() - - indices = self.model.ops.to_numpy(indices) - keeps = self.model.ops.to_numpy(keeps) + attrs_scores = [] for i in range(indices.shape[0]): start = indices[i, 0] end = indices[i, 1] - for j, keep in enumerate(keeps[i]): if keep: - spans.append(Span(doc, start, end, label=labels[j])) + if j != self._negative_label_i: + spans.append(Span(doc, start, end, label=self.labels[j])) + attrs_scores.append(scores[i, j]) + spans.attrs["scores"] = numpy.array(attrs_scores) + return spans + + def _make_span_group_singlelabel( + self, + doc: Doc, + indices: Ints2d, + scores: Floats2d, + allow_overlap: bool = True, + ) -> SpanGroup: + """Find the argmax label for each span.""" + # Handle cases when there are zero suggestions + if scores.size == 0: + return SpanGroup(doc, name=self.key) + scores = self.model.ops.to_numpy(scores) + indices = self.model.ops.to_numpy(indices) + predicted = scores.argmax(axis=1) + argmax_scores = numpy.take_along_axis( + scores, numpy.expand_dims(predicted, 1), axis=1 + ) + keeps = numpy.ones(predicted.shape, dtype=bool) + # Remove samples where the negative label is the argmax. + if self.add_negative_label: + keeps = numpy.logical_and(keeps, predicted != self._negative_label_i) + # Filter samples according to threshold. + threshold = self.cfg["threshold"] + if threshold is not None: + keeps = numpy.logical_and(keeps, (argmax_scores >= threshold).squeeze()) + # Sort spans according to argmax probability + if not allow_overlap: + # Get the probabilities + sort_idx = (argmax_scores.squeeze() * -1).argsort() + predicted = predicted[sort_idx] + indices = indices[sort_idx] + keeps = keeps[sort_idx] + seen = _Intervals() + spans = SpanGroup(doc, name=self.key) + attrs_scores = [] + for i in range(indices.shape[0]): + if not keeps[i]: + continue + + label = predicted[i] + start = indices[i, 0] + end = indices[i, 1] + + if not allow_overlap: + if (start, end) in seen: + continue + else: + seen.add(start, end) + attrs_scores.append(argmax_scores[i]) + spans.append(Span(doc, start, end, label=self.labels[label])) return spans diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index e9db983d3..cf6304042 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -15,6 +15,8 @@ OPS = get_current_ops() SPAN_KEY = "labeled_spans" +SPANCAT_COMPONENTS = ["spancat", "spancat_singlelabel"] + TRAIN_DATA = [ ("Who is Shaka Khan?", {"spans": {SPAN_KEY: [(7, 17, "PERSON")]}}), ( @@ -41,38 +43,42 @@ def make_examples(nlp, data=TRAIN_DATA): return train_examples -def test_no_label(): +@pytest.mark.parametrize("name", SPANCAT_COMPONENTS) +def test_no_label(name): nlp = Language() - nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) + nlp.add_pipe(name, config={"spans_key": SPAN_KEY}) with pytest.raises(ValueError): nlp.initialize() -def test_no_resize(): +@pytest.mark.parametrize("name", SPANCAT_COMPONENTS) +def test_no_resize(name): nlp = Language() - spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) + spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY}) spancat.add_label("Thing") spancat.add_label("Phrase") assert spancat.labels == ("Thing", "Phrase") nlp.initialize() - assert spancat.model.get_dim("nO") == 2 + assert spancat.model.get_dim("nO") == spancat._n_labels # this throws an error because the spancat can't be resized after initialization with pytest.raises(ValueError): spancat.add_label("Stuff") -def test_implicit_labels(): +@pytest.mark.parametrize("name", SPANCAT_COMPONENTS) +def test_implicit_labels(name): nlp = Language() - spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) + spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY}) assert len(spancat.labels) == 0 train_examples = make_examples(nlp) nlp.initialize(get_examples=lambda: train_examples) assert spancat.labels == ("PERSON", "LOC") -def test_explicit_labels(): +@pytest.mark.parametrize("name", SPANCAT_COMPONENTS) +def test_explicit_labels(name): nlp = Language() - spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) + spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY}) assert len(spancat.labels) == 0 spancat.add_label("PERSON") spancat.add_label("LOC") @@ -102,13 +108,13 @@ def test_doc_gc(): # XXX This fails with length 0 sometimes assert len(spangroup) > 0 with pytest.raises(RuntimeError): - span = spangroup[0] + spangroup[0] @pytest.mark.parametrize( "max_positive,nr_results", [(None, 4), (1, 2), (2, 3), (3, 4), (4, 4)] ) -def test_make_spangroup(max_positive, nr_results): +def test_make_spangroup_multilabel(max_positive, nr_results): fix_random_seed(0) nlp = Language() spancat = nlp.add_pipe( @@ -120,10 +126,12 @@ def test_make_spangroup(max_positive, nr_results): indices = ngram_suggester([doc])[0].dataXd assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]])) labels = ["Thing", "City", "Person", "GreatCity"] + for label in labels: + spancat.add_label(label) scores = numpy.asarray( [[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f" ) - spangroup = spancat._make_span_group(doc, indices, scores, labels) + spangroup = spancat._make_span_group_multilabel(doc, indices, scores) assert len(spangroup) == nr_results # first span is always the second token "London" @@ -154,6 +162,118 @@ def test_make_spangroup(max_positive, nr_results): assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5) +@pytest.mark.parametrize( + "threshold,allow_overlap,nr_results", + [(0.05, True, 3), (0.05, False, 1), (0.5, True, 2), (0.5, False, 1)], +) +def test_make_spangroup_singlelabel(threshold, allow_overlap, nr_results): + fix_random_seed(0) + nlp = Language() + spancat = nlp.add_pipe( + "spancat", + config={ + "spans_key": SPAN_KEY, + "threshold": threshold, + "max_positive": 1, + }, + ) + doc = nlp.make_doc("Greater London") + ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2]) + indices = ngram_suggester([doc])[0].dataXd + assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]])) + labels = ["Thing", "City", "Person", "GreatCity"] + for label in labels: + spancat.add_label(label) + scores = numpy.asarray( + [[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f" + ) + spangroup = spancat._make_span_group_singlelabel( + doc, indices, scores, allow_overlap + ) + assert len(spangroup) == nr_results + if threshold > 0.4: + if allow_overlap: + assert spangroup[0].text == "London" + assert spangroup[0].label_ == "City" + assert spangroup[1].text == "Greater London" + assert spangroup[1].label_ == "GreatCity" + + else: + assert spangroup[0].text == "Greater London" + assert spangroup[0].label_ == "GreatCity" + else: + if allow_overlap: + assert spangroup[0].text == "Greater" + assert spangroup[0].label_ == "City" + assert spangroup[1].text == "London" + assert spangroup[1].label_ == "City" + assert spangroup[2].text == "Greater London" + assert spangroup[2].label_ == "GreatCity" + else: + assert spangroup[0].text == "Greater London" + + +def test_make_spangroup_negative_label(): + fix_random_seed(0) + nlp_single = Language() + nlp_multi = Language() + spancat_single = nlp_single.add_pipe( + "spancat", + config={ + "spans_key": SPAN_KEY, + "threshold": 0.1, + "max_positive": 1, + }, + ) + spancat_multi = nlp_multi.add_pipe( + "spancat", + config={ + "spans_key": SPAN_KEY, + "threshold": 0.1, + "max_positive": 2, + }, + ) + spancat_single.add_negative_label = True + spancat_multi.add_negative_label = True + doc = nlp_single.make_doc("Greater London") + labels = ["Thing", "City", "Person", "GreatCity"] + for label in labels: + spancat_multi.add_label(label) + spancat_single.add_label(label) + ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2]) + indices = ngram_suggester([doc])[0].dataXd + assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]])) + scores = numpy.asarray( + [ + [0.2, 0.4, 0.3, 0.1, 0.1], + [0.1, 0.6, 0.2, 0.4, 0.9], + [0.8, 0.7, 0.3, 0.9, 0.1], + ], + dtype="f", + ) + spangroup_multi = spancat_multi._make_span_group_multilabel(doc, indices, scores) + spangroup_single = spancat_single._make_span_group_singlelabel(doc, indices, scores) + assert len(spangroup_single) == 2 + assert spangroup_single[0].text == "Greater" + assert spangroup_single[0].label_ == "City" + assert spangroup_single[1].text == "Greater London" + assert spangroup_single[1].label_ == "GreatCity" + + assert len(spangroup_multi) == 6 + assert spangroup_multi[0].text == "Greater" + assert spangroup_multi[0].label_ == "City" + assert spangroup_multi[1].text == "Greater" + assert spangroup_multi[1].label_ == "Person" + assert spangroup_multi[2].text == "London" + assert spangroup_multi[2].label_ == "City" + assert spangroup_multi[3].text == "London" + assert spangroup_multi[3].label_ == "GreatCity" + assert spangroup_multi[4].text == "Greater London" + assert spangroup_multi[4].label_ == "Thing" + assert spangroup_multi[5].text == "Greater London" + assert spangroup_multi[5].label_ == "GreatCity" + + def test_ngram_suggester(en_tokenizer): # test different n-gram lengths for size in [1, 2, 3]: @@ -371,9 +491,9 @@ def test_overfitting_IO_overlapping(): assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"} -def test_zero_suggestions(): +@pytest.mark.parametrize("name", SPANCAT_COMPONENTS) +def test_zero_suggestions(name): # Test with a suggester that can return 0 suggestions - @registry.misc("test_mixed_zero_suggester") def make_mixed_zero_suggester(): def mixed_zero_suggester(docs, *, ops=None): @@ -400,7 +520,7 @@ def test_zero_suggestions(): fix_random_seed(0) nlp = English() spancat = nlp.add_pipe( - "spancat", + name, config={ "suggester": {"@misc": "test_mixed_zero_suggester"}, "spans_key": SPAN_KEY, @@ -408,7 +528,7 @@ def test_zero_suggestions(): ) train_examples = make_examples(nlp) optimizer = nlp.initialize(get_examples=lambda: train_examples) - assert spancat.model.get_dim("nO") == 2 + assert spancat.model.get_dim("nO") == spancat._n_labels assert set(spancat.labels) == {"LOC", "PERSON"} nlp.update(train_examples, sgd=optimizer) @@ -424,9 +544,10 @@ def test_zero_suggestions(): list(nlp.pipe(["", "one", "three three three"])) -def test_set_candidates(): +@pytest.mark.parametrize("name", SPANCAT_COMPONENTS) +def test_set_candidates(name): nlp = Language() - spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) + spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY}) train_examples = make_examples(nlp) nlp.initialize(get_examples=lambda: train_examples) texts = [ diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index f5bcdfd23..1fdf059b3 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -552,7 +552,14 @@ def test_parse_cli_overrides(): @pytest.mark.parametrize("lang", ["en", "nl"]) @pytest.mark.parametrize( - "pipeline", [["tagger", "parser", "ner"], [], ["ner", "textcat", "sentencizer"]] + "pipeline", + [ + ["tagger", "parser", "ner"], + [], + ["ner", "textcat", "sentencizer"], + ["morphologizer", "spancat", "entity_linker"], + ["spancat_singlelabel", "textcat_multilabel"], + ], ) @pytest.mark.parametrize("optimize", ["efficiency", "accuracy"]) @pytest.mark.parametrize("pretraining", [True, False]) diff --git a/website/docs/api/spancategorizer.mdx b/website/docs/api/spancategorizer.mdx index f39c0aff9..c7de2324b 100644 --- a/website/docs/api/spancategorizer.mdx +++ b/website/docs/api/spancategorizer.mdx @@ -13,6 +13,13 @@ A span categorizer consists of two parts: a [suggester function](#suggesters) that proposes candidate spans, which may or may not overlap, and a labeler model that predicts zero or more labels for each candidate. +This component comes in two forms: `spancat` and `spancat_singlelabel` (added in +spaCy v3.5.1). When you need to perform multi-label classification on your +spans, use `spancat`. The `spancat` component uses a `Logistic` layer where the +output class probabilities are independent for each class. However, if you need +to predict at most one true class for a span, then use `spancat_singlelabel`. It +uses a `Softmax` layer and treats the task as a multi-class problem. + Predicted spans will be saved in a [`SpanGroup`](/api/spangroup) on the doc. Individual span scores can be found in `spangroup.attrs["scores"]`. @@ -38,7 +45,7 @@ how the component should be configured. You can override its settings via the [model architectures](/api/architectures) documentation for details on the architectures and their arguments and hyperparameters. -> #### Example +> #### Example (spancat) > > ```python > from spacy.pipeline.spancat import DEFAULT_SPANCAT_MODEL @@ -52,14 +59,33 @@ architectures and their arguments and hyperparameters. > nlp.add_pipe("spancat", config=config) > ``` -| Setting | Description | -| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. Defaults to [`ngram_suggester`](#ngram_suggester). ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ | -| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. Defaults to [SpanCategorizer](/api/architectures#SpanCategorizer). ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ | -| `spans_key` | Key of the [`Doc.spans`](/api/doc#spans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"sc"`. ~~str~~ | -| `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Defaults to `0.5`. ~~float~~ | -| `max_positive` | Maximum number of labels to consider positive per span. Defaults to `None`, indicating no limit. ~~Optional[int]~~ | -| `scorer` | The scoring method. Defaults to [`Scorer.score_spans`](/api/scorer#score_spans) for `Doc.spans[spans_key]` with overlapping spans allowed. ~~Optional[Callable]~~ | +> #### Example (spancat_singlelabel) +> +> ```python +> from spacy.pipeline.spancat import DEFAULT_SPANCAT_SINGLELABEL_MODEL +> config = { +> "threshold": 0.5, +> "spans_key": "labeled_spans", +> "model": DEFAULT_SPANCAT_SINGLELABEL_MODEL, +> "suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]}, +> # Additional spancat_singlelabel parameters +> "negative_weight": 0.8, +> "allow_overlap": True, +> } +> nlp.add_pipe("spancat_singlelabel", config=config) +> ``` + +| Setting | Description | +| --------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. Defaults to [`ngram_suggester`](#ngram_suggester). ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ | +| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. Defaults to [SpanCategorizer](/api/architectures#SpanCategorizer). ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ | +| `spans_key` | Key of the [`Doc.spans`](/api/doc#spans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"sc"`. ~~str~~ | +| `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Meant to be used in combination with the multi-class `spancat` component with a `Logistic` scoring layer. Defaults to `0.5`. ~~float~~ | +| `max_positive` | Maximum number of labels to consider positive per span. Defaults to `None`, indicating no limit. Meant to be used together with the `spancat` component and defaults to 0 with `spancat_singlelabel`. ~~Optional[int]~~ | +| `scorer` | The scoring method. Defaults to [`Scorer.score_spans`](/api/scorer#score_spans) for `Doc.spans[spans_key]` with overlapping spans allowed. ~~Optional[Callable]~~ | +| `add_negative_label` 3.5.1 | Whether to learn to predict a special negative label for each unannotated `Span` . This should be `True` when using a `Softmax` classifier layer and so its `True` by default for `spancat_singlelabel`. Spans with negative labels and their scores are not stored as annotations. ~~bool~~ | +| `negative_weight` 3.5.1 | Multiplier for the loss terms. It can be used to downweight the negative samples if there are too many. It is only used when `add_negative_label` is `True`. Defaults to `1.0`. ~~float~~ | +| `allow_overlap` 3.5.1 | If `True`, the data is assumed to contain overlapping spans. It is only available when `max_positive` is exactly 1. Defaults to `True`. ~~bool~~ | ```python %%GITHUB_SPACY/spacy/pipeline/spancat.py @@ -71,6 +97,7 @@ architectures and their arguments and hyperparameters. > > ```python > # Construction via add_pipe with default model +> # Replace 'spancat' with 'spancat_singlelabel' for exclusive classes > spancat = nlp.add_pipe("spancat") > > # Construction via add_pipe with custom model @@ -86,16 +113,19 @@ Create a new pipeline instance. In your application, you would normally use a shortcut for this and instantiate the component using its string name and [`nlp.add_pipe`](/api/language#create_pipe). -| Name | Description | -| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| `vocab` | The shared vocabulary. ~~Vocab~~ | -| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ | -| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ | -| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ | -| _keyword-only_ | | -| `spans_key` | Key of the [`Doc.spans`](/api/doc#sans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"sc"`. ~~str~~ | -| `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Defaults to `0.5`. ~~float~~ | -| `max_positive` | Maximum number of labels to consider positive per span. Defaults to `None`, indicating no limit. ~~Optional[int]~~ | +| Name | Description | +| --------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `vocab` | The shared vocabulary. ~~Vocab~~ | +| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ | +| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ | +| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ | +| _keyword-only_ | | +| `spans_key` | Key of the [`Doc.spans`](/api/doc#sans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"sc"`. ~~str~~ | +| `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Defaults to `0.5`. ~~float~~ | +| `max_positive` | Maximum number of labels to consider positive per span. Defaults to `None`, indicating no limit. ~~Optional[int]~~ | +| `allow_overlap` 3.5.1 | If `True`, the data is assumed to contain overlapping spans. It is only available when `max_positive` is exactly 1. Defaults to `True`. ~~bool~~ | +| `add_negative_label` 3.5.1 | Whether to learn to predict a special negative label for each unannotated `Span`. This should be `True` when using a `Softmax` classifier layer and so its `True` by default for `spancat_singlelabel` . Spans with negative labels and their scores are not stored as annotations. ~~bool~~ | +| `negative_weight` 3.5.1 | Multiplier for the loss terms. It can be used to downweight the negative samples if there are too many . It is only used when `add_negative_label` is `True`. Defaults to `1.0`. ~~float~~ | ## SpanCategorizer.\_\_call\_\_ {id="call",tag="method"}