diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja
index 441189341..c5e8c6c43 100644
--- a/spacy/cli/templates/quickstart_training.jinja
+++ b/spacy/cli/templates/quickstart_training.jinja
@@ -3,7 +3,7 @@ the docs and the init config command. It encodes various best practices and
can help generate the best possible configuration, given a user's requirements. #}
{%- set use_transformer = hardware != "cpu" and transformer_data -%}
{%- set transformer = transformer_data[optimize] if use_transformer else {} -%}
-{%- set listener_components = ["tagger", "morphologizer", "parser", "ner", "textcat", "textcat_multilabel", "entity_linker", "spancat", "trainable_lemmatizer"] -%}
+{%- set listener_components = ["tagger", "morphologizer", "parser", "ner", "textcat", "textcat_multilabel", "entity_linker", "spancat", "spancat_singlelabel", "trainable_lemmatizer"] -%}
[paths]
train = null
dev = null
@@ -28,7 +28,7 @@ lang = "{{ lang }}"
tok2vec/transformer. #}
{%- set with_accuracy_or_transformer = (use_transformer or with_accuracy) -%}
{%- set textcat_needs_features = has_textcat and with_accuracy_or_transformer -%}
-{%- if ("tagger" in components or "morphologizer" in components or "parser" in components or "ner" in components or "spancat" in components or "trainable_lemmatizer" in components or "entity_linker" in components or textcat_needs_features) -%}
+{%- if ("tagger" in components or "morphologizer" in components or "parser" in components or "ner" in components or "spancat" in components or "spancat_singlelabel" in components or "trainable_lemmatizer" in components or "entity_linker" in components or textcat_needs_features) -%}
{%- set full_pipeline = ["transformer" if use_transformer else "tok2vec"] + components -%}
{%- else -%}
{%- set full_pipeline = components -%}
@@ -159,6 +159,36 @@ grad_factor = 1.0
sizes = [1,2,3]
{% endif -%}
+{% if "spancat_singlelabel" in components %}
+[components.spancat_singlelabel]
+factory = "spancat_singlelabel"
+negative_weight = 1.0
+allow_overlap = true
+scorer = {"@scorers":"spacy.spancat_scorer.v1"}
+spans_key = "sc"
+
+[components.spancat_singlelabel.model]
+@architectures = "spacy.SpanCategorizer.v1"
+
+[components.spancat_singlelabel.model.reducer]
+@layers = "spacy.mean_max_reducer.v1"
+hidden_size = 128
+
+[components.spancat_singlelabel.model.scorer]
+@layers = "Softmax.v2"
+
+[components.spancat_singlelabel.model.tok2vec]
+@architectures = "spacy-transformers.TransformerListener.v1"
+grad_factor = 1.0
+
+[components.spancat_singlelabel.model.tok2vec.pooling]
+@layers = "reduce_mean.v1"
+
+[components.spancat_singlelabel.suggester]
+@misc = "spacy.ngram_suggester.v1"
+sizes = [1,2,3]
+{% endif %}
+
{% if "trainable_lemmatizer" in components -%}
[components.trainable_lemmatizer]
factory = "trainable_lemmatizer"
@@ -389,6 +419,33 @@ width = ${components.tok2vec.model.encode.width}
sizes = [1,2,3]
{% endif %}
+{% if "spancat_singlelabel" in components %}
+[components.spancat_singlelabel]
+factory = "spancat_singlelabel"
+negative_weight = 1.0
+allow_overlap = true
+scorer = {"@scorers":"spacy.spancat_scorer.v1"}
+spans_key = "sc"
+
+[components.spancat_singlelabel.model]
+@architectures = "spacy.SpanCategorizer.v1"
+
+[components.spancat_singlelabel.model.reducer]
+@layers = "spacy.mean_max_reducer.v1"
+hidden_size = 128
+
+[components.spancat_singlelabel.model.scorer]
+@layers = "Softmax.v2"
+
+[components.spancat_singlelabel.model.tok2vec]
+@architectures = "spacy.Tok2VecListener.v1"
+width = ${components.tok2vec.model.encode.width}
+
+[components.spancat_singlelabel.suggester]
+@misc = "spacy.ngram_suggester.v1"
+sizes = [1,2,3]
+{% endif %}
+
{% if "trainable_lemmatizer" in components -%}
[components.trainable_lemmatizer]
factory = "trainable_lemmatizer"
diff --git a/spacy/errors.py b/spacy/errors.py
index 1047ed21a..c897c29ff 100644
--- a/spacy/errors.py
+++ b/spacy/errors.py
@@ -969,6 +969,7 @@ class Errors(metaclass=ErrorsWithCodes):
"with `displacy.serve(doc, port=port)`")
E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` "
"or use `auto_select_port=True` to pick an available port automatically.")
+ E1051 = ("'allow_overlap' can only be False when max_positive is 1, but found 'max_positive': {max_positive}.")
# Deprecated model shortcuts, only used in errors and warnings
diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py
index a3388e81a..983e1fba9 100644
--- a/spacy/pipeline/spancat.py
+++ b/spacy/pipeline/spancat.py
@@ -1,4 +1,5 @@
-from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any
+from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union
+from dataclasses import dataclass
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
from thinc.api import Optimizer
from thinc.types import Ragged, Ints2d, Floats2d
@@ -43,7 +44,36 @@ maxout_pieces = 3
depth = 4
"""
+spancat_singlelabel_default_config = """
+[model]
+@architectures = "spacy.SpanCategorizer.v1"
+scorer = {"@layers": "Softmax.v2"}
+
+[model.reducer]
+@layers = spacy.mean_max_reducer.v1
+hidden_size = 128
+
+[model.tok2vec]
+@architectures = "spacy.Tok2Vec.v2"
+[model.tok2vec.embed]
+@architectures = "spacy.MultiHashEmbed.v1"
+width = 96
+rows = [5000, 1000, 2500, 1000]
+attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
+include_static_vectors = false
+
+[model.tok2vec.encode]
+@architectures = "spacy.MaxoutWindowEncoder.v2"
+width = ${model.tok2vec.embed.width}
+window_size = 1
+maxout_pieces = 3
+depth = 4
+"""
+
DEFAULT_SPANCAT_MODEL = Config().from_str(spancat_default_config)["model"]
+DEFAULT_SPANCAT_SINGLELABEL_MODEL = Config().from_str(
+ spancat_singlelabel_default_config
+)["model"]
@runtime_checkable
@@ -119,10 +149,14 @@ def make_spancat(
threshold: float,
max_positive: Optional[int],
) -> "SpanCategorizer":
- """Create a SpanCategorizer component. The span categorizer consists of two
+ """Create a SpanCategorizer component and configure it for multi-label
+ classification to be able to assign multiple labels for each span.
+ The span categorizer consists of two
parts: a suggester function that proposes candidate spans, and a labeller
model that predicts one or more labels for each span.
+ name (str): The component instance name, used to add entries to the
+ losses during training.
suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans.
Spans are returned as a ragged array with two integer columns, for the
start and end positions.
@@ -144,12 +178,80 @@ def make_spancat(
"""
return SpanCategorizer(
nlp.vocab,
- suggester=suggester,
model=model,
- spans_key=spans_key,
- threshold=threshold,
- max_positive=max_positive,
+ suggester=suggester,
name=name,
+ spans_key=spans_key,
+ negative_weight=None,
+ allow_overlap=True,
+ max_positive=max_positive,
+ threshold=threshold,
+ scorer=scorer,
+ add_negative_label=False,
+ )
+
+
+@Language.factory(
+ "spancat_singlelabel",
+ assigns=["doc.spans"],
+ default_config={
+ "spans_key": "sc",
+ "model": DEFAULT_SPANCAT_SINGLELABEL_MODEL,
+ "negative_weight": 1.0,
+ "suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
+ "scorer": {"@scorers": "spacy.spancat_scorer.v1"},
+ "allow_overlap": True,
+ },
+ default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0},
+)
+def make_spancat_singlelabel(
+ nlp: Language,
+ name: str,
+ suggester: Suggester,
+ model: Model[Tuple[List[Doc], Ragged], Floats2d],
+ spans_key: str,
+ negative_weight: float,
+ allow_overlap: bool,
+ scorer: Optional[Callable],
+) -> "SpanCategorizer":
+ """Create a SpanCategorizer component and configure it for multi-class
+ classification. With this configuration each span can get at most one
+ label. The span categorizer consists of two
+ parts: a suggester function that proposes candidate spans, and a labeller
+ model that predicts one or more labels for each span.
+
+ name (str): The component instance name, used to add entries to the
+ losses during training.
+ suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans.
+ Spans are returned as a ragged array with two integer columns, for the
+ start and end positions.
+ model (Model[Tuple[List[Doc], Ragged], Floats2d]): A model instance that
+ is given a list of documents and (start, end) indices representing
+ candidate span offsets. The model predicts a probability for each category
+ for each span.
+ spans_key (str): Key of the doc.spans dict to save the spans under. During
+ initialization and training, the component will look for spans on the
+ reference document under the same key.
+ scorer (Optional[Callable]): The scoring method. Defaults to
+ Scorer.score_spans for the Doc.spans[spans_key] with overlapping
+ spans allowed.
+ negative_weight (float): Multiplier for the loss terms.
+ Can be used to downweight the negative samples if there are too many.
+ allow_overlap (bool): If True the data is assumed to contain overlapping spans.
+ Otherwise it produces non-overlapping spans greedily prioritizing
+ higher assigned label scores.
+ """
+ return SpanCategorizer(
+ nlp.vocab,
+ model=model,
+ suggester=suggester,
+ name=name,
+ spans_key=spans_key,
+ negative_weight=negative_weight,
+ allow_overlap=allow_overlap,
+ max_positive=1,
+ add_negative_label=True,
+ threshold=None,
scorer=scorer,
)
@@ -172,6 +274,27 @@ def make_spancat_scorer():
return spancat_score
+@dataclass
+class _Intervals:
+ """
+ Helper class to avoid storing overlapping spans.
+ """
+
+ def __init__(self):
+ self.ranges = set()
+
+ def add(self, i, j):
+ for e in range(i, j):
+ self.ranges.add(e)
+
+ def __contains__(self, rang):
+ i, j = rang
+ for e in range(i, j):
+ if e in self.ranges:
+ return True
+ return False
+
+
class SpanCategorizer(TrainablePipe):
"""Pipeline component to label spans of text.
@@ -185,25 +308,43 @@ class SpanCategorizer(TrainablePipe):
suggester: Suggester,
name: str = "spancat",
*,
+ add_negative_label: bool = False,
spans_key: str = "spans",
- threshold: float = 0.5,
+ negative_weight: Optional[float] = 1.0,
+ allow_overlap: Optional[bool] = True,
max_positive: Optional[int] = None,
+ threshold: Optional[float] = 0.5,
scorer: Optional[Callable] = spancat_score,
) -> None:
- """Initialize the span categorizer.
+ """Initialize the multi-label or multi-class span categorizer.
+
vocab (Vocab): The shared vocabulary.
model (thinc.api.Model): The Thinc Model powering the pipeline component.
+ For multi-class classification (single label per span) we recommend
+ using a Softmax classifier as a the final layer, while for multi-label
+ classification (multiple possible labels per span) we recommend Logistic.
+ suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans.
+ Spans are returned as a ragged array with two integer columns, for the
+ start and end positions.
name (str): The component instance name, used to add entries to the
losses during training.
spans_key (str): Key of the Doc.spans dict to save the spans under.
During initialization and training, the component will look for
spans on the reference document under the same key. Defaults to
`"spans"`.
- threshold (float): Minimum probability to consider a prediction
- positive. Spans with a positive prediction will be saved on the Doc.
- Defaults to 0.5.
+ add_negative_label (bool): Learn to predict a special 'negative_label'
+ when a Span is not annotated.
+ threshold (Optional[float]): Minimum probability to consider a prediction
+ positive. Defaults to 0.5. Spans with a positive prediction will be saved
+ on the Doc.
max_positive (Optional[int]): Maximum number of labels to consider
positive per span. Defaults to None, indicating no limit.
+ negative_weight (float): Multiplier for the loss terms.
+ Can be used to downweight the negative samples if there are too many
+ when add_negative_label is True. Otherwise its unused.
+ allow_overlap (bool): If True the data is assumed to contain overlapping spans.
+ Otherwise it produces non-overlapping spans greedily prioritizing
+ higher assigned label scores. Only used when max_positive is 1.
scorer (Optional[Callable]): The scoring method. Defaults to
Scorer.score_spans for the Doc.spans[spans_key] with overlapping
spans allowed.
@@ -215,12 +356,17 @@ class SpanCategorizer(TrainablePipe):
"spans_key": spans_key,
"threshold": threshold,
"max_positive": max_positive,
+ "negative_weight": negative_weight,
+ "allow_overlap": allow_overlap,
}
self.vocab = vocab
self.suggester = suggester
self.model = model
self.name = name
self.scorer = scorer
+ self.add_negative_label = add_negative_label
+ if not allow_overlap and max_positive is not None and max_positive > 1:
+ raise ValueError(Errors.E1051.format(max_positive=max_positive))
@property
def key(self) -> str:
@@ -230,6 +376,21 @@ class SpanCategorizer(TrainablePipe):
"""
return str(self.cfg["spans_key"])
+ def _allow_extra_label(self) -> None:
+ """Raise an error if the component can not add any more labels."""
+ nO = None
+ if self.model.has_dim("nO"):
+ nO = self.model.get_dim("nO")
+ elif self.model.has_ref("output_layer") and self.model.get_ref(
+ "output_layer"
+ ).has_dim("nO"):
+ nO = self.model.get_ref("output_layer").get_dim("nO")
+ if nO is not None and nO == self._n_labels:
+ if not self.is_resizable:
+ raise ValueError(
+ Errors.E922.format(name=self.name, nO=self.model.get_dim("nO"))
+ )
+
def add_label(self, label: str) -> int:
"""Add a new label to the pipe.
@@ -263,6 +424,27 @@ class SpanCategorizer(TrainablePipe):
"""
return list(self.labels)
+ @property
+ def _label_map(self) -> Dict[str, int]:
+ """RETURNS (Dict[str, int]): The label map."""
+ return {label: i for i, label in enumerate(self.labels)}
+
+ @property
+ def _n_labels(self) -> int:
+ """RETURNS (int): Number of labels."""
+ if self.add_negative_label:
+ return len(self.labels) + 1
+ else:
+ return len(self.labels)
+
+ @property
+ def _negative_label_i(self) -> Union[int, None]:
+ """RETURNS (Union[int, None]): Index of the negative label."""
+ if self.add_negative_label:
+ return len(self.label_data)
+ else:
+ return None
+
def predict(self, docs: Iterable[Doc]):
"""Apply the pipeline's model to a batch of docs, without modifying them.
@@ -304,14 +486,24 @@ class SpanCategorizer(TrainablePipe):
DOCS: https://spacy.io/api/spancategorizer#set_annotations
"""
- labels = self.labels
indices, scores = indices_scores
offset = 0
for i, doc in enumerate(docs):
indices_i = indices[i].dataXd
- doc.spans[self.key] = self._make_span_group(
- doc, indices_i, scores[offset : offset + indices.lengths[i]], labels # type: ignore[arg-type]
- )
+ allow_overlap = cast(bool, self.cfg["allow_overlap"])
+ if self.cfg["max_positive"] == 1:
+ doc.spans[self.key] = self._make_span_group_singlelabel(
+ doc,
+ indices_i,
+ scores[offset : offset + indices.lengths[i]],
+ allow_overlap,
+ )
+ else:
+ doc.spans[self.key] = self._make_span_group_multilabel(
+ doc,
+ indices_i,
+ scores[offset : offset + indices.lengths[i]],
+ )
offset += indices.lengths[i]
def update(
@@ -371,9 +563,11 @@ class SpanCategorizer(TrainablePipe):
spans = Ragged(
self.model.ops.to_numpy(spans.data), self.model.ops.to_numpy(spans.lengths)
)
- label_map = {label: i for i, label in enumerate(self.labels)}
target = numpy.zeros(scores.shape, dtype=scores.dtype)
+ if self.add_negative_label:
+ negative_spans = numpy.ones((scores.shape[0]))
offset = 0
+ label_map = self._label_map
for i, eg in enumerate(examples):
# Map (start, end) offset of spans to the row in the d_scores array,
# so that we can adjust the gradient for predictions that were
@@ -390,10 +584,16 @@ class SpanCategorizer(TrainablePipe):
row = spans_index[key]
k = label_map[gold_span.label_]
target[row, k] = 1.0
+ if self.add_negative_label:
+ # delete negative label target.
+ negative_spans[row] = 0.0
# The target is a flat array for all docs. Track the position
# we're at within the flat array.
offset += spans.lengths[i]
target = self.model.ops.asarray(target, dtype="f") # type: ignore
+ if self.add_negative_label:
+ negative_samples = numpy.nonzero(negative_spans)[0]
+ target[negative_samples, self._negative_label_i] = 1.0 # type: ignore
# The target will have the values 0 (for untrue predictions) or 1
# (for true predictions).
# The scores should be in the range [0, 1].
@@ -402,6 +602,10 @@ class SpanCategorizer(TrainablePipe):
# If the prediction is 0.9 and it's false, the gradient will be
# 0.9 (0.9 - 0.0)
d_scores = scores - target
+ if self.add_negative_label:
+ neg_weight = cast(float, self.cfg["negative_weight"])
+ if neg_weight != 1.0:
+ d_scores[negative_samples] *= neg_weight
loss = float((d_scores**2).sum())
return loss, d_scores
@@ -438,7 +642,7 @@ class SpanCategorizer(TrainablePipe):
if subbatch:
docs = [eg.x for eg in subbatch]
spans = build_ngram_suggester(sizes=[1])(docs)
- Y = self.model.ops.alloc2f(spans.dataXd.shape[0], len(self.labels))
+ Y = self.model.ops.alloc2f(spans.dataXd.shape[0], self._n_labels)
self.model.initialize(X=(docs, spans), Y=Y)
else:
self.model.initialize()
@@ -452,31 +656,96 @@ class SpanCategorizer(TrainablePipe):
eg.reference.spans.get(self.key, []), allow_overlap=True
)
- def _make_span_group(
- self, doc: Doc, indices: Ints2d, scores: Floats2d, labels: List[str]
+ def _make_span_group_multilabel(
+ self,
+ doc: Doc,
+ indices: Ints2d,
+ scores: Floats2d,
) -> SpanGroup:
+ """Find the top-k labels for each span (k=max_positive)."""
spans = SpanGroup(doc, name=self.key)
- max_positive = self.cfg["max_positive"]
+ if scores.size == 0:
+ return spans
+ scores = self.model.ops.to_numpy(scores)
+ indices = self.model.ops.to_numpy(indices)
threshold = self.cfg["threshold"]
+ max_positive = self.cfg["max_positive"]
keeps = scores >= threshold
- ranked = (scores * -1).argsort() # type: ignore
if max_positive is not None:
assert isinstance(max_positive, int)
+ if self.add_negative_label:
+ negative_scores = numpy.copy(scores[:, self._negative_label_i])
+ scores[:, self._negative_label_i] = -numpy.inf
+ ranked = (scores * -1).argsort() # type: ignore
+ scores[:, self._negative_label_i] = negative_scores
+ else:
+ ranked = (scores * -1).argsort() # type: ignore
span_filter = ranked[:, max_positive:]
for i, row in enumerate(span_filter):
keeps[i, row] = False
- spans.attrs["scores"] = scores[keeps].flatten()
-
- indices = self.model.ops.to_numpy(indices)
- keeps = self.model.ops.to_numpy(keeps)
+ attrs_scores = []
for i in range(indices.shape[0]):
start = indices[i, 0]
end = indices[i, 1]
-
for j, keep in enumerate(keeps[i]):
if keep:
- spans.append(Span(doc, start, end, label=labels[j]))
+ if j != self._negative_label_i:
+ spans.append(Span(doc, start, end, label=self.labels[j]))
+ attrs_scores.append(scores[i, j])
+ spans.attrs["scores"] = numpy.array(attrs_scores)
+ return spans
+
+ def _make_span_group_singlelabel(
+ self,
+ doc: Doc,
+ indices: Ints2d,
+ scores: Floats2d,
+ allow_overlap: bool = True,
+ ) -> SpanGroup:
+ """Find the argmax label for each span."""
+ # Handle cases when there are zero suggestions
+ if scores.size == 0:
+ return SpanGroup(doc, name=self.key)
+ scores = self.model.ops.to_numpy(scores)
+ indices = self.model.ops.to_numpy(indices)
+ predicted = scores.argmax(axis=1)
+ argmax_scores = numpy.take_along_axis(
+ scores, numpy.expand_dims(predicted, 1), axis=1
+ )
+ keeps = numpy.ones(predicted.shape, dtype=bool)
+ # Remove samples where the negative label is the argmax.
+ if self.add_negative_label:
+ keeps = numpy.logical_and(keeps, predicted != self._negative_label_i)
+ # Filter samples according to threshold.
+ threshold = self.cfg["threshold"]
+ if threshold is not None:
+ keeps = numpy.logical_and(keeps, (argmax_scores >= threshold).squeeze())
+ # Sort spans according to argmax probability
+ if not allow_overlap:
+ # Get the probabilities
+ sort_idx = (argmax_scores.squeeze() * -1).argsort()
+ predicted = predicted[sort_idx]
+ indices = indices[sort_idx]
+ keeps = keeps[sort_idx]
+ seen = _Intervals()
+ spans = SpanGroup(doc, name=self.key)
+ attrs_scores = []
+ for i in range(indices.shape[0]):
+ if not keeps[i]:
+ continue
+
+ label = predicted[i]
+ start = indices[i, 0]
+ end = indices[i, 1]
+
+ if not allow_overlap:
+ if (start, end) in seen:
+ continue
+ else:
+ seen.add(start, end)
+ attrs_scores.append(argmax_scores[i])
+ spans.append(Span(doc, start, end, label=self.labels[label]))
return spans
diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py
index e9db983d3..cf6304042 100644
--- a/spacy/tests/pipeline/test_spancat.py
+++ b/spacy/tests/pipeline/test_spancat.py
@@ -15,6 +15,8 @@ OPS = get_current_ops()
SPAN_KEY = "labeled_spans"
+SPANCAT_COMPONENTS = ["spancat", "spancat_singlelabel"]
+
TRAIN_DATA = [
("Who is Shaka Khan?", {"spans": {SPAN_KEY: [(7, 17, "PERSON")]}}),
(
@@ -41,38 +43,42 @@ def make_examples(nlp, data=TRAIN_DATA):
return train_examples
-def test_no_label():
+@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
+def test_no_label(name):
nlp = Language()
- nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
+ nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
with pytest.raises(ValueError):
nlp.initialize()
-def test_no_resize():
+@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
+def test_no_resize(name):
nlp = Language()
- spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
+ spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
spancat.add_label("Thing")
spancat.add_label("Phrase")
assert spancat.labels == ("Thing", "Phrase")
nlp.initialize()
- assert spancat.model.get_dim("nO") == 2
+ assert spancat.model.get_dim("nO") == spancat._n_labels
# this throws an error because the spancat can't be resized after initialization
with pytest.raises(ValueError):
spancat.add_label("Stuff")
-def test_implicit_labels():
+@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
+def test_implicit_labels(name):
nlp = Language()
- spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
+ spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
assert len(spancat.labels) == 0
train_examples = make_examples(nlp)
nlp.initialize(get_examples=lambda: train_examples)
assert spancat.labels == ("PERSON", "LOC")
-def test_explicit_labels():
+@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
+def test_explicit_labels(name):
nlp = Language()
- spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
+ spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
assert len(spancat.labels) == 0
spancat.add_label("PERSON")
spancat.add_label("LOC")
@@ -102,13 +108,13 @@ def test_doc_gc():
# XXX This fails with length 0 sometimes
assert len(spangroup) > 0
with pytest.raises(RuntimeError):
- span = spangroup[0]
+ spangroup[0]
@pytest.mark.parametrize(
"max_positive,nr_results", [(None, 4), (1, 2), (2, 3), (3, 4), (4, 4)]
)
-def test_make_spangroup(max_positive, nr_results):
+def test_make_spangroup_multilabel(max_positive, nr_results):
fix_random_seed(0)
nlp = Language()
spancat = nlp.add_pipe(
@@ -120,10 +126,12 @@ def test_make_spangroup(max_positive, nr_results):
indices = ngram_suggester([doc])[0].dataXd
assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]]))
labels = ["Thing", "City", "Person", "GreatCity"]
+ for label in labels:
+ spancat.add_label(label)
scores = numpy.asarray(
[[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
)
- spangroup = spancat._make_span_group(doc, indices, scores, labels)
+ spangroup = spancat._make_span_group_multilabel(doc, indices, scores)
assert len(spangroup) == nr_results
# first span is always the second token "London"
@@ -154,6 +162,118 @@ def test_make_spangroup(max_positive, nr_results):
assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5)
+@pytest.mark.parametrize(
+ "threshold,allow_overlap,nr_results",
+ [(0.05, True, 3), (0.05, False, 1), (0.5, True, 2), (0.5, False, 1)],
+)
+def test_make_spangroup_singlelabel(threshold, allow_overlap, nr_results):
+ fix_random_seed(0)
+ nlp = Language()
+ spancat = nlp.add_pipe(
+ "spancat",
+ config={
+ "spans_key": SPAN_KEY,
+ "threshold": threshold,
+ "max_positive": 1,
+ },
+ )
+ doc = nlp.make_doc("Greater London")
+ ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2])
+ indices = ngram_suggester([doc])[0].dataXd
+ assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]]))
+ labels = ["Thing", "City", "Person", "GreatCity"]
+ for label in labels:
+ spancat.add_label(label)
+ scores = numpy.asarray(
+ [[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
+ )
+ spangroup = spancat._make_span_group_singlelabel(
+ doc, indices, scores, allow_overlap
+ )
+ assert len(spangroup) == nr_results
+ if threshold > 0.4:
+ if allow_overlap:
+ assert spangroup[0].text == "London"
+ assert spangroup[0].label_ == "City"
+ assert spangroup[1].text == "Greater London"
+ assert spangroup[1].label_ == "GreatCity"
+
+ else:
+ assert spangroup[0].text == "Greater London"
+ assert spangroup[0].label_ == "GreatCity"
+ else:
+ if allow_overlap:
+ assert spangroup[0].text == "Greater"
+ assert spangroup[0].label_ == "City"
+ assert spangroup[1].text == "London"
+ assert spangroup[1].label_ == "City"
+ assert spangroup[2].text == "Greater London"
+ assert spangroup[2].label_ == "GreatCity"
+ else:
+ assert spangroup[0].text == "Greater London"
+
+
+def test_make_spangroup_negative_label():
+ fix_random_seed(0)
+ nlp_single = Language()
+ nlp_multi = Language()
+ spancat_single = nlp_single.add_pipe(
+ "spancat",
+ config={
+ "spans_key": SPAN_KEY,
+ "threshold": 0.1,
+ "max_positive": 1,
+ },
+ )
+ spancat_multi = nlp_multi.add_pipe(
+ "spancat",
+ config={
+ "spans_key": SPAN_KEY,
+ "threshold": 0.1,
+ "max_positive": 2,
+ },
+ )
+ spancat_single.add_negative_label = True
+ spancat_multi.add_negative_label = True
+ doc = nlp_single.make_doc("Greater London")
+ labels = ["Thing", "City", "Person", "GreatCity"]
+ for label in labels:
+ spancat_multi.add_label(label)
+ spancat_single.add_label(label)
+ ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2])
+ indices = ngram_suggester([doc])[0].dataXd
+ assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]]))
+ scores = numpy.asarray(
+ [
+ [0.2, 0.4, 0.3, 0.1, 0.1],
+ [0.1, 0.6, 0.2, 0.4, 0.9],
+ [0.8, 0.7, 0.3, 0.9, 0.1],
+ ],
+ dtype="f",
+ )
+ spangroup_multi = spancat_multi._make_span_group_multilabel(doc, indices, scores)
+ spangroup_single = spancat_single._make_span_group_singlelabel(doc, indices, scores)
+ assert len(spangroup_single) == 2
+ assert spangroup_single[0].text == "Greater"
+ assert spangroup_single[0].label_ == "City"
+ assert spangroup_single[1].text == "Greater London"
+ assert spangroup_single[1].label_ == "GreatCity"
+
+ assert len(spangroup_multi) == 6
+ assert spangroup_multi[0].text == "Greater"
+ assert spangroup_multi[0].label_ == "City"
+ assert spangroup_multi[1].text == "Greater"
+ assert spangroup_multi[1].label_ == "Person"
+ assert spangroup_multi[2].text == "London"
+ assert spangroup_multi[2].label_ == "City"
+ assert spangroup_multi[3].text == "London"
+ assert spangroup_multi[3].label_ == "GreatCity"
+ assert spangroup_multi[4].text == "Greater London"
+ assert spangroup_multi[4].label_ == "Thing"
+ assert spangroup_multi[5].text == "Greater London"
+ assert spangroup_multi[5].label_ == "GreatCity"
+
+
def test_ngram_suggester(en_tokenizer):
# test different n-gram lengths
for size in [1, 2, 3]:
@@ -371,9 +491,9 @@ def test_overfitting_IO_overlapping():
assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"}
-def test_zero_suggestions():
+@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
+def test_zero_suggestions(name):
# Test with a suggester that can return 0 suggestions
-
@registry.misc("test_mixed_zero_suggester")
def make_mixed_zero_suggester():
def mixed_zero_suggester(docs, *, ops=None):
@@ -400,7 +520,7 @@ def test_zero_suggestions():
fix_random_seed(0)
nlp = English()
spancat = nlp.add_pipe(
- "spancat",
+ name,
config={
"suggester": {"@misc": "test_mixed_zero_suggester"},
"spans_key": SPAN_KEY,
@@ -408,7 +528,7 @@ def test_zero_suggestions():
)
train_examples = make_examples(nlp)
optimizer = nlp.initialize(get_examples=lambda: train_examples)
- assert spancat.model.get_dim("nO") == 2
+ assert spancat.model.get_dim("nO") == spancat._n_labels
assert set(spancat.labels) == {"LOC", "PERSON"}
nlp.update(train_examples, sgd=optimizer)
@@ -424,9 +544,10 @@ def test_zero_suggestions():
list(nlp.pipe(["", "one", "three three three"]))
-def test_set_candidates():
+@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
+def test_set_candidates(name):
nlp = Language()
- spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
+ spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
train_examples = make_examples(nlp)
nlp.initialize(get_examples=lambda: train_examples)
texts = [
diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py
index f5bcdfd23..1fdf059b3 100644
--- a/spacy/tests/test_cli.py
+++ b/spacy/tests/test_cli.py
@@ -552,7 +552,14 @@ def test_parse_cli_overrides():
@pytest.mark.parametrize("lang", ["en", "nl"])
@pytest.mark.parametrize(
- "pipeline", [["tagger", "parser", "ner"], [], ["ner", "textcat", "sentencizer"]]
+ "pipeline",
+ [
+ ["tagger", "parser", "ner"],
+ [],
+ ["ner", "textcat", "sentencizer"],
+ ["morphologizer", "spancat", "entity_linker"],
+ ["spancat_singlelabel", "textcat_multilabel"],
+ ],
)
@pytest.mark.parametrize("optimize", ["efficiency", "accuracy"])
@pytest.mark.parametrize("pretraining", [True, False])
diff --git a/website/docs/api/spancategorizer.mdx b/website/docs/api/spancategorizer.mdx
index f39c0aff9..c7de2324b 100644
--- a/website/docs/api/spancategorizer.mdx
+++ b/website/docs/api/spancategorizer.mdx
@@ -13,6 +13,13 @@ A span categorizer consists of two parts: a [suggester function](#suggesters)
that proposes candidate spans, which may or may not overlap, and a labeler model
that predicts zero or more labels for each candidate.
+This component comes in two forms: `spancat` and `spancat_singlelabel` (added in
+spaCy v3.5.1). When you need to perform multi-label classification on your
+spans, use `spancat`. The `spancat` component uses a `Logistic` layer where the
+output class probabilities are independent for each class. However, if you need
+to predict at most one true class for a span, then use `spancat_singlelabel`. It
+uses a `Softmax` layer and treats the task as a multi-class problem.
+
Predicted spans will be saved in a [`SpanGroup`](/api/spangroup) on the doc.
Individual span scores can be found in `spangroup.attrs["scores"]`.
@@ -38,7 +45,7 @@ how the component should be configured. You can override its settings via the
[model architectures](/api/architectures) documentation for details on the
architectures and their arguments and hyperparameters.
-> #### Example
+> #### Example (spancat)
>
> ```python
> from spacy.pipeline.spancat import DEFAULT_SPANCAT_MODEL
@@ -52,14 +59,33 @@ architectures and their arguments and hyperparameters.
> nlp.add_pipe("spancat", config=config)
> ```
-| Setting | Description |
-| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. Defaults to [`ngram_suggester`](#ngram_suggester). ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
-| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. Defaults to [SpanCategorizer](/api/architectures#SpanCategorizer). ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ |
-| `spans_key` | Key of the [`Doc.spans`](/api/doc#spans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"sc"`. ~~str~~ |
-| `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Defaults to `0.5`. ~~float~~ |
-| `max_positive` | Maximum number of labels to consider positive per span. Defaults to `None`, indicating no limit. ~~Optional[int]~~ |
-| `scorer` | The scoring method. Defaults to [`Scorer.score_spans`](/api/scorer#score_spans) for `Doc.spans[spans_key]` with overlapping spans allowed. ~~Optional[Callable]~~ |
+> #### Example (spancat_singlelabel)
+>
+> ```python
+> from spacy.pipeline.spancat import DEFAULT_SPANCAT_SINGLELABEL_MODEL
+> config = {
+> "threshold": 0.5,
+> "spans_key": "labeled_spans",
+> "model": DEFAULT_SPANCAT_SINGLELABEL_MODEL,
+> "suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
+> # Additional spancat_singlelabel parameters
+> "negative_weight": 0.8,
+> "allow_overlap": True,
+> }
+> nlp.add_pipe("spancat_singlelabel", config=config)
+> ```
+
+| Setting | Description |
+| --------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. Defaults to [`ngram_suggester`](#ngram_suggester). ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
+| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. Defaults to [SpanCategorizer](/api/architectures#SpanCategorizer). ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ |
+| `spans_key` | Key of the [`Doc.spans`](/api/doc#spans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"sc"`. ~~str~~ |
+| `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Meant to be used in combination with the multi-class `spancat` component with a `Logistic` scoring layer. Defaults to `0.5`. ~~float~~ |
+| `max_positive` | Maximum number of labels to consider positive per span. Defaults to `None`, indicating no limit. Meant to be used together with the `spancat` component and defaults to 0 with `spancat_singlelabel`. ~~Optional[int]~~ |
+| `scorer` | The scoring method. Defaults to [`Scorer.score_spans`](/api/scorer#score_spans) for `Doc.spans[spans_key]` with overlapping spans allowed. ~~Optional[Callable]~~ |
+| `add_negative_label` 3.5.1 | Whether to learn to predict a special negative label for each unannotated `Span` . This should be `True` when using a `Softmax` classifier layer and so its `True` by default for `spancat_singlelabel`. Spans with negative labels and their scores are not stored as annotations. ~~bool~~ |
+| `negative_weight` 3.5.1 | Multiplier for the loss terms. It can be used to downweight the negative samples if there are too many. It is only used when `add_negative_label` is `True`. Defaults to `1.0`. ~~float~~ |
+| `allow_overlap` 3.5.1 | If `True`, the data is assumed to contain overlapping spans. It is only available when `max_positive` is exactly 1. Defaults to `True`. ~~bool~~ |
```python
%%GITHUB_SPACY/spacy/pipeline/spancat.py
@@ -71,6 +97,7 @@ architectures and their arguments and hyperparameters.
>
> ```python
> # Construction via add_pipe with default model
+> # Replace 'spancat' with 'spancat_singlelabel' for exclusive classes
> spancat = nlp.add_pipe("spancat")
>
> # Construction via add_pipe with custom model
@@ -86,16 +113,19 @@ Create a new pipeline instance. In your application, you would normally use a
shortcut for this and instantiate the component using its string name and
[`nlp.add_pipe`](/api/language#create_pipe).
-| Name | Description |
-| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
-| `vocab` | The shared vocabulary. ~~Vocab~~ |
-| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ |
-| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
-| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
-| _keyword-only_ | |
-| `spans_key` | Key of the [`Doc.spans`](/api/doc#sans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"sc"`. ~~str~~ |
-| `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Defaults to `0.5`. ~~float~~ |
-| `max_positive` | Maximum number of labels to consider positive per span. Defaults to `None`, indicating no limit. ~~Optional[int]~~ |
+| Name | Description |
+| --------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `vocab` | The shared vocabulary. ~~Vocab~~ |
+| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ |
+| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
+| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
+| _keyword-only_ | |
+| `spans_key` | Key of the [`Doc.spans`](/api/doc#sans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"sc"`. ~~str~~ |
+| `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Defaults to `0.5`. ~~float~~ |
+| `max_positive` | Maximum number of labels to consider positive per span. Defaults to `None`, indicating no limit. ~~Optional[int]~~ |
+| `allow_overlap` 3.5.1 | If `True`, the data is assumed to contain overlapping spans. It is only available when `max_positive` is exactly 1. Defaults to `True`. ~~bool~~ |
+| `add_negative_label` 3.5.1 | Whether to learn to predict a special negative label for each unannotated `Span`. This should be `True` when using a `Softmax` classifier layer and so its `True` by default for `spancat_singlelabel` . Spans with negative labels and their scores are not stored as annotations. ~~bool~~ |
+| `negative_weight` 3.5.1 | Multiplier for the loss terms. It can be used to downweight the negative samples if there are too many . It is only used when `add_negative_label` is `True`. Defaults to `1.0`. ~~float~~ |
## SpanCategorizer.\_\_call\_\_ {id="call",tag="method"}