mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Add spancat_singlelabel pipeline for multiclass and non-overlapping span labelling tasks (#11365)
* [wip] Update * [wip] Update * Add initial port * [wip] Update * Fix all imports * Add spancat_exclusive to pipeline * [WIP] Update * [ci skip] Add breakpoint for debugging * Use spacy.SpanCategorizer.v1 as default archi * Update spacy/pipeline/spancat_exclusive.py Co-authored-by: kadarakos <kadar.akos@gmail.com> * [ci skip] Small updates * Use Softmax v2 directly from thinc * Cache the label map * Fix mypy errors However, I ignored line 370 because it opened up a bunch of type errors that might be trickier to solve and might lead to a more complicated codebase. * avoid multiplication with 1.0 Co-authored-by: kadarakos <kadar.akos@gmail.com> * Update spacy/pipeline/spancat_exclusive.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Update component versions to v2 * Add scorer to docstring * Add _n_labels property to SpanCategorizer Instead of using len(self.labels) in initialize() I am using a private property self._n_labels. This achieves implementation parity and allows me to delete the whole initialize() method for spancat_exclusive (since it's now the same with spancat). * Inherit from SpanCat instead of TrainablePipe This commit changes the inheritance structure of Exclusive_Spancat, now it's inheriting from SpanCategorizer than TrainablePipe. This allows me to remove duplicate methods that are already present in the parent function. * Revert documentation link to spancat * Fix init call for exclusive spancat * Update spacy/pipeline/spancat_exclusive.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Import Suggester from spancat * Include zero_init.v1 for spancat * Implement _allow_extra_label to use _n_labels To ensure that spancat / spancat_exclusive cannot be resized after initialization, I inherited the _allow_extra_label() method from spacy/pipeline/trainable_pipe.pyx and used self._n_labels instead of len(self.labels) for checking. I think that changing it locally is a better solution rather than forcing each class that inherits TrainablePipe to use the self._n_labels attribute. Also note that I turned-off black formatting in this block of code because it reads better without the overhang. * Extend existing tests to spancat_exclusive In this commit, I extended the existing tests for spancat to include spancat_exclusive. I parametrized the test functions with 'name' (similar var name with textcat and textcat_multilabel) for each applicable test. TODO: Add overfitting tests for spancat_exclusive * Update documentation for spancat * Turn on formatting for allow_extra_label * Remove initializers in default config * Use DEFAULT_EXCL_SPANCAT_MODEL I also renamed spancat_exclusive_default_config into spancat_excl_default_config because black does some not pretty formatting changes. * Update documentation Update grammar and usage Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Clarify docstring for Exclusive_SpanCategorizer * Remove mypy ignore and typecast labels to list * Fix documentation API * Use a single variable for tests * Update defaults for number of rows Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Put back initializers in spancat config Whenever I remove model.scorer.init_w and model.scorer.init_b, I encounter an error in the test: SystemError: <method '__getitem__' of 'dict' objects> returned a result with an error set. My Thinc version is 8.1.5, but I can't seem to check what's causing the error. * Update spancat_exclusive docstring * Remove init_W and init_B parameters This commit is expected to fail until the new Thinc release. * Require thinc>=8.1.6 for serializable Softmax defaults * Handle zero suggestions to make tests pass I'm not sure if this is the most elegant solution. But what should happen is that the _make_span_group function MUST return an empty SpanGroup if there are no suggestions. The error happens when the 'scores' variable is empty. We cannot get the 'predicted' and other downstream vars. * Better approach for handling zero suggestions * Update website/docs/api/spancategorizer.md Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Update spancategorizer headers * Apply suggestions from code review Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Add default value in negative_weight in docs * Add default value in allow_overlap in docs * Update how spancat_exclusive is constructed In this commit, I added the following: - Put the default values of negative_weight and allow_overlap in the default_config dictionary. - Rename make_spancat -> make_exclusive_spancat * Run prettier on spancategorizer.mdx * Change exactly one -> at most one * Add suggester documentation in Exclusive_SpanCategorizer * Add suggester to spancat docstrings * merge multilabel and singlelabel spancat * rename spancat_exclusive to singlelable * wire up different make_spangroups for single and multilabel * black * black * add docstrings * more docstring and fix negative_label * don't rely on default arguments * black * remove spancat exclusive * replace single_label with add_negative_label and adjust inference * mypy * logical bug in configuration check * add spans.attrs[scores] * single label make_spangroup test * bugfix * black * tests for make_span_group with negative labels * refactor make_span_group * black * Update spacy/tests/pipeline/test_spancat.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * remove duplicate declaration * Update spacy/pipeline/spancat.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * raise error instead of just print * make label mapper private * update docs * run prettier * Update website/docs/api/spancategorizer.mdx Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Update website/docs/api/spancategorizer.mdx Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Update spacy/pipeline/spancat.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Update spacy/pipeline/spancat.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Update spacy/pipeline/spancat.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Update spacy/pipeline/spancat.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * don't keep recomputing self._label_map for each span * typo in docs * Intervals to private and document 'name' param * Update spacy/pipeline/spancat.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Update spacy/pipeline/spancat.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * add Tag to new features * replace tags * revert * revert * revert * revert * Update website/docs/api/spancategorizer.mdx Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Update website/docs/api/spancategorizer.mdx Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * prettier * Fix merge * Update website/docs/api/spancategorizer.mdx * remove references to 'single_label' * remove old paragraph * Add spancat_singlelabel to config template * Format * Extend init config tests --------- Co-authored-by: kadarakos <kadar.akos@gmail.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
4fdf356b29
commit
913d74f509
|
@ -3,7 +3,7 @@ the docs and the init config command. It encodes various best practices and
|
||||||
can help generate the best possible configuration, given a user's requirements. #}
|
can help generate the best possible configuration, given a user's requirements. #}
|
||||||
{%- set use_transformer = hardware != "cpu" and transformer_data -%}
|
{%- set use_transformer = hardware != "cpu" and transformer_data -%}
|
||||||
{%- set transformer = transformer_data[optimize] if use_transformer else {} -%}
|
{%- set transformer = transformer_data[optimize] if use_transformer else {} -%}
|
||||||
{%- set listener_components = ["tagger", "morphologizer", "parser", "ner", "textcat", "textcat_multilabel", "entity_linker", "spancat", "trainable_lemmatizer"] -%}
|
{%- set listener_components = ["tagger", "morphologizer", "parser", "ner", "textcat", "textcat_multilabel", "entity_linker", "spancat", "spancat_singlelabel", "trainable_lemmatizer"] -%}
|
||||||
[paths]
|
[paths]
|
||||||
train = null
|
train = null
|
||||||
dev = null
|
dev = null
|
||||||
|
@ -28,7 +28,7 @@ lang = "{{ lang }}"
|
||||||
tok2vec/transformer. #}
|
tok2vec/transformer. #}
|
||||||
{%- set with_accuracy_or_transformer = (use_transformer or with_accuracy) -%}
|
{%- set with_accuracy_or_transformer = (use_transformer or with_accuracy) -%}
|
||||||
{%- set textcat_needs_features = has_textcat and with_accuracy_or_transformer -%}
|
{%- set textcat_needs_features = has_textcat and with_accuracy_or_transformer -%}
|
||||||
{%- if ("tagger" in components or "morphologizer" in components or "parser" in components or "ner" in components or "spancat" in components or "trainable_lemmatizer" in components or "entity_linker" in components or textcat_needs_features) -%}
|
{%- if ("tagger" in components or "morphologizer" in components or "parser" in components or "ner" in components or "spancat" in components or "spancat_singlelabel" in components or "trainable_lemmatizer" in components or "entity_linker" in components or textcat_needs_features) -%}
|
||||||
{%- set full_pipeline = ["transformer" if use_transformer else "tok2vec"] + components -%}
|
{%- set full_pipeline = ["transformer" if use_transformer else "tok2vec"] + components -%}
|
||||||
{%- else -%}
|
{%- else -%}
|
||||||
{%- set full_pipeline = components -%}
|
{%- set full_pipeline = components -%}
|
||||||
|
@ -159,6 +159,36 @@ grad_factor = 1.0
|
||||||
sizes = [1,2,3]
|
sizes = [1,2,3]
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
|
|
||||||
|
{% if "spancat_singlelabel" in components %}
|
||||||
|
[components.spancat_singlelabel]
|
||||||
|
factory = "spancat_singlelabel"
|
||||||
|
negative_weight = 1.0
|
||||||
|
allow_overlap = true
|
||||||
|
scorer = {"@scorers":"spacy.spancat_scorer.v1"}
|
||||||
|
spans_key = "sc"
|
||||||
|
|
||||||
|
[components.spancat_singlelabel.model]
|
||||||
|
@architectures = "spacy.SpanCategorizer.v1"
|
||||||
|
|
||||||
|
[components.spancat_singlelabel.model.reducer]
|
||||||
|
@layers = "spacy.mean_max_reducer.v1"
|
||||||
|
hidden_size = 128
|
||||||
|
|
||||||
|
[components.spancat_singlelabel.model.scorer]
|
||||||
|
@layers = "Softmax.v2"
|
||||||
|
|
||||||
|
[components.spancat_singlelabel.model.tok2vec]
|
||||||
|
@architectures = "spacy-transformers.TransformerListener.v1"
|
||||||
|
grad_factor = 1.0
|
||||||
|
|
||||||
|
[components.spancat_singlelabel.model.tok2vec.pooling]
|
||||||
|
@layers = "reduce_mean.v1"
|
||||||
|
|
||||||
|
[components.spancat_singlelabel.suggester]
|
||||||
|
@misc = "spacy.ngram_suggester.v1"
|
||||||
|
sizes = [1,2,3]
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
{% if "trainable_lemmatizer" in components -%}
|
{% if "trainable_lemmatizer" in components -%}
|
||||||
[components.trainable_lemmatizer]
|
[components.trainable_lemmatizer]
|
||||||
factory = "trainable_lemmatizer"
|
factory = "trainable_lemmatizer"
|
||||||
|
@ -389,6 +419,33 @@ width = ${components.tok2vec.model.encode.width}
|
||||||
sizes = [1,2,3]
|
sizes = [1,2,3]
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
|
{% if "spancat_singlelabel" in components %}
|
||||||
|
[components.spancat_singlelabel]
|
||||||
|
factory = "spancat_singlelabel"
|
||||||
|
negative_weight = 1.0
|
||||||
|
allow_overlap = true
|
||||||
|
scorer = {"@scorers":"spacy.spancat_scorer.v1"}
|
||||||
|
spans_key = "sc"
|
||||||
|
|
||||||
|
[components.spancat_singlelabel.model]
|
||||||
|
@architectures = "spacy.SpanCategorizer.v1"
|
||||||
|
|
||||||
|
[components.spancat_singlelabel.model.reducer]
|
||||||
|
@layers = "spacy.mean_max_reducer.v1"
|
||||||
|
hidden_size = 128
|
||||||
|
|
||||||
|
[components.spancat_singlelabel.model.scorer]
|
||||||
|
@layers = "Softmax.v2"
|
||||||
|
|
||||||
|
[components.spancat_singlelabel.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
|
||||||
|
[components.spancat_singlelabel.suggester]
|
||||||
|
@misc = "spacy.ngram_suggester.v1"
|
||||||
|
sizes = [1,2,3]
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
{% if "trainable_lemmatizer" in components -%}
|
{% if "trainable_lemmatizer" in components -%}
|
||||||
[components.trainable_lemmatizer]
|
[components.trainable_lemmatizer]
|
||||||
factory = "trainable_lemmatizer"
|
factory = "trainable_lemmatizer"
|
||||||
|
|
|
@ -969,6 +969,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
"with `displacy.serve(doc, port=port)`")
|
"with `displacy.serve(doc, port=port)`")
|
||||||
E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` "
|
E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` "
|
||||||
"or use `auto_select_port=True` to pick an available port automatically.")
|
"or use `auto_select_port=True` to pick an available port automatically.")
|
||||||
|
E1051 = ("'allow_overlap' can only be False when max_positive is 1, but found 'max_positive': {max_positive}.")
|
||||||
|
|
||||||
|
|
||||||
# Deprecated model shortcuts, only used in errors and warnings
|
# Deprecated model shortcuts, only used in errors and warnings
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any
|
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union
|
||||||
|
from dataclasses import dataclass
|
||||||
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
||||||
from thinc.api import Optimizer
|
from thinc.api import Optimizer
|
||||||
from thinc.types import Ragged, Ints2d, Floats2d
|
from thinc.types import Ragged, Ints2d, Floats2d
|
||||||
|
@ -43,7 +44,36 @@ maxout_pieces = 3
|
||||||
depth = 4
|
depth = 4
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
spancat_singlelabel_default_config = """
|
||||||
|
[model]
|
||||||
|
@architectures = "spacy.SpanCategorizer.v1"
|
||||||
|
scorer = {"@layers": "Softmax.v2"}
|
||||||
|
|
||||||
|
[model.reducer]
|
||||||
|
@layers = spacy.mean_max_reducer.v1
|
||||||
|
hidden_size = 128
|
||||||
|
|
||||||
|
[model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2Vec.v2"
|
||||||
|
[model.tok2vec.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v1"
|
||||||
|
width = 96
|
||||||
|
rows = [5000, 1000, 2500, 1000]
|
||||||
|
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||||
|
include_static_vectors = false
|
||||||
|
|
||||||
|
[model.tok2vec.encode]
|
||||||
|
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||||
|
width = ${model.tok2vec.embed.width}
|
||||||
|
window_size = 1
|
||||||
|
maxout_pieces = 3
|
||||||
|
depth = 4
|
||||||
|
"""
|
||||||
|
|
||||||
DEFAULT_SPANCAT_MODEL = Config().from_str(spancat_default_config)["model"]
|
DEFAULT_SPANCAT_MODEL = Config().from_str(spancat_default_config)["model"]
|
||||||
|
DEFAULT_SPANCAT_SINGLELABEL_MODEL = Config().from_str(
|
||||||
|
spancat_singlelabel_default_config
|
||||||
|
)["model"]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -119,10 +149,14 @@ def make_spancat(
|
||||||
threshold: float,
|
threshold: float,
|
||||||
max_positive: Optional[int],
|
max_positive: Optional[int],
|
||||||
) -> "SpanCategorizer":
|
) -> "SpanCategorizer":
|
||||||
"""Create a SpanCategorizer component. The span categorizer consists of two
|
"""Create a SpanCategorizer component and configure it for multi-label
|
||||||
|
classification to be able to assign multiple labels for each span.
|
||||||
|
The span categorizer consists of two
|
||||||
parts: a suggester function that proposes candidate spans, and a labeller
|
parts: a suggester function that proposes candidate spans, and a labeller
|
||||||
model that predicts one or more labels for each span.
|
model that predicts one or more labels for each span.
|
||||||
|
|
||||||
|
name (str): The component instance name, used to add entries to the
|
||||||
|
losses during training.
|
||||||
suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans.
|
suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans.
|
||||||
Spans are returned as a ragged array with two integer columns, for the
|
Spans are returned as a ragged array with two integer columns, for the
|
||||||
start and end positions.
|
start and end positions.
|
||||||
|
@ -144,12 +178,80 @@ def make_spancat(
|
||||||
"""
|
"""
|
||||||
return SpanCategorizer(
|
return SpanCategorizer(
|
||||||
nlp.vocab,
|
nlp.vocab,
|
||||||
suggester=suggester,
|
|
||||||
model=model,
|
model=model,
|
||||||
spans_key=spans_key,
|
suggester=suggester,
|
||||||
threshold=threshold,
|
|
||||||
max_positive=max_positive,
|
|
||||||
name=name,
|
name=name,
|
||||||
|
spans_key=spans_key,
|
||||||
|
negative_weight=None,
|
||||||
|
allow_overlap=True,
|
||||||
|
max_positive=max_positive,
|
||||||
|
threshold=threshold,
|
||||||
|
scorer=scorer,
|
||||||
|
add_negative_label=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@Language.factory(
|
||||||
|
"spancat_singlelabel",
|
||||||
|
assigns=["doc.spans"],
|
||||||
|
default_config={
|
||||||
|
"spans_key": "sc",
|
||||||
|
"model": DEFAULT_SPANCAT_SINGLELABEL_MODEL,
|
||||||
|
"negative_weight": 1.0,
|
||||||
|
"suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
|
||||||
|
"scorer": {"@scorers": "spacy.spancat_scorer.v1"},
|
||||||
|
"allow_overlap": True,
|
||||||
|
},
|
||||||
|
default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0},
|
||||||
|
)
|
||||||
|
def make_spancat_singlelabel(
|
||||||
|
nlp: Language,
|
||||||
|
name: str,
|
||||||
|
suggester: Suggester,
|
||||||
|
model: Model[Tuple[List[Doc], Ragged], Floats2d],
|
||||||
|
spans_key: str,
|
||||||
|
negative_weight: float,
|
||||||
|
allow_overlap: bool,
|
||||||
|
scorer: Optional[Callable],
|
||||||
|
) -> "SpanCategorizer":
|
||||||
|
"""Create a SpanCategorizer component and configure it for multi-class
|
||||||
|
classification. With this configuration each span can get at most one
|
||||||
|
label. The span categorizer consists of two
|
||||||
|
parts: a suggester function that proposes candidate spans, and a labeller
|
||||||
|
model that predicts one or more labels for each span.
|
||||||
|
|
||||||
|
name (str): The component instance name, used to add entries to the
|
||||||
|
losses during training.
|
||||||
|
suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans.
|
||||||
|
Spans are returned as a ragged array with two integer columns, for the
|
||||||
|
start and end positions.
|
||||||
|
model (Model[Tuple[List[Doc], Ragged], Floats2d]): A model instance that
|
||||||
|
is given a list of documents and (start, end) indices representing
|
||||||
|
candidate span offsets. The model predicts a probability for each category
|
||||||
|
for each span.
|
||||||
|
spans_key (str): Key of the doc.spans dict to save the spans under. During
|
||||||
|
initialization and training, the component will look for spans on the
|
||||||
|
reference document under the same key.
|
||||||
|
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||||
|
Scorer.score_spans for the Doc.spans[spans_key] with overlapping
|
||||||
|
spans allowed.
|
||||||
|
negative_weight (float): Multiplier for the loss terms.
|
||||||
|
Can be used to downweight the negative samples if there are too many.
|
||||||
|
allow_overlap (bool): If True the data is assumed to contain overlapping spans.
|
||||||
|
Otherwise it produces non-overlapping spans greedily prioritizing
|
||||||
|
higher assigned label scores.
|
||||||
|
"""
|
||||||
|
return SpanCategorizer(
|
||||||
|
nlp.vocab,
|
||||||
|
model=model,
|
||||||
|
suggester=suggester,
|
||||||
|
name=name,
|
||||||
|
spans_key=spans_key,
|
||||||
|
negative_weight=negative_weight,
|
||||||
|
allow_overlap=allow_overlap,
|
||||||
|
max_positive=1,
|
||||||
|
add_negative_label=True,
|
||||||
|
threshold=None,
|
||||||
scorer=scorer,
|
scorer=scorer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -172,6 +274,27 @@ def make_spancat_scorer():
|
||||||
return spancat_score
|
return spancat_score
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Intervals:
|
||||||
|
"""
|
||||||
|
Helper class to avoid storing overlapping spans.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.ranges = set()
|
||||||
|
|
||||||
|
def add(self, i, j):
|
||||||
|
for e in range(i, j):
|
||||||
|
self.ranges.add(e)
|
||||||
|
|
||||||
|
def __contains__(self, rang):
|
||||||
|
i, j = rang
|
||||||
|
for e in range(i, j):
|
||||||
|
if e in self.ranges:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class SpanCategorizer(TrainablePipe):
|
class SpanCategorizer(TrainablePipe):
|
||||||
"""Pipeline component to label spans of text.
|
"""Pipeline component to label spans of text.
|
||||||
|
|
||||||
|
@ -185,25 +308,43 @@ class SpanCategorizer(TrainablePipe):
|
||||||
suggester: Suggester,
|
suggester: Suggester,
|
||||||
name: str = "spancat",
|
name: str = "spancat",
|
||||||
*,
|
*,
|
||||||
|
add_negative_label: bool = False,
|
||||||
spans_key: str = "spans",
|
spans_key: str = "spans",
|
||||||
threshold: float = 0.5,
|
negative_weight: Optional[float] = 1.0,
|
||||||
|
allow_overlap: Optional[bool] = True,
|
||||||
max_positive: Optional[int] = None,
|
max_positive: Optional[int] = None,
|
||||||
|
threshold: Optional[float] = 0.5,
|
||||||
scorer: Optional[Callable] = spancat_score,
|
scorer: Optional[Callable] = spancat_score,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the span categorizer.
|
"""Initialize the multi-label or multi-class span categorizer.
|
||||||
|
|
||||||
vocab (Vocab): The shared vocabulary.
|
vocab (Vocab): The shared vocabulary.
|
||||||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||||
|
For multi-class classification (single label per span) we recommend
|
||||||
|
using a Softmax classifier as a the final layer, while for multi-label
|
||||||
|
classification (multiple possible labels per span) we recommend Logistic.
|
||||||
|
suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans.
|
||||||
|
Spans are returned as a ragged array with two integer columns, for the
|
||||||
|
start and end positions.
|
||||||
name (str): The component instance name, used to add entries to the
|
name (str): The component instance name, used to add entries to the
|
||||||
losses during training.
|
losses during training.
|
||||||
spans_key (str): Key of the Doc.spans dict to save the spans under.
|
spans_key (str): Key of the Doc.spans dict to save the spans under.
|
||||||
During initialization and training, the component will look for
|
During initialization and training, the component will look for
|
||||||
spans on the reference document under the same key. Defaults to
|
spans on the reference document under the same key. Defaults to
|
||||||
`"spans"`.
|
`"spans"`.
|
||||||
threshold (float): Minimum probability to consider a prediction
|
add_negative_label (bool): Learn to predict a special 'negative_label'
|
||||||
positive. Spans with a positive prediction will be saved on the Doc.
|
when a Span is not annotated.
|
||||||
Defaults to 0.5.
|
threshold (Optional[float]): Minimum probability to consider a prediction
|
||||||
|
positive. Defaults to 0.5. Spans with a positive prediction will be saved
|
||||||
|
on the Doc.
|
||||||
max_positive (Optional[int]): Maximum number of labels to consider
|
max_positive (Optional[int]): Maximum number of labels to consider
|
||||||
positive per span. Defaults to None, indicating no limit.
|
positive per span. Defaults to None, indicating no limit.
|
||||||
|
negative_weight (float): Multiplier for the loss terms.
|
||||||
|
Can be used to downweight the negative samples if there are too many
|
||||||
|
when add_negative_label is True. Otherwise its unused.
|
||||||
|
allow_overlap (bool): If True the data is assumed to contain overlapping spans.
|
||||||
|
Otherwise it produces non-overlapping spans greedily prioritizing
|
||||||
|
higher assigned label scores. Only used when max_positive is 1.
|
||||||
scorer (Optional[Callable]): The scoring method. Defaults to
|
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||||
Scorer.score_spans for the Doc.spans[spans_key] with overlapping
|
Scorer.score_spans for the Doc.spans[spans_key] with overlapping
|
||||||
spans allowed.
|
spans allowed.
|
||||||
|
@ -215,12 +356,17 @@ class SpanCategorizer(TrainablePipe):
|
||||||
"spans_key": spans_key,
|
"spans_key": spans_key,
|
||||||
"threshold": threshold,
|
"threshold": threshold,
|
||||||
"max_positive": max_positive,
|
"max_positive": max_positive,
|
||||||
|
"negative_weight": negative_weight,
|
||||||
|
"allow_overlap": allow_overlap,
|
||||||
}
|
}
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.suggester = suggester
|
self.suggester = suggester
|
||||||
self.model = model
|
self.model = model
|
||||||
self.name = name
|
self.name = name
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
|
self.add_negative_label = add_negative_label
|
||||||
|
if not allow_overlap and max_positive is not None and max_positive > 1:
|
||||||
|
raise ValueError(Errors.E1051.format(max_positive=max_positive))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self) -> str:
|
def key(self) -> str:
|
||||||
|
@ -230,6 +376,21 @@ class SpanCategorizer(TrainablePipe):
|
||||||
"""
|
"""
|
||||||
return str(self.cfg["spans_key"])
|
return str(self.cfg["spans_key"])
|
||||||
|
|
||||||
|
def _allow_extra_label(self) -> None:
|
||||||
|
"""Raise an error if the component can not add any more labels."""
|
||||||
|
nO = None
|
||||||
|
if self.model.has_dim("nO"):
|
||||||
|
nO = self.model.get_dim("nO")
|
||||||
|
elif self.model.has_ref("output_layer") and self.model.get_ref(
|
||||||
|
"output_layer"
|
||||||
|
).has_dim("nO"):
|
||||||
|
nO = self.model.get_ref("output_layer").get_dim("nO")
|
||||||
|
if nO is not None and nO == self._n_labels:
|
||||||
|
if not self.is_resizable:
|
||||||
|
raise ValueError(
|
||||||
|
Errors.E922.format(name=self.name, nO=self.model.get_dim("nO"))
|
||||||
|
)
|
||||||
|
|
||||||
def add_label(self, label: str) -> int:
|
def add_label(self, label: str) -> int:
|
||||||
"""Add a new label to the pipe.
|
"""Add a new label to the pipe.
|
||||||
|
|
||||||
|
@ -263,6 +424,27 @@ class SpanCategorizer(TrainablePipe):
|
||||||
"""
|
"""
|
||||||
return list(self.labels)
|
return list(self.labels)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _label_map(self) -> Dict[str, int]:
|
||||||
|
"""RETURNS (Dict[str, int]): The label map."""
|
||||||
|
return {label: i for i, label in enumerate(self.labels)}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _n_labels(self) -> int:
|
||||||
|
"""RETURNS (int): Number of labels."""
|
||||||
|
if self.add_negative_label:
|
||||||
|
return len(self.labels) + 1
|
||||||
|
else:
|
||||||
|
return len(self.labels)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _negative_label_i(self) -> Union[int, None]:
|
||||||
|
"""RETURNS (Union[int, None]): Index of the negative label."""
|
||||||
|
if self.add_negative_label:
|
||||||
|
return len(self.label_data)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def predict(self, docs: Iterable[Doc]):
|
def predict(self, docs: Iterable[Doc]):
|
||||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||||
|
|
||||||
|
@ -304,14 +486,24 @@ class SpanCategorizer(TrainablePipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/spancategorizer#set_annotations
|
DOCS: https://spacy.io/api/spancategorizer#set_annotations
|
||||||
"""
|
"""
|
||||||
labels = self.labels
|
|
||||||
indices, scores = indices_scores
|
indices, scores = indices_scores
|
||||||
offset = 0
|
offset = 0
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
indices_i = indices[i].dataXd
|
indices_i = indices[i].dataXd
|
||||||
doc.spans[self.key] = self._make_span_group(
|
allow_overlap = cast(bool, self.cfg["allow_overlap"])
|
||||||
doc, indices_i, scores[offset : offset + indices.lengths[i]], labels # type: ignore[arg-type]
|
if self.cfg["max_positive"] == 1:
|
||||||
)
|
doc.spans[self.key] = self._make_span_group_singlelabel(
|
||||||
|
doc,
|
||||||
|
indices_i,
|
||||||
|
scores[offset : offset + indices.lengths[i]],
|
||||||
|
allow_overlap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
doc.spans[self.key] = self._make_span_group_multilabel(
|
||||||
|
doc,
|
||||||
|
indices_i,
|
||||||
|
scores[offset : offset + indices.lengths[i]],
|
||||||
|
)
|
||||||
offset += indices.lengths[i]
|
offset += indices.lengths[i]
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
|
@ -371,9 +563,11 @@ class SpanCategorizer(TrainablePipe):
|
||||||
spans = Ragged(
|
spans = Ragged(
|
||||||
self.model.ops.to_numpy(spans.data), self.model.ops.to_numpy(spans.lengths)
|
self.model.ops.to_numpy(spans.data), self.model.ops.to_numpy(spans.lengths)
|
||||||
)
|
)
|
||||||
label_map = {label: i for i, label in enumerate(self.labels)}
|
|
||||||
target = numpy.zeros(scores.shape, dtype=scores.dtype)
|
target = numpy.zeros(scores.shape, dtype=scores.dtype)
|
||||||
|
if self.add_negative_label:
|
||||||
|
negative_spans = numpy.ones((scores.shape[0]))
|
||||||
offset = 0
|
offset = 0
|
||||||
|
label_map = self._label_map
|
||||||
for i, eg in enumerate(examples):
|
for i, eg in enumerate(examples):
|
||||||
# Map (start, end) offset of spans to the row in the d_scores array,
|
# Map (start, end) offset of spans to the row in the d_scores array,
|
||||||
# so that we can adjust the gradient for predictions that were
|
# so that we can adjust the gradient for predictions that were
|
||||||
|
@ -390,10 +584,16 @@ class SpanCategorizer(TrainablePipe):
|
||||||
row = spans_index[key]
|
row = spans_index[key]
|
||||||
k = label_map[gold_span.label_]
|
k = label_map[gold_span.label_]
|
||||||
target[row, k] = 1.0
|
target[row, k] = 1.0
|
||||||
|
if self.add_negative_label:
|
||||||
|
# delete negative label target.
|
||||||
|
negative_spans[row] = 0.0
|
||||||
# The target is a flat array for all docs. Track the position
|
# The target is a flat array for all docs. Track the position
|
||||||
# we're at within the flat array.
|
# we're at within the flat array.
|
||||||
offset += spans.lengths[i]
|
offset += spans.lengths[i]
|
||||||
target = self.model.ops.asarray(target, dtype="f") # type: ignore
|
target = self.model.ops.asarray(target, dtype="f") # type: ignore
|
||||||
|
if self.add_negative_label:
|
||||||
|
negative_samples = numpy.nonzero(negative_spans)[0]
|
||||||
|
target[negative_samples, self._negative_label_i] = 1.0 # type: ignore
|
||||||
# The target will have the values 0 (for untrue predictions) or 1
|
# The target will have the values 0 (for untrue predictions) or 1
|
||||||
# (for true predictions).
|
# (for true predictions).
|
||||||
# The scores should be in the range [0, 1].
|
# The scores should be in the range [0, 1].
|
||||||
|
@ -402,6 +602,10 @@ class SpanCategorizer(TrainablePipe):
|
||||||
# If the prediction is 0.9 and it's false, the gradient will be
|
# If the prediction is 0.9 and it's false, the gradient will be
|
||||||
# 0.9 (0.9 - 0.0)
|
# 0.9 (0.9 - 0.0)
|
||||||
d_scores = scores - target
|
d_scores = scores - target
|
||||||
|
if self.add_negative_label:
|
||||||
|
neg_weight = cast(float, self.cfg["negative_weight"])
|
||||||
|
if neg_weight != 1.0:
|
||||||
|
d_scores[negative_samples] *= neg_weight
|
||||||
loss = float((d_scores**2).sum())
|
loss = float((d_scores**2).sum())
|
||||||
return loss, d_scores
|
return loss, d_scores
|
||||||
|
|
||||||
|
@ -438,7 +642,7 @@ class SpanCategorizer(TrainablePipe):
|
||||||
if subbatch:
|
if subbatch:
|
||||||
docs = [eg.x for eg in subbatch]
|
docs = [eg.x for eg in subbatch]
|
||||||
spans = build_ngram_suggester(sizes=[1])(docs)
|
spans = build_ngram_suggester(sizes=[1])(docs)
|
||||||
Y = self.model.ops.alloc2f(spans.dataXd.shape[0], len(self.labels))
|
Y = self.model.ops.alloc2f(spans.dataXd.shape[0], self._n_labels)
|
||||||
self.model.initialize(X=(docs, spans), Y=Y)
|
self.model.initialize(X=(docs, spans), Y=Y)
|
||||||
else:
|
else:
|
||||||
self.model.initialize()
|
self.model.initialize()
|
||||||
|
@ -452,31 +656,96 @@ class SpanCategorizer(TrainablePipe):
|
||||||
eg.reference.spans.get(self.key, []), allow_overlap=True
|
eg.reference.spans.get(self.key, []), allow_overlap=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def _make_span_group(
|
def _make_span_group_multilabel(
|
||||||
self, doc: Doc, indices: Ints2d, scores: Floats2d, labels: List[str]
|
self,
|
||||||
|
doc: Doc,
|
||||||
|
indices: Ints2d,
|
||||||
|
scores: Floats2d,
|
||||||
) -> SpanGroup:
|
) -> SpanGroup:
|
||||||
|
"""Find the top-k labels for each span (k=max_positive)."""
|
||||||
spans = SpanGroup(doc, name=self.key)
|
spans = SpanGroup(doc, name=self.key)
|
||||||
max_positive = self.cfg["max_positive"]
|
if scores.size == 0:
|
||||||
|
return spans
|
||||||
|
scores = self.model.ops.to_numpy(scores)
|
||||||
|
indices = self.model.ops.to_numpy(indices)
|
||||||
threshold = self.cfg["threshold"]
|
threshold = self.cfg["threshold"]
|
||||||
|
max_positive = self.cfg["max_positive"]
|
||||||
|
|
||||||
keeps = scores >= threshold
|
keeps = scores >= threshold
|
||||||
ranked = (scores * -1).argsort() # type: ignore
|
|
||||||
if max_positive is not None:
|
if max_positive is not None:
|
||||||
assert isinstance(max_positive, int)
|
assert isinstance(max_positive, int)
|
||||||
|
if self.add_negative_label:
|
||||||
|
negative_scores = numpy.copy(scores[:, self._negative_label_i])
|
||||||
|
scores[:, self._negative_label_i] = -numpy.inf
|
||||||
|
ranked = (scores * -1).argsort() # type: ignore
|
||||||
|
scores[:, self._negative_label_i] = negative_scores
|
||||||
|
else:
|
||||||
|
ranked = (scores * -1).argsort() # type: ignore
|
||||||
span_filter = ranked[:, max_positive:]
|
span_filter = ranked[:, max_positive:]
|
||||||
for i, row in enumerate(span_filter):
|
for i, row in enumerate(span_filter):
|
||||||
keeps[i, row] = False
|
keeps[i, row] = False
|
||||||
spans.attrs["scores"] = scores[keeps].flatten()
|
|
||||||
|
|
||||||
indices = self.model.ops.to_numpy(indices)
|
|
||||||
keeps = self.model.ops.to_numpy(keeps)
|
|
||||||
|
|
||||||
|
attrs_scores = []
|
||||||
for i in range(indices.shape[0]):
|
for i in range(indices.shape[0]):
|
||||||
start = indices[i, 0]
|
start = indices[i, 0]
|
||||||
end = indices[i, 1]
|
end = indices[i, 1]
|
||||||
|
|
||||||
for j, keep in enumerate(keeps[i]):
|
for j, keep in enumerate(keeps[i]):
|
||||||
if keep:
|
if keep:
|
||||||
spans.append(Span(doc, start, end, label=labels[j]))
|
if j != self._negative_label_i:
|
||||||
|
spans.append(Span(doc, start, end, label=self.labels[j]))
|
||||||
|
attrs_scores.append(scores[i, j])
|
||||||
|
spans.attrs["scores"] = numpy.array(attrs_scores)
|
||||||
|
return spans
|
||||||
|
|
||||||
|
def _make_span_group_singlelabel(
|
||||||
|
self,
|
||||||
|
doc: Doc,
|
||||||
|
indices: Ints2d,
|
||||||
|
scores: Floats2d,
|
||||||
|
allow_overlap: bool = True,
|
||||||
|
) -> SpanGroup:
|
||||||
|
"""Find the argmax label for each span."""
|
||||||
|
# Handle cases when there are zero suggestions
|
||||||
|
if scores.size == 0:
|
||||||
|
return SpanGroup(doc, name=self.key)
|
||||||
|
scores = self.model.ops.to_numpy(scores)
|
||||||
|
indices = self.model.ops.to_numpy(indices)
|
||||||
|
predicted = scores.argmax(axis=1)
|
||||||
|
argmax_scores = numpy.take_along_axis(
|
||||||
|
scores, numpy.expand_dims(predicted, 1), axis=1
|
||||||
|
)
|
||||||
|
keeps = numpy.ones(predicted.shape, dtype=bool)
|
||||||
|
# Remove samples where the negative label is the argmax.
|
||||||
|
if self.add_negative_label:
|
||||||
|
keeps = numpy.logical_and(keeps, predicted != self._negative_label_i)
|
||||||
|
# Filter samples according to threshold.
|
||||||
|
threshold = self.cfg["threshold"]
|
||||||
|
if threshold is not None:
|
||||||
|
keeps = numpy.logical_and(keeps, (argmax_scores >= threshold).squeeze())
|
||||||
|
# Sort spans according to argmax probability
|
||||||
|
if not allow_overlap:
|
||||||
|
# Get the probabilities
|
||||||
|
sort_idx = (argmax_scores.squeeze() * -1).argsort()
|
||||||
|
predicted = predicted[sort_idx]
|
||||||
|
indices = indices[sort_idx]
|
||||||
|
keeps = keeps[sort_idx]
|
||||||
|
seen = _Intervals()
|
||||||
|
spans = SpanGroup(doc, name=self.key)
|
||||||
|
attrs_scores = []
|
||||||
|
for i in range(indices.shape[0]):
|
||||||
|
if not keeps[i]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
label = predicted[i]
|
||||||
|
start = indices[i, 0]
|
||||||
|
end = indices[i, 1]
|
||||||
|
|
||||||
|
if not allow_overlap:
|
||||||
|
if (start, end) in seen:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
seen.add(start, end)
|
||||||
|
attrs_scores.append(argmax_scores[i])
|
||||||
|
spans.append(Span(doc, start, end, label=self.labels[label]))
|
||||||
|
|
||||||
return spans
|
return spans
|
||||||
|
|
|
@ -15,6 +15,8 @@ OPS = get_current_ops()
|
||||||
|
|
||||||
SPAN_KEY = "labeled_spans"
|
SPAN_KEY = "labeled_spans"
|
||||||
|
|
||||||
|
SPANCAT_COMPONENTS = ["spancat", "spancat_singlelabel"]
|
||||||
|
|
||||||
TRAIN_DATA = [
|
TRAIN_DATA = [
|
||||||
("Who is Shaka Khan?", {"spans": {SPAN_KEY: [(7, 17, "PERSON")]}}),
|
("Who is Shaka Khan?", {"spans": {SPAN_KEY: [(7, 17, "PERSON")]}}),
|
||||||
(
|
(
|
||||||
|
@ -41,38 +43,42 @@ def make_examples(nlp, data=TRAIN_DATA):
|
||||||
return train_examples
|
return train_examples
|
||||||
|
|
||||||
|
|
||||||
def test_no_label():
|
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
|
||||||
|
def test_no_label(name):
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
|
|
||||||
|
|
||||||
def test_no_resize():
|
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
|
||||||
|
def test_no_resize(name):
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
|
||||||
spancat.add_label("Thing")
|
spancat.add_label("Thing")
|
||||||
spancat.add_label("Phrase")
|
spancat.add_label("Phrase")
|
||||||
assert spancat.labels == ("Thing", "Phrase")
|
assert spancat.labels == ("Thing", "Phrase")
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
assert spancat.model.get_dim("nO") == 2
|
assert spancat.model.get_dim("nO") == spancat._n_labels
|
||||||
# this throws an error because the spancat can't be resized after initialization
|
# this throws an error because the spancat can't be resized after initialization
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
spancat.add_label("Stuff")
|
spancat.add_label("Stuff")
|
||||||
|
|
||||||
|
|
||||||
def test_implicit_labels():
|
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
|
||||||
|
def test_implicit_labels(name):
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
|
||||||
assert len(spancat.labels) == 0
|
assert len(spancat.labels) == 0
|
||||||
train_examples = make_examples(nlp)
|
train_examples = make_examples(nlp)
|
||||||
nlp.initialize(get_examples=lambda: train_examples)
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
assert spancat.labels == ("PERSON", "LOC")
|
assert spancat.labels == ("PERSON", "LOC")
|
||||||
|
|
||||||
|
|
||||||
def test_explicit_labels():
|
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
|
||||||
|
def test_explicit_labels(name):
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
|
||||||
assert len(spancat.labels) == 0
|
assert len(spancat.labels) == 0
|
||||||
spancat.add_label("PERSON")
|
spancat.add_label("PERSON")
|
||||||
spancat.add_label("LOC")
|
spancat.add_label("LOC")
|
||||||
|
@ -102,13 +108,13 @@ def test_doc_gc():
|
||||||
# XXX This fails with length 0 sometimes
|
# XXX This fails with length 0 sometimes
|
||||||
assert len(spangroup) > 0
|
assert len(spangroup) > 0
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
span = spangroup[0]
|
spangroup[0]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"max_positive,nr_results", [(None, 4), (1, 2), (2, 3), (3, 4), (4, 4)]
|
"max_positive,nr_results", [(None, 4), (1, 2), (2, 3), (3, 4), (4, 4)]
|
||||||
)
|
)
|
||||||
def test_make_spangroup(max_positive, nr_results):
|
def test_make_spangroup_multilabel(max_positive, nr_results):
|
||||||
fix_random_seed(0)
|
fix_random_seed(0)
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
spancat = nlp.add_pipe(
|
spancat = nlp.add_pipe(
|
||||||
|
@ -120,10 +126,12 @@ def test_make_spangroup(max_positive, nr_results):
|
||||||
indices = ngram_suggester([doc])[0].dataXd
|
indices = ngram_suggester([doc])[0].dataXd
|
||||||
assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]]))
|
assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]]))
|
||||||
labels = ["Thing", "City", "Person", "GreatCity"]
|
labels = ["Thing", "City", "Person", "GreatCity"]
|
||||||
|
for label in labels:
|
||||||
|
spancat.add_label(label)
|
||||||
scores = numpy.asarray(
|
scores = numpy.asarray(
|
||||||
[[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
|
[[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
|
||||||
)
|
)
|
||||||
spangroup = spancat._make_span_group(doc, indices, scores, labels)
|
spangroup = spancat._make_span_group_multilabel(doc, indices, scores)
|
||||||
assert len(spangroup) == nr_results
|
assert len(spangroup) == nr_results
|
||||||
|
|
||||||
# first span is always the second token "London"
|
# first span is always the second token "London"
|
||||||
|
@ -154,6 +162,118 @@ def test_make_spangroup(max_positive, nr_results):
|
||||||
assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5)
|
assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"threshold,allow_overlap,nr_results",
|
||||||
|
[(0.05, True, 3), (0.05, False, 1), (0.5, True, 2), (0.5, False, 1)],
|
||||||
|
)
|
||||||
|
def test_make_spangroup_singlelabel(threshold, allow_overlap, nr_results):
|
||||||
|
fix_random_seed(0)
|
||||||
|
nlp = Language()
|
||||||
|
spancat = nlp.add_pipe(
|
||||||
|
"spancat",
|
||||||
|
config={
|
||||||
|
"spans_key": SPAN_KEY,
|
||||||
|
"threshold": threshold,
|
||||||
|
"max_positive": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
doc = nlp.make_doc("Greater London")
|
||||||
|
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2])
|
||||||
|
indices = ngram_suggester([doc])[0].dataXd
|
||||||
|
assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]]))
|
||||||
|
labels = ["Thing", "City", "Person", "GreatCity"]
|
||||||
|
for label in labels:
|
||||||
|
spancat.add_label(label)
|
||||||
|
scores = numpy.asarray(
|
||||||
|
[[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
|
||||||
|
)
|
||||||
|
spangroup = spancat._make_span_group_singlelabel(
|
||||||
|
doc, indices, scores, allow_overlap
|
||||||
|
)
|
||||||
|
assert len(spangroup) == nr_results
|
||||||
|
if threshold > 0.4:
|
||||||
|
if allow_overlap:
|
||||||
|
assert spangroup[0].text == "London"
|
||||||
|
assert spangroup[0].label_ == "City"
|
||||||
|
assert spangroup[1].text == "Greater London"
|
||||||
|
assert spangroup[1].label_ == "GreatCity"
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert spangroup[0].text == "Greater London"
|
||||||
|
assert spangroup[0].label_ == "GreatCity"
|
||||||
|
else:
|
||||||
|
if allow_overlap:
|
||||||
|
assert spangroup[0].text == "Greater"
|
||||||
|
assert spangroup[0].label_ == "City"
|
||||||
|
assert spangroup[1].text == "London"
|
||||||
|
assert spangroup[1].label_ == "City"
|
||||||
|
assert spangroup[2].text == "Greater London"
|
||||||
|
assert spangroup[2].label_ == "GreatCity"
|
||||||
|
else:
|
||||||
|
assert spangroup[0].text == "Greater London"
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_spangroup_negative_label():
|
||||||
|
fix_random_seed(0)
|
||||||
|
nlp_single = Language()
|
||||||
|
nlp_multi = Language()
|
||||||
|
spancat_single = nlp_single.add_pipe(
|
||||||
|
"spancat",
|
||||||
|
config={
|
||||||
|
"spans_key": SPAN_KEY,
|
||||||
|
"threshold": 0.1,
|
||||||
|
"max_positive": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
spancat_multi = nlp_multi.add_pipe(
|
||||||
|
"spancat",
|
||||||
|
config={
|
||||||
|
"spans_key": SPAN_KEY,
|
||||||
|
"threshold": 0.1,
|
||||||
|
"max_positive": 2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
spancat_single.add_negative_label = True
|
||||||
|
spancat_multi.add_negative_label = True
|
||||||
|
doc = nlp_single.make_doc("Greater London")
|
||||||
|
labels = ["Thing", "City", "Person", "GreatCity"]
|
||||||
|
for label in labels:
|
||||||
|
spancat_multi.add_label(label)
|
||||||
|
spancat_single.add_label(label)
|
||||||
|
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2])
|
||||||
|
indices = ngram_suggester([doc])[0].dataXd
|
||||||
|
assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]]))
|
||||||
|
scores = numpy.asarray(
|
||||||
|
[
|
||||||
|
[0.2, 0.4, 0.3, 0.1, 0.1],
|
||||||
|
[0.1, 0.6, 0.2, 0.4, 0.9],
|
||||||
|
[0.8, 0.7, 0.3, 0.9, 0.1],
|
||||||
|
],
|
||||||
|
dtype="f",
|
||||||
|
)
|
||||||
|
spangroup_multi = spancat_multi._make_span_group_multilabel(doc, indices, scores)
|
||||||
|
spangroup_single = spancat_single._make_span_group_singlelabel(doc, indices, scores)
|
||||||
|
assert len(spangroup_single) == 2
|
||||||
|
assert spangroup_single[0].text == "Greater"
|
||||||
|
assert spangroup_single[0].label_ == "City"
|
||||||
|
assert spangroup_single[1].text == "Greater London"
|
||||||
|
assert spangroup_single[1].label_ == "GreatCity"
|
||||||
|
|
||||||
|
assert len(spangroup_multi) == 6
|
||||||
|
assert spangroup_multi[0].text == "Greater"
|
||||||
|
assert spangroup_multi[0].label_ == "City"
|
||||||
|
assert spangroup_multi[1].text == "Greater"
|
||||||
|
assert spangroup_multi[1].label_ == "Person"
|
||||||
|
assert spangroup_multi[2].text == "London"
|
||||||
|
assert spangroup_multi[2].label_ == "City"
|
||||||
|
assert spangroup_multi[3].text == "London"
|
||||||
|
assert spangroup_multi[3].label_ == "GreatCity"
|
||||||
|
assert spangroup_multi[4].text == "Greater London"
|
||||||
|
assert spangroup_multi[4].label_ == "Thing"
|
||||||
|
assert spangroup_multi[5].text == "Greater London"
|
||||||
|
assert spangroup_multi[5].label_ == "GreatCity"
|
||||||
|
|
||||||
|
|
||||||
def test_ngram_suggester(en_tokenizer):
|
def test_ngram_suggester(en_tokenizer):
|
||||||
# test different n-gram lengths
|
# test different n-gram lengths
|
||||||
for size in [1, 2, 3]:
|
for size in [1, 2, 3]:
|
||||||
|
@ -371,9 +491,9 @@ def test_overfitting_IO_overlapping():
|
||||||
assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"}
|
assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"}
|
||||||
|
|
||||||
|
|
||||||
def test_zero_suggestions():
|
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
|
||||||
|
def test_zero_suggestions(name):
|
||||||
# Test with a suggester that can return 0 suggestions
|
# Test with a suggester that can return 0 suggestions
|
||||||
|
|
||||||
@registry.misc("test_mixed_zero_suggester")
|
@registry.misc("test_mixed_zero_suggester")
|
||||||
def make_mixed_zero_suggester():
|
def make_mixed_zero_suggester():
|
||||||
def mixed_zero_suggester(docs, *, ops=None):
|
def mixed_zero_suggester(docs, *, ops=None):
|
||||||
|
@ -400,7 +520,7 @@ def test_zero_suggestions():
|
||||||
fix_random_seed(0)
|
fix_random_seed(0)
|
||||||
nlp = English()
|
nlp = English()
|
||||||
spancat = nlp.add_pipe(
|
spancat = nlp.add_pipe(
|
||||||
"spancat",
|
name,
|
||||||
config={
|
config={
|
||||||
"suggester": {"@misc": "test_mixed_zero_suggester"},
|
"suggester": {"@misc": "test_mixed_zero_suggester"},
|
||||||
"spans_key": SPAN_KEY,
|
"spans_key": SPAN_KEY,
|
||||||
|
@ -408,7 +528,7 @@ def test_zero_suggestions():
|
||||||
)
|
)
|
||||||
train_examples = make_examples(nlp)
|
train_examples = make_examples(nlp)
|
||||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
assert spancat.model.get_dim("nO") == 2
|
assert spancat.model.get_dim("nO") == spancat._n_labels
|
||||||
assert set(spancat.labels) == {"LOC", "PERSON"}
|
assert set(spancat.labels) == {"LOC", "PERSON"}
|
||||||
|
|
||||||
nlp.update(train_examples, sgd=optimizer)
|
nlp.update(train_examples, sgd=optimizer)
|
||||||
|
@ -424,9 +544,10 @@ def test_zero_suggestions():
|
||||||
list(nlp.pipe(["", "one", "three three three"]))
|
list(nlp.pipe(["", "one", "three three three"]))
|
||||||
|
|
||||||
|
|
||||||
def test_set_candidates():
|
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
|
||||||
|
def test_set_candidates(name):
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
|
||||||
train_examples = make_examples(nlp)
|
train_examples = make_examples(nlp)
|
||||||
nlp.initialize(get_examples=lambda: train_examples)
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
texts = [
|
texts = [
|
||||||
|
|
|
@ -552,7 +552,14 @@ def test_parse_cli_overrides():
|
||||||
|
|
||||||
@pytest.mark.parametrize("lang", ["en", "nl"])
|
@pytest.mark.parametrize("lang", ["en", "nl"])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"pipeline", [["tagger", "parser", "ner"], [], ["ner", "textcat", "sentencizer"]]
|
"pipeline",
|
||||||
|
[
|
||||||
|
["tagger", "parser", "ner"],
|
||||||
|
[],
|
||||||
|
["ner", "textcat", "sentencizer"],
|
||||||
|
["morphologizer", "spancat", "entity_linker"],
|
||||||
|
["spancat_singlelabel", "textcat_multilabel"],
|
||||||
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("optimize", ["efficiency", "accuracy"])
|
@pytest.mark.parametrize("optimize", ["efficiency", "accuracy"])
|
||||||
@pytest.mark.parametrize("pretraining", [True, False])
|
@pytest.mark.parametrize("pretraining", [True, False])
|
||||||
|
|
|
@ -13,6 +13,13 @@ A span categorizer consists of two parts: a [suggester function](#suggesters)
|
||||||
that proposes candidate spans, which may or may not overlap, and a labeler model
|
that proposes candidate spans, which may or may not overlap, and a labeler model
|
||||||
that predicts zero or more labels for each candidate.
|
that predicts zero or more labels for each candidate.
|
||||||
|
|
||||||
|
This component comes in two forms: `spancat` and `spancat_singlelabel` (added in
|
||||||
|
spaCy v3.5.1). When you need to perform multi-label classification on your
|
||||||
|
spans, use `spancat`. The `spancat` component uses a `Logistic` layer where the
|
||||||
|
output class probabilities are independent for each class. However, if you need
|
||||||
|
to predict at most one true class for a span, then use `spancat_singlelabel`. It
|
||||||
|
uses a `Softmax` layer and treats the task as a multi-class problem.
|
||||||
|
|
||||||
Predicted spans will be saved in a [`SpanGroup`](/api/spangroup) on the doc.
|
Predicted spans will be saved in a [`SpanGroup`](/api/spangroup) on the doc.
|
||||||
Individual span scores can be found in `spangroup.attrs["scores"]`.
|
Individual span scores can be found in `spangroup.attrs["scores"]`.
|
||||||
|
|
||||||
|
@ -38,7 +45,7 @@ how the component should be configured. You can override its settings via the
|
||||||
[model architectures](/api/architectures) documentation for details on the
|
[model architectures](/api/architectures) documentation for details on the
|
||||||
architectures and their arguments and hyperparameters.
|
architectures and their arguments and hyperparameters.
|
||||||
|
|
||||||
> #### Example
|
> #### Example (spancat)
|
||||||
>
|
>
|
||||||
> ```python
|
> ```python
|
||||||
> from spacy.pipeline.spancat import DEFAULT_SPANCAT_MODEL
|
> from spacy.pipeline.spancat import DEFAULT_SPANCAT_MODEL
|
||||||
|
@ -52,14 +59,33 @@ architectures and their arguments and hyperparameters.
|
||||||
> nlp.add_pipe("spancat", config=config)
|
> nlp.add_pipe("spancat", config=config)
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Setting | Description |
|
> #### Example (spancat_singlelabel)
|
||||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
>
|
||||||
| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. Defaults to [`ngram_suggester`](#ngram_suggester). ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
|
> ```python
|
||||||
| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. Defaults to [SpanCategorizer](/api/architectures#SpanCategorizer). ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ |
|
> from spacy.pipeline.spancat import DEFAULT_SPANCAT_SINGLELABEL_MODEL
|
||||||
| `spans_key` | Key of the [`Doc.spans`](/api/doc#spans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"sc"`. ~~str~~ |
|
> config = {
|
||||||
| `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Defaults to `0.5`. ~~float~~ |
|
> "threshold": 0.5,
|
||||||
| `max_positive` | Maximum number of labels to consider positive per span. Defaults to `None`, indicating no limit. ~~Optional[int]~~ |
|
> "spans_key": "labeled_spans",
|
||||||
| `scorer` | The scoring method. Defaults to [`Scorer.score_spans`](/api/scorer#score_spans) for `Doc.spans[spans_key]` with overlapping spans allowed. ~~Optional[Callable]~~ |
|
> "model": DEFAULT_SPANCAT_SINGLELABEL_MODEL,
|
||||||
|
> "suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
|
||||||
|
> # Additional spancat_singlelabel parameters
|
||||||
|
> "negative_weight": 0.8,
|
||||||
|
> "allow_overlap": True,
|
||||||
|
> }
|
||||||
|
> nlp.add_pipe("spancat_singlelabel", config=config)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Setting | Description |
|
||||||
|
| --------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. Defaults to [`ngram_suggester`](#ngram_suggester). ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
|
||||||
|
| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. Defaults to [SpanCategorizer](/api/architectures#SpanCategorizer). ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ |
|
||||||
|
| `spans_key` | Key of the [`Doc.spans`](/api/doc#spans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"sc"`. ~~str~~ |
|
||||||
|
| `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Meant to be used in combination with the multi-class `spancat` component with a `Logistic` scoring layer. Defaults to `0.5`. ~~float~~ |
|
||||||
|
| `max_positive` | Maximum number of labels to consider positive per span. Defaults to `None`, indicating no limit. Meant to be used together with the `spancat` component and defaults to 0 with `spancat_singlelabel`. ~~Optional[int]~~ |
|
||||||
|
| `scorer` | The scoring method. Defaults to [`Scorer.score_spans`](/api/scorer#score_spans) for `Doc.spans[spans_key]` with overlapping spans allowed. ~~Optional[Callable]~~ |
|
||||||
|
| `add_negative_label` <Tag variant="new">3.5.1</Tag> | Whether to learn to predict a special negative label for each unannotated `Span` . This should be `True` when using a `Softmax` classifier layer and so its `True` by default for `spancat_singlelabel`. Spans with negative labels and their scores are not stored as annotations. ~~bool~~ |
|
||||||
|
| `negative_weight` <Tag variant="new">3.5.1</Tag> | Multiplier for the loss terms. It can be used to downweight the negative samples if there are too many. It is only used when `add_negative_label` is `True`. Defaults to `1.0`. ~~float~~ |
|
||||||
|
| `allow_overlap` <Tag variant="new">3.5.1</Tag> | If `True`, the data is assumed to contain overlapping spans. It is only available when `max_positive` is exactly 1. Defaults to `True`. ~~bool~~ |
|
||||||
|
|
||||||
```python
|
```python
|
||||||
%%GITHUB_SPACY/spacy/pipeline/spancat.py
|
%%GITHUB_SPACY/spacy/pipeline/spancat.py
|
||||||
|
@ -71,6 +97,7 @@ architectures and their arguments and hyperparameters.
|
||||||
>
|
>
|
||||||
> ```python
|
> ```python
|
||||||
> # Construction via add_pipe with default model
|
> # Construction via add_pipe with default model
|
||||||
|
> # Replace 'spancat' with 'spancat_singlelabel' for exclusive classes
|
||||||
> spancat = nlp.add_pipe("spancat")
|
> spancat = nlp.add_pipe("spancat")
|
||||||
>
|
>
|
||||||
> # Construction via add_pipe with custom model
|
> # Construction via add_pipe with custom model
|
||||||
|
@ -86,16 +113,19 @@ Create a new pipeline instance. In your application, you would normally use a
|
||||||
shortcut for this and instantiate the component using its string name and
|
shortcut for this and instantiate the component using its string name and
|
||||||
[`nlp.add_pipe`](/api/language#create_pipe).
|
[`nlp.add_pipe`](/api/language#create_pipe).
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
| --------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `vocab` | The shared vocabulary. ~~Vocab~~ |
|
| `vocab` | The shared vocabulary. ~~Vocab~~ |
|
||||||
| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ |
|
| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ |
|
||||||
| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
|
| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
|
||||||
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
|
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
|
||||||
| _keyword-only_ | |
|
| _keyword-only_ | |
|
||||||
| `spans_key` | Key of the [`Doc.spans`](/api/doc#sans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"sc"`. ~~str~~ |
|
| `spans_key` | Key of the [`Doc.spans`](/api/doc#sans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"sc"`. ~~str~~ |
|
||||||
| `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Defaults to `0.5`. ~~float~~ |
|
| `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Defaults to `0.5`. ~~float~~ |
|
||||||
| `max_positive` | Maximum number of labels to consider positive per span. Defaults to `None`, indicating no limit. ~~Optional[int]~~ |
|
| `max_positive` | Maximum number of labels to consider positive per span. Defaults to `None`, indicating no limit. ~~Optional[int]~~ |
|
||||||
|
| `allow_overlap` <Tag variant="new">3.5.1</Tag> | If `True`, the data is assumed to contain overlapping spans. It is only available when `max_positive` is exactly 1. Defaults to `True`. ~~bool~~ |
|
||||||
|
| `add_negative_label` <Tag variant="new">3.5.1</Tag> | Whether to learn to predict a special negative label for each unannotated `Span`. This should be `True` when using a `Softmax` classifier layer and so its `True` by default for `spancat_singlelabel` . Spans with negative labels and their scores are not stored as annotations. ~~bool~~ |
|
||||||
|
| `negative_weight` <Tag variant="new">3.5.1</Tag> | Multiplier for the loss terms. It can be used to downweight the negative samples if there are too many . It is only used when `add_negative_label` is `True`. Defaults to `1.0`. ~~float~~ |
|
||||||
|
|
||||||
## SpanCategorizer.\_\_call\_\_ {id="call",tag="method"}
|
## SpanCategorizer.\_\_call\_\_ {id="call",tag="method"}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user