mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Add SpanCategorizer component (#6747)
* Draft spancat model * Add spancat model * Add test for extract_spans * Add extract_spans layer * Upd extract_spans * Add spancat model * Add test for spancat model * Upd spancat model * Update spancat component * Upd spancat * Update spancat model * Add quick spancat test * Import SpanCategorizer * Fix SpanCategorizer component * Import SpanGroup * Fix span extraction * Fix import * Fix import * Upd model * Update spancat models * Add scoring, update defaults * Update and add docs * Fix type * Update spacy/ml/extract_spans.py * Auto-format and fix import * Fix comment * Fix type * Fix type * Update website/docs/api/spancategorizer.md * Fix comment Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Better defense Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Fix labels list Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Update spacy/ml/extract_spans.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Update spacy/pipeline/spancat.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Set annotations during update * Set annotations in spancat * fix imports in test * Update spacy/pipeline/spancat.py * replace MaxoutLogistic with LinearLogistic * fix config * various small fixes * remove set_annotations parameter in update * use our beloved tupley format with recent support for doc.spans * bugfix to allow renaming the default span_key (scores weren't showing up) * use different key in docs example * change defaults to better-working parameters from project (WIP) * register spacy.extract_spans.v1 for legacy purposes * Upd dev version so can build wheel * layers instead of architectures for smaller building blocks * Update website/docs/api/spancategorizer.md Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Update website/docs/api/spancategorizer.md Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Include additional scores from overrides in combined score weights * Parameterize spans key in scoring Parameterize the `SpanCategorizer` `spans_key` for scoring purposes so that it's possible to evaluate multiple `spancat` components in the same pipeline. * Use the (intentionally very short) default spans key `sc` in the `SpanCategorizer` * Adjust the default score weights to include the default key * Adjust the scorer to use `spans_{spans_key}` as the prefix for the returned score * Revert addition of `attr_name` argument to `score_spans` and adjust the key in the `getter` instead. Note that for `spancat` components with a custom `span_key`, the score weights currently need to be modified manually in `[training.score_weights]` for them to be available during training. To suppress the default score weights `spans_sc_p/r/f` during training, set them to `null` in `[training.score_weights]`. * Update website/docs/api/scorer.md * Fix scorer for spans key containing underscore * Increment version * Add Spans to Evaluate CLI (#8439) * Add Spans to Evaluate CLI * Change to spans_key * Add spans per_type output Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Fix spancat GPU issues (#8455) * Fix GPU issues * Require thinc >=8.0.6 * Switch to glorot_uniform_init * Fix and test ngram suggester * Include final ngram in doc for all sizes * Fix ngrams for docs of the same length as ngram size * Handle batches of docs that result in no ngrams * Add tests Co-authored-by: Ines Montani <ines@ines.io> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: svlandeg <sofie.vanlandeghem@gmail.com> Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> Co-authored-by: Nirant <NirantK@users.noreply.github.com>
This commit is contained in:
parent
172dfec4f2
commit
f9946154d9
|
@ -5,7 +5,7 @@ requires = [
|
|||
"cymem>=2.0.2,<2.1.0",
|
||||
"preshed>=3.0.2,<3.1.0",
|
||||
"murmurhash>=0.28.0,<1.1.0",
|
||||
"thinc>=8.0.5,<8.1.0",
|
||||
"thinc>=8.0.6,<8.1.0",
|
||||
"blis>=0.4.0,<0.8.0",
|
||||
"pathy",
|
||||
"numpy>=1.15.0",
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
spacy-legacy>=3.0.6,<3.1.0
|
||||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
thinc>=8.0.5,<8.1.0
|
||||
thinc>=8.0.6,<8.1.0
|
||||
blis>=0.4.0,<0.8.0
|
||||
ml_datasets>=0.2.0,<0.3.0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
|
|
|
@ -37,14 +37,14 @@ setup_requires =
|
|||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
thinc>=8.0.5,<8.1.0
|
||||
thinc>=8.0.6,<8.1.0
|
||||
install_requires =
|
||||
# Our libraries
|
||||
spacy-legacy>=3.0.6,<3.1.0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
thinc>=8.0.5,<8.1.0
|
||||
thinc>=8.0.6,<8.1.0
|
||||
blis>=0.4.0,<0.8.0
|
||||
wasabi>=0.8.1,<1.1.0
|
||||
srsly>=2.4.1,<3.0.0
|
||||
|
|
|
@ -60,6 +60,7 @@ def evaluate(
|
|||
displacy_path: Optional[Path] = None,
|
||||
displacy_limit: int = 25,
|
||||
silent: bool = True,
|
||||
spans_key="sc",
|
||||
) -> Scorer:
|
||||
msg = Printer(no_print=silent, pretty=not silent)
|
||||
fix_random_seed()
|
||||
|
@ -90,6 +91,9 @@ def evaluate(
|
|||
"SENT P": "sents_p",
|
||||
"SENT R": "sents_r",
|
||||
"SENT F": "sents_f",
|
||||
"SPAN P": f"spans_{spans_key}_p",
|
||||
"SPAN R": f"spans_{spans_key}_r",
|
||||
"SPAN F": f"spans_{spans_key}_f",
|
||||
"SPEED": "speed",
|
||||
}
|
||||
results = {}
|
||||
|
@ -121,6 +125,10 @@ def evaluate(
|
|||
if scores["ents_per_type"]:
|
||||
print_prf_per_type(msg, scores["ents_per_type"], "NER", "type")
|
||||
data["ents_per_type"] = scores["ents_per_type"]
|
||||
if f"spans_{spans_key}_per_type" in scores:
|
||||
if scores[f"spans_{spans_key}_per_type"]:
|
||||
print_prf_per_type(msg, scores[f"spans_{spans_key}_per_type"], "SPANS", "type")
|
||||
data[f"spans_{spans_key}_per_type"] = scores[f"spans_{spans_key}_per_type"]
|
||||
if "cats_f_per_type" in scores:
|
||||
if scores["cats_f_per_type"]:
|
||||
print_prf_per_type(msg, scores["cats_f_per_type"], "Textcat F", "label")
|
||||
|
|
60
spacy/ml/extract_spans.py
Normal file
60
spacy/ml/extract_spans.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
from typing import Tuple, Callable
|
||||
from thinc.api import Model, to_numpy
|
||||
from thinc.types import Ragged, Ints1d
|
||||
|
||||
from ..util import registry
|
||||
|
||||
|
||||
@registry.layers("spacy.extract_spans.v1")
|
||||
def extract_spans() -> Model[Tuple[Ragged, Ragged], Ragged]:
|
||||
"""Extract spans from a sequence of source arrays, as specified by an array
|
||||
of (start, end) indices. The output is a ragged array of the
|
||||
extracted spans.
|
||||
"""
|
||||
return Model(
|
||||
"extract_spans", forward, layers=[], refs={}, attrs={}, dims={}, init=init
|
||||
)
|
||||
|
||||
|
||||
def init(model, X=None, Y=None):
|
||||
pass
|
||||
|
||||
|
||||
def forward(
|
||||
model: Model, source_spans: Tuple[Ragged, Ragged], is_train: bool
|
||||
) -> Tuple[Ragged, Callable]:
|
||||
"""Get subsequences from source vectors."""
|
||||
ops = model.ops
|
||||
X, spans = source_spans
|
||||
assert spans.dataXd.ndim == 2
|
||||
indices = _get_span_indices(ops, spans, X.lengths)
|
||||
Y = Ragged(X.dataXd[indices], spans.dataXd[:, 1] - spans.dataXd[:, 0])
|
||||
x_shape = X.dataXd.shape
|
||||
x_lengths = X.lengths
|
||||
|
||||
def backprop_windows(dY: Ragged) -> Tuple[Ragged, Ragged]:
|
||||
dX = Ragged(ops.alloc2f(*x_shape), x_lengths)
|
||||
ops.scatter_add(dX.dataXd, indices, dY.dataXd)
|
||||
return (dX, spans)
|
||||
|
||||
return Y, backprop_windows
|
||||
|
||||
|
||||
def _get_span_indices(ops, spans: Ragged, lengths: Ints1d) -> Ints1d:
|
||||
"""Construct a flat array that has the indices we want to extract from the
|
||||
source data. For instance, if we want the spans (5, 9), (8, 10) the
|
||||
indices will be [5, 6, 7, 8, 8, 9].
|
||||
"""
|
||||
spans, lengths = _ensure_cpu(spans, lengths)
|
||||
indices = []
|
||||
offset = 0
|
||||
for i, length in enumerate(lengths):
|
||||
spans_i = spans[i].dataXd + offset
|
||||
for j in range(spans_i.shape[0]):
|
||||
indices.append(ops.xp.arange(spans_i[j, 0], spans_i[j, 1]))
|
||||
offset += length
|
||||
return ops.flatten(indices)
|
||||
|
||||
|
||||
def _ensure_cpu(spans: Ragged, lengths: Ints1d) -> Tuple[Ragged, Ints1d]:
|
||||
return (Ragged(to_numpy(spans.dataXd), to_numpy(spans.lengths)), to_numpy(lengths))
|
|
@ -1,6 +1,7 @@
|
|||
from .entity_linker import * # noqa
|
||||
from .multi_task import * # noqa
|
||||
from .parser import * # noqa
|
||||
from .spancat import * # noqa
|
||||
from .tagger import * # noqa
|
||||
from .textcat import * # noqa
|
||||
from .tok2vec import * # noqa
|
||||
|
|
54
spacy/ml/models/spancat.py
Normal file
54
spacy/ml/models/spancat.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
from typing import List, Tuple
|
||||
from thinc.api import Model, with_getitem, chain, list2ragged, Logistic
|
||||
from thinc.api import Maxout, Linear, concatenate, glorot_uniform_init
|
||||
from thinc.api import reduce_mean, reduce_max, reduce_first, reduce_last
|
||||
from thinc.types import Ragged, Floats2d
|
||||
|
||||
from ...util import registry
|
||||
from ...tokens import Doc
|
||||
from ..extract_spans import extract_spans
|
||||
|
||||
|
||||
@registry.layers.register("spacy.LinearLogistic.v1")
|
||||
def build_linear_logistic(nO=None, nI=None) -> Model[Floats2d, Floats2d]:
|
||||
"""An output layer for multi-label classification. It uses a linear layer
|
||||
followed by a logistic activation.
|
||||
"""
|
||||
return chain(Linear(nO=nO, nI=nI, init_W=glorot_uniform_init), Logistic())
|
||||
|
||||
|
||||
@registry.layers.register("spacy.mean_max_reducer.v1")
|
||||
def build_mean_max_reducer(hidden_size: int) -> Model[Ragged, Floats2d]:
|
||||
"""Reduce sequences by concatenating their mean and max pooled vectors,
|
||||
and then combine the concatenated vectors with a hidden layer.
|
||||
"""
|
||||
return chain(
|
||||
concatenate(reduce_last(), reduce_first(), reduce_mean(), reduce_max()),
|
||||
Maxout(nO=hidden_size, normalize=True, dropout=0.0),
|
||||
)
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.SpanCategorizer.v1")
|
||||
def build_spancat_model(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
reducer: Model[Ragged, Floats2d],
|
||||
scorer: Model[Floats2d, Floats2d],
|
||||
) -> Model[Tuple[List[Doc], Ragged], Floats2d]:
|
||||
"""Build a span categorizer model, given a token-to-vector model, a
|
||||
reducer model to map the sequence of vectors for each span down to a single
|
||||
vector, and a scorer model to map the vectors to probabilities.
|
||||
|
||||
tok2vec (Model[List[Doc], List[Floats2d]]): The tok2vec model.
|
||||
reducer (Model[Ragged, Floats2d]): The reducer model.
|
||||
scorer (Model[Floats2d, Floats2d]): The scorer model.
|
||||
"""
|
||||
model = chain(
|
||||
with_getitem(0, chain(tok2vec, list2ragged())),
|
||||
extract_spans(),
|
||||
reducer,
|
||||
scorer,
|
||||
)
|
||||
model.set_ref("tok2vec", tok2vec)
|
||||
model.set_ref("reducer", reducer)
|
||||
model.set_ref("scorer", scorer)
|
||||
return model
|
|
@ -11,6 +11,7 @@ from .senter import SentenceRecognizer
|
|||
from .sentencizer import Sentencizer
|
||||
from .tagger import Tagger
|
||||
from .textcat import TextCategorizer
|
||||
from .spancat import SpanCategorizer
|
||||
from .textcat_multilabel import MultiLabel_TextCategorizer
|
||||
from .tok2vec import Tok2Vec
|
||||
from .functions import merge_entities, merge_noun_chunks, merge_subtokens
|
||||
|
@ -27,6 +28,7 @@ __all__ = [
|
|||
"Pipe",
|
||||
"SentenceRecognizer",
|
||||
"Sentencizer",
|
||||
"SpanCategorizer",
|
||||
"Tagger",
|
||||
"TextCategorizer",
|
||||
"Tok2Vec",
|
||||
|
|
411
spacy/pipeline/spancat.py
Normal file
411
spacy/pipeline/spancat.py
Normal file
|
@ -0,0 +1,411 @@
|
|||
import numpy
|
||||
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any
|
||||
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
||||
from thinc.api import Optimizer
|
||||
from thinc.types import Ragged, Ints2d, Floats2d
|
||||
|
||||
from ..scorer import Scorer
|
||||
from ..language import Language
|
||||
from .trainable_pipe import TrainablePipe
|
||||
from ..tokens import Doc, SpanGroup, Span
|
||||
from ..vocab import Vocab
|
||||
from ..training import Example, validate_examples
|
||||
from ..errors import Errors
|
||||
from ..util import registry
|
||||
|
||||
|
||||
spancat_default_config = """
|
||||
[model]
|
||||
@architectures = "spacy.SpanCategorizer.v1"
|
||||
scorer = {"@layers": "spacy.LinearLogistic.v1"}
|
||||
|
||||
[model.reducer]
|
||||
@layers = spacy.mean_max_reducer.v1
|
||||
hidden_size = 128
|
||||
|
||||
[model.tok2vec]
|
||||
@architectures = "spacy.Tok2Vec.v1"
|
||||
|
||||
[model.tok2vec.embed]
|
||||
@architectures = "spacy.MultiHashEmbed.v1"
|
||||
width = 96
|
||||
rows = [5000, 2000, 1000, 1000]
|
||||
attrs = ["ORTH", "PREFIX", "SUFFIX", "SHAPE"]
|
||||
include_static_vectors = false
|
||||
|
||||
[model.tok2vec.encode]
|
||||
@architectures = "spacy.MaxoutWindowEncoder.v1"
|
||||
width = ${model.tok2vec.embed.width}
|
||||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
depth = 4
|
||||
"""
|
||||
|
||||
DEFAULT_SPANCAT_MODEL = Config().from_str(spancat_default_config)["model"]
|
||||
|
||||
|
||||
@registry.misc("ngram_suggester.v1")
|
||||
def build_ngram_suggester(sizes: List[int]) -> Callable[[List[Doc]], Ragged]:
|
||||
"""Suggest all spans of the given lengths. Spans are returned as a ragged
|
||||
array of integers. The array has two columns, indicating the start and end
|
||||
position."""
|
||||
|
||||
def ngram_suggester(docs: List[Doc], *, ops: Optional[Ops] = None) -> Ragged:
|
||||
if ops is None:
|
||||
ops = get_current_ops()
|
||||
spans = []
|
||||
lengths = []
|
||||
for doc in docs:
|
||||
starts = ops.xp.arange(len(doc), dtype="i")
|
||||
starts = starts.reshape((-1, 1))
|
||||
length = 0
|
||||
for size in sizes:
|
||||
if size <= len(doc):
|
||||
starts_size = starts[:len(doc) - (size - 1)]
|
||||
spans.append(ops.xp.hstack((starts_size, starts_size + size)))
|
||||
length += spans[-1].shape[0]
|
||||
if spans:
|
||||
assert spans[-1].ndim == 2, spans[-1].shape
|
||||
lengths.append(length)
|
||||
if len(spans) > 0:
|
||||
output = Ragged(ops.xp.vstack(spans), ops.asarray(lengths, dtype="i"))
|
||||
else:
|
||||
output = Ragged(ops.xp.zeros((0,0)), ops.asarray(lengths, dtype="i"))
|
||||
|
||||
assert output.dataXd.ndim == 2
|
||||
return output
|
||||
|
||||
return ngram_suggester
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"spancat",
|
||||
assigns=["doc.spans"],
|
||||
default_config={
|
||||
"threshold": 0.5,
|
||||
"spans_key": "sc",
|
||||
"max_positive": None,
|
||||
"model": DEFAULT_SPANCAT_MODEL,
|
||||
"suggester": {"@misc": "ngram_suggester.v1", "sizes": [1, 2, 3]},
|
||||
},
|
||||
default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0},
|
||||
)
|
||||
def make_spancat(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
suggester: Callable[[List[Doc]], Ragged],
|
||||
model: Model[Tuple[List[Doc], Ragged], Floats2d],
|
||||
spans_key: str,
|
||||
threshold: float = 0.5,
|
||||
max_positive: Optional[int] = None,
|
||||
) -> "SpanCategorizer":
|
||||
"""Create a SpanCategorizer component. The span categorizer consists of two
|
||||
parts: a suggester function that proposes candidate spans, and a labeller
|
||||
model that predicts one or more labels for each span.
|
||||
|
||||
suggester (Callable[List[Doc], Ragged]): A function that suggests spans.
|
||||
Spans are returned as a ragged array with two integer columns, for the
|
||||
start and end positions.
|
||||
model (Model[Tuple[List[Doc], Ragged], Floats2d]): A model instance that
|
||||
is given a list of documents and (start, end) indices representing
|
||||
candidate span offsets. The model predicts a probability for each category
|
||||
for each span.
|
||||
spans_key (str): Key of the doc.spans dict to save the spans under. During
|
||||
initialization and training, the component will look for spans on the
|
||||
reference document under the same key.
|
||||
threshold (float): Minimum probability to consider a prediction positive.
|
||||
Spans with a positive prediction will be saved on the Doc. Defaults to
|
||||
0.5.
|
||||
max_positive (Optional[int]): Maximum number of labels to consider positive
|
||||
per span. Defaults to None, indicating no limit.
|
||||
"""
|
||||
return SpanCategorizer(
|
||||
nlp.vocab,
|
||||
suggester=suggester,
|
||||
model=model,
|
||||
spans_key=spans_key,
|
||||
threshold=threshold,
|
||||
max_positive=max_positive,
|
||||
name=name,
|
||||
)
|
||||
|
||||
|
||||
class SpanCategorizer(TrainablePipe):
|
||||
"""Pipeline component to label spans of text.
|
||||
|
||||
DOCS: https://spacy.io/api/spancategorizer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab: Vocab,
|
||||
model: Model[Tuple[List[Doc], Ragged], Floats2d],
|
||||
suggester: Callable[[List[Doc]], Ragged],
|
||||
name: str = "spancat",
|
||||
*,
|
||||
spans_key: str = "spans",
|
||||
threshold: float = 0.5,
|
||||
max_positive: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Initialize the span categorizer.
|
||||
|
||||
DOCS: https://spacy.io/api/spancategorizer#init
|
||||
"""
|
||||
self.cfg = {
|
||||
"labels": [],
|
||||
"spans_key": spans_key,
|
||||
"threshold": threshold,
|
||||
"max_positive": max_positive,
|
||||
}
|
||||
self.vocab = vocab
|
||||
self.suggester = suggester
|
||||
self.model = model
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Key of the doc.spans dict to save the spans under. During
|
||||
initialization and training, the component will look for spans on the
|
||||
reference document under the same key.
|
||||
"""
|
||||
return self.cfg["spans_key"]
|
||||
|
||||
def add_label(self, label: str) -> int:
|
||||
"""Add a new label to the pipe.
|
||||
|
||||
label (str): The label to add.
|
||||
RETURNS (int): 0 if label is already present, otherwise 1.
|
||||
|
||||
DOCS: https://spacy.io/api/spancategorizer#add_label
|
||||
"""
|
||||
if not isinstance(label, str):
|
||||
raise ValueError(Errors.E187)
|
||||
if label in self.labels:
|
||||
return 0
|
||||
self.cfg["labels"].append(label)
|
||||
self.vocab.strings.add(label)
|
||||
return 1
|
||||
|
||||
@property
|
||||
def labels(self) -> Tuple[str]:
|
||||
"""RETURNS (Tuple[str]): The labels currently added to the component.
|
||||
|
||||
DOCS: https://spacy.io/api/spancategorizer#labels
|
||||
"""
|
||||
return tuple(self.cfg["labels"])
|
||||
|
||||
@property
|
||||
def label_data(self) -> List[str]:
|
||||
"""RETURNS (List[str]): Information about the component's labels.
|
||||
|
||||
DOCS: https://spacy.io/api/spancategorizer#label_data
|
||||
"""
|
||||
return list(self.labels)
|
||||
|
||||
def predict(self, docs: Iterable[Doc]):
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
|
||||
docs (Iterable[Doc]): The documents to predict.
|
||||
RETURNS: The models prediction for each document.
|
||||
|
||||
DOCS: https://spacy.io/api/spancategorizer#predict
|
||||
"""
|
||||
indices = self.suggester(docs, ops=self.model.ops)
|
||||
scores = self.model.predict((docs, indices))
|
||||
return (indices, scores)
|
||||
|
||||
def set_annotations(self, docs: Iterable[Doc], indices_scores) -> None:
|
||||
"""Modify a batch of Doc objects, using pre-computed scores.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
scores: The scores to set, produced by SpanCategorizer.predict.
|
||||
|
||||
DOCS: https://spacy.io/api/spancategorizer#set_annotations
|
||||
"""
|
||||
labels = self.labels
|
||||
indices, scores = indices_scores
|
||||
offset = 0
|
||||
for i, doc in enumerate(docs):
|
||||
indices_i = indices[i].dataXd
|
||||
doc.spans[self.key] = self._make_span_group(
|
||||
doc, indices_i, scores[offset : offset + indices.lengths[i]], labels
|
||||
)
|
||||
offset += indices.lengths[i]
|
||||
|
||||
def update(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""Learn from a batch of documents and gold-standard information,
|
||||
updating the pipe's model. Delegates to predict and get_loss.
|
||||
|
||||
examples (Iterable[Example]): A batch of Example objects.
|
||||
drop (float): The dropout rate.
|
||||
sgd (thinc.api.Optimizer): The optimizer.
|
||||
losses (Dict[str, float]): Optional record of the loss during training.
|
||||
Updated using the component name as the key.
|
||||
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||
|
||||
DOCS: https://spacy.io/api/spancategorizer#update
|
||||
"""
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
validate_examples(examples, "SpanCategorizer.update")
|
||||
self._validate_categories(examples)
|
||||
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
return losses
|
||||
docs = [eg.predicted for eg in examples]
|
||||
spans = self.suggester(docs, ops=self.model.ops)
|
||||
if spans.lengths.sum() == 0:
|
||||
return losses
|
||||
set_dropout_rate(self.model, drop)
|
||||
scores, backprop_scores = self.model.begin_update((docs, spans))
|
||||
loss, d_scores = self.get_loss(examples, (spans, scores))
|
||||
backprop_scores(d_scores)
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
losses[self.name] += loss
|
||||
return losses
|
||||
|
||||
def get_loss(
|
||||
self, examples: Iterable[Example], spans_scores: Tuple[Ragged, Ragged]
|
||||
) -> Tuple[float, float]:
|
||||
"""Find the loss and gradient of loss for the batch of documents and
|
||||
their predicted scores.
|
||||
|
||||
examples (Iterable[Examples]): The batch of examples.
|
||||
spans_scores: Scores representing the model's predictions.
|
||||
RETURNS (Tuple[float, float]): The loss and the gradient.
|
||||
|
||||
DOCS: https://spacy.io/api/spancategorizer#get_loss
|
||||
"""
|
||||
spans, scores = spans_scores
|
||||
spans = Ragged(
|
||||
self.model.ops.to_numpy(spans.data), self.model.ops.to_numpy(spans.lengths)
|
||||
)
|
||||
label_map = {label: i for i, label in enumerate(self.labels)}
|
||||
target = numpy.zeros(scores.shape, dtype=scores.dtype)
|
||||
offset = 0
|
||||
for i, eg in enumerate(examples):
|
||||
# Map (start, end) offset of spans to the row in the d_scores array,
|
||||
# so that we can adjust the gradient for predictions that were
|
||||
# in the gold standard.
|
||||
spans_index = {}
|
||||
spans_i = spans[i].dataXd
|
||||
for j in range(spans.lengths[i]):
|
||||
start = int(spans_i[j, 0])
|
||||
end = int(spans_i[j, 1])
|
||||
spans_index[(start, end)] = offset + j
|
||||
for gold_span in self._get_aligned_spans(eg):
|
||||
key = (gold_span.start, gold_span.end)
|
||||
if key in spans_index:
|
||||
row = spans_index[key]
|
||||
k = label_map[gold_span.label_]
|
||||
target[row, k] = 1.0
|
||||
# The target is a flat array for all docs. Track the position
|
||||
# we're at within the flat array.
|
||||
offset += spans.lengths[i]
|
||||
target = self.model.ops.asarray(target, dtype="f")
|
||||
# The target will have the values 0 (for untrue predictions) or 1
|
||||
# (for true predictions).
|
||||
# The scores should be in the range [0, 1].
|
||||
# If the prediction is 0.9 and it's true, the gradient
|
||||
# will be -0.1 (0.9 - 1.0).
|
||||
# If the prediction is 0.9 and it's false, the gradient will be
|
||||
# 0.9 (0.9 - 0.0)
|
||||
d_scores = scores - target
|
||||
loss = float((d_scores ** 2).sum())
|
||||
return loss, d_scores
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
get_examples: Callable[[], Iterable[Example]],
|
||||
*,
|
||||
nlp: Language = None,
|
||||
labels: Optional[Dict] = None,
|
||||
) -> None:
|
||||
"""Initialize the pipe for training, using a representative set
|
||||
of data examples.
|
||||
|
||||
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||
returns a representative sample of gold-standard Example objects.
|
||||
nlp (Language): The current nlp object the component is part of.
|
||||
labels: The labels to add to the component, typically generated by the
|
||||
`init labels` command. If no labels are provided, the get_examples
|
||||
callback is used to extract the labels from the data.
|
||||
|
||||
DOCS: https://spacy.io/api/spancategorizer#initialize
|
||||
"""
|
||||
subbatch = []
|
||||
if labels is not None:
|
||||
for label in labels:
|
||||
self.add_label(label)
|
||||
for eg in get_examples():
|
||||
if labels is None:
|
||||
for span in eg.reference.spans[self.key]:
|
||||
self.add_label(span.label_)
|
||||
if len(subbatch) < 10:
|
||||
subbatch.append(eg)
|
||||
self._require_labels()
|
||||
if subbatch:
|
||||
docs = [eg.x for eg in subbatch]
|
||||
spans = self.suggester(docs)
|
||||
Y = self.model.ops.alloc2f(spans.dataXd.shape[0], len(self.labels))
|
||||
self.model.initialize(X=(docs, spans), Y=Y)
|
||||
else:
|
||||
self.model.initialize()
|
||||
|
||||
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
"""Score a batch of examples.
|
||||
|
||||
examples (Iterable[Example]): The examples to score.
|
||||
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_cats.
|
||||
|
||||
DOCS: https://spacy.io/api/spancategorizer#score
|
||||
"""
|
||||
validate_examples(examples, "SpanCategorizer.score")
|
||||
self._validate_categories(examples)
|
||||
kwargs = dict(kwargs)
|
||||
attr_prefix = "spans_"
|
||||
kwargs.setdefault("attr", f"{attr_prefix}{self.key}")
|
||||
kwargs.setdefault("labels", self.labels)
|
||||
kwargs.setdefault("multi_label", True)
|
||||
kwargs.setdefault("threshold", self.cfg["threshold"])
|
||||
kwargs.setdefault(
|
||||
"getter", lambda doc, key: doc.spans.get(key[len(attr_prefix) :], [])
|
||||
)
|
||||
kwargs.setdefault("has_annotation", lambda doc: self.key in doc.spans)
|
||||
return Scorer.score_spans(examples, **kwargs)
|
||||
|
||||
def _validate_categories(self, examples):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
def _get_aligned_spans(self, eg: Example):
|
||||
return eg.get_aligned_spans_y2x(eg.reference.spans.get(self.key, []))
|
||||
|
||||
def _make_span_group(
|
||||
self, doc: Doc, indices: Ints2d, scores: Floats2d, labels: List[str]
|
||||
) -> SpanGroup:
|
||||
spans = SpanGroup(doc, name=self.key)
|
||||
max_positive = self.cfg["max_positive"]
|
||||
threshold = self.cfg["threshold"]
|
||||
for i in range(indices.shape[0]):
|
||||
start = int(indices[i, 0])
|
||||
end = int(indices[i, 1])
|
||||
positives = []
|
||||
for j, score in enumerate(scores[i]):
|
||||
if score >= threshold:
|
||||
positives.append((score, start, end, labels[j]))
|
||||
positives.sort(reverse=True)
|
||||
if max_positive:
|
||||
positives = positives[:max_positive]
|
||||
for score, start, end, label in positives:
|
||||
spans.append(Span(doc, start, end, label=label))
|
||||
return spans
|
|
@ -101,7 +101,8 @@ cdef class TrainablePipe(Pipe):
|
|||
|
||||
def update(self,
|
||||
examples: Iterable["Example"],
|
||||
*, drop: float=0.0,
|
||||
*,
|
||||
drop: float=0.0,
|
||||
sgd: Optimizer=None,
|
||||
losses: Optional[Dict[str, float]]=None) -> Dict[str, float]:
|
||||
"""Learn from a batch of documents and gold-standard information,
|
||||
|
|
|
@ -353,6 +353,7 @@ def test_language_factories_invalid():
|
|||
([{"a": 0.0, "b": 0.0}, {"c": 1.0}], {}, {"a": 0.0, "b": 0.0, "c": 1.0}),
|
||||
([{"a": 0.0, "b": 0.0}, {"c": 0.0}], {"c": 0.2}, {"a": 0.0, "b": 0.0, "c": 1.0}),
|
||||
([{"a": 0.5, "b": 0.5, "c": 1.0, "d": 1.0}], {"a": 0.0, "b": 0.0}, {"a": 0.0, "b": 0.0, "c": 0.5, "d": 0.5}),
|
||||
([{"a": 0.5, "b": 0.5, "c": 1.0, "d": 1.0}], {"a": 0.0, "b": 0.0, "f": 0.0}, {"a": 0.0, "b": 0.0, "c": 0.5, "d": 0.5, "f": 0.0}),
|
||||
],
|
||||
)
|
||||
def test_language_factories_combine_score_weights(weights, override, expected):
|
||||
|
|
146
spacy/tests/pipeline/test_spancat.py
Normal file
146
spacy/tests/pipeline/test_spancat.py
Normal file
|
@ -0,0 +1,146 @@
|
|||
from numpy.testing import assert_equal
|
||||
from spacy.language import Language
|
||||
from spacy.training import Example
|
||||
from spacy.util import fix_random_seed, registry
|
||||
|
||||
|
||||
SPAN_KEY = "labeled_spans"
|
||||
|
||||
TRAIN_DATA = [
|
||||
("Who is Shaka Khan?", {"spans": {SPAN_KEY: [(7, 17, "PERSON")]}}),
|
||||
(
|
||||
"I like London and Berlin.",
|
||||
{"spans": {SPAN_KEY: [(7, 13, "LOC"), (18, 24, "LOC")]}},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def make_get_examples(nlp):
|
||||
train_examples = []
|
||||
for t in TRAIN_DATA:
|
||||
eg = Example.from_dict(nlp.make_doc(t[0]), t[1])
|
||||
train_examples.append(eg)
|
||||
|
||||
def get_examples():
|
||||
return train_examples
|
||||
|
||||
return get_examples
|
||||
|
||||
|
||||
def test_simple_train():
|
||||
fix_random_seed(0)
|
||||
nlp = Language()
|
||||
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||
get_examples = make_get_examples(nlp)
|
||||
nlp.initialize(get_examples)
|
||||
sgd = nlp.create_optimizer()
|
||||
assert len(spancat.labels) != 0
|
||||
for i in range(40):
|
||||
losses = {}
|
||||
nlp.update(list(get_examples()), losses=losses, drop=0.1, sgd=sgd)
|
||||
doc = nlp("I like London and Berlin.")
|
||||
assert doc.spans[spancat.key] == doc.spans[SPAN_KEY]
|
||||
assert len(doc.spans[spancat.key]) == 2
|
||||
assert doc.spans[spancat.key][0].text == "London"
|
||||
scores = nlp.evaluate(get_examples())
|
||||
assert f"spans_{SPAN_KEY}_f" in scores
|
||||
assert scores[f"spans_{SPAN_KEY}_f"] == 1.0
|
||||
|
||||
|
||||
def test_ngram_suggester(en_tokenizer):
|
||||
# test different n-gram lengths
|
||||
for size in [1, 2, 3]:
|
||||
ngram_suggester = registry.misc.get("ngram_suggester.v1")(sizes=[size])
|
||||
docs = [
|
||||
en_tokenizer(text)
|
||||
for text in [
|
||||
"a",
|
||||
"a b",
|
||||
"a b c",
|
||||
"a b c d",
|
||||
"a b c d e",
|
||||
"a " * 100,
|
||||
]
|
||||
]
|
||||
ngrams = ngram_suggester(docs)
|
||||
# span sizes are correct
|
||||
for s in ngrams.data:
|
||||
assert s[1] - s[0] == size
|
||||
# spans are within docs
|
||||
offset = 0
|
||||
for i, doc in enumerate(docs):
|
||||
spans = ngrams.dataXd[offset : offset + ngrams.lengths[i]]
|
||||
spans_set = set()
|
||||
for span in spans:
|
||||
assert 0 <= span[0] < len(doc)
|
||||
assert 0 < span[1] <= len(doc)
|
||||
spans_set.add((span[0], span[1]))
|
||||
# spans are unique
|
||||
assert spans.shape[0] == len(spans_set)
|
||||
offset += ngrams.lengths[i]
|
||||
# the number of spans is correct
|
||||
assert_equal(
|
||||
ngrams.lengths,
|
||||
[max(0, len(doc) - (size - 1)) for doc in docs]
|
||||
)
|
||||
|
||||
# test 1-3-gram suggestions
|
||||
ngram_suggester = registry.misc.get("ngram_suggester.v1")(sizes=[1, 2, 3])
|
||||
docs = [
|
||||
en_tokenizer(text) for text in ["a", "a b", "a b c", "a b c d", "a b c d e"]
|
||||
]
|
||||
ngrams = ngram_suggester(docs)
|
||||
assert_equal(ngrams.lengths, [1, 3, 6, 9, 12])
|
||||
assert_equal(
|
||||
ngrams.data,
|
||||
[
|
||||
# doc 0
|
||||
[0, 1],
|
||||
# doc 1
|
||||
[0, 1],
|
||||
[1, 2],
|
||||
[0, 2],
|
||||
# doc 2
|
||||
[0, 1],
|
||||
[1, 2],
|
||||
[2, 3],
|
||||
[0, 2],
|
||||
[1, 3],
|
||||
[0, 3],
|
||||
# doc 3
|
||||
[0, 1],
|
||||
[1, 2],
|
||||
[2, 3],
|
||||
[3, 4],
|
||||
[0, 2],
|
||||
[1, 3],
|
||||
[2, 4],
|
||||
[0, 3],
|
||||
[1, 4],
|
||||
# doc 4
|
||||
[0, 1],
|
||||
[1, 2],
|
||||
[2, 3],
|
||||
[3, 4],
|
||||
[4, 5],
|
||||
[0, 2],
|
||||
[1, 3],
|
||||
[2, 4],
|
||||
[3, 5],
|
||||
[0, 3],
|
||||
[1, 4],
|
||||
[2, 5],
|
||||
],
|
||||
)
|
||||
|
||||
# test some empty docs
|
||||
ngram_suggester = registry.misc.get("ngram_suggester.v1")(sizes=[1])
|
||||
docs = [en_tokenizer(text) for text in ["", "a", ""]]
|
||||
ngrams = ngram_suggester(docs)
|
||||
assert_equal(ngrams.lengths, [len(doc) for doc in docs])
|
||||
|
||||
# test all empty docs
|
||||
ngram_suggester = registry.misc.get("ngram_suggester.v1")(sizes=[1])
|
||||
docs = [en_tokenizer(text) for text in ["", "", ""]]
|
||||
ngrams = ngram_suggester(docs)
|
||||
assert_equal(ngrams.lengths, [len(doc) for doc in docs])
|
|
@ -1,11 +1,14 @@
|
|||
from typing import List
|
||||
import pytest
|
||||
from thinc.api import fix_random_seed, Adam, set_dropout_rate
|
||||
from thinc.api import Ragged, reduce_mean, Logistic, chain, Relu
|
||||
from numpy.testing import assert_array_equal, assert_array_almost_equal
|
||||
import numpy
|
||||
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
|
||||
from spacy.ml.models import build_bow_text_classifier, build_simple_cnn_text_classifier
|
||||
from spacy.ml.models import build_spancat_model
|
||||
from spacy.ml.staticvectors import StaticVectors
|
||||
from spacy.ml.extract_spans import extract_spans, _get_span_indices
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.en.examples import sentences as EN_SENTENCES
|
||||
|
||||
|
@ -205,3 +208,63 @@ def test_empty_docs(model_func, kwargs):
|
|||
# Test backprop
|
||||
output, backprop = model.begin_update(docs)
|
||||
backprop(output)
|
||||
|
||||
|
||||
def test_init_extract_spans():
|
||||
model = extract_spans().initialize()
|
||||
|
||||
|
||||
def test_extract_spans_span_indices():
|
||||
model = extract_spans().initialize()
|
||||
spans = Ragged(
|
||||
model.ops.asarray([[0, 3], [2, 3], [5, 7]], dtype="i"),
|
||||
model.ops.asarray([2, 1], dtype="i"),
|
||||
)
|
||||
x_lengths = model.ops.asarray([5, 10], dtype="i")
|
||||
indices = _get_span_indices(model.ops, spans, x_lengths)
|
||||
assert list(indices) == [0, 1, 2, 2, 10, 11]
|
||||
|
||||
|
||||
def test_extract_spans_forward_backward():
|
||||
model = extract_spans().initialize()
|
||||
X = Ragged(model.ops.alloc2f(15, 4), model.ops.asarray([5, 10], dtype="i"))
|
||||
spans = Ragged(
|
||||
model.ops.asarray([[0, 3], [2, 3], [5, 7]], dtype="i"),
|
||||
model.ops.asarray([2, 1], dtype="i"),
|
||||
)
|
||||
Y, backprop = model.begin_update((X, spans))
|
||||
assert list(Y.lengths) == [3, 1, 2]
|
||||
assert Y.dataXd.shape == (6, 4)
|
||||
dX, spans2 = backprop(Y)
|
||||
assert spans2 is spans
|
||||
assert dX.dataXd.shape == X.dataXd.shape
|
||||
assert list(dX.lengths) == list(X.lengths)
|
||||
|
||||
|
||||
def test_spancat_model_init():
|
||||
model = build_spancat_model(
|
||||
build_Tok2Vec_model(**get_tok2vec_kwargs()), reduce_mean(), Logistic()
|
||||
)
|
||||
model.initialize()
|
||||
|
||||
|
||||
def test_spancat_model_forward_backward(nO=5):
|
||||
tok2vec = build_Tok2Vec_model(**get_tok2vec_kwargs())
|
||||
docs = get_docs()
|
||||
spans_list = []
|
||||
lengths = []
|
||||
for doc in docs:
|
||||
spans_list.append(doc[:2])
|
||||
spans_list.append(doc[1:4])
|
||||
lengths.append(2)
|
||||
spans = Ragged(
|
||||
tok2vec.ops.asarray([[s.start, s.end] for s in spans_list], dtype="i"),
|
||||
tok2vec.ops.asarray(lengths, dtype="i"),
|
||||
)
|
||||
model = build_spancat_model(
|
||||
tok2vec, reduce_mean(), chain(Relu(nO=nO), Logistic())
|
||||
).initialize(X=(docs, spans))
|
||||
|
||||
Y, backprop = model((docs, spans), is_train=True)
|
||||
assert Y.shape == (spans.dataXd.shape[0], nO)
|
||||
backprop(Y)
|
||||
|
|
|
@ -1394,7 +1394,8 @@ def combine_score_weights(
|
|||
# We divide each weight by the total weight sum.
|
||||
# We first need to extract all None/null values for score weights that
|
||||
# shouldn't be shown in the table *or* be weighted
|
||||
result = {key: overrides.get(key, value) for w_dict in weights for (key, value) in w_dict.items()}
|
||||
result = {key: value for w_dict in weights for (key, value) in w_dict.items()}
|
||||
result.update(overrides)
|
||||
weight_sum = sum([v if v else 0.0 for v in result.values()])
|
||||
for key, value in result.items():
|
||||
if value and weight_sum > 0:
|
||||
|
|
|
@ -9,6 +9,7 @@ menu:
|
|||
- ['Parser & NER', 'parser']
|
||||
- ['Tagging', 'tagger']
|
||||
- ['Text Classification', 'textcat']
|
||||
- ['Span Classification', 'spancat']
|
||||
- ['Entity Linking', 'entitylinker']
|
||||
---
|
||||
|
||||
|
@ -736,6 +737,54 @@ Since v2, new labels can be added to this component, even after training.
|
|||
|
||||
</Accordion>
|
||||
|
||||
## Span classification architectures {#spancat source="spacy/ml/models/spancat.py"}
|
||||
|
||||
### spacy.SpanCategorizer.v1 {#SpanCategorizer}
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
> ```ini
|
||||
> [model]
|
||||
> @architectures = "spacy.SpanCategorizer.v1"
|
||||
> scorer = {"@layers": "spacy.LinearLogistic.v1"}
|
||||
>
|
||||
> [model.reducer]
|
||||
> @layers = spacy.mean_max_reducer.v1"
|
||||
> hidden_size = 128
|
||||
>
|
||||
> [model.tok2vec]
|
||||
> @architectures = "spacy.Tok2Vec.v1"
|
||||
>
|
||||
> [model.tok2vec.embed]
|
||||
> @architectures = "spacy.MultiHashEmbed.v1"
|
||||
> # ...
|
||||
>
|
||||
> [model.tok2vec.encode]
|
||||
> @architectures = "spacy.MaxoutWindowEncoder.v1"
|
||||
> # ...
|
||||
> ```
|
||||
|
||||
Build a span categorizer model to power a
|
||||
[`SpanCategorizer`](/api/spancategorizer) component, given a token-to-vector
|
||||
model, a reducer model to map the sequence of vectors for each span down to a
|
||||
single vector, and a scorer model to map the vectors to probabilities.
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ------------------------------------------------------------------------------- |
|
||||
| `tok2vec` | The token-to-vector model. ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||
| `reducer` | The reducer model. ~~Model[Ragged, Floats2d]~~ |
|
||||
| `scorer` | The scorer model. ~~Model[Floats2d, Floats2d]~~ |
|
||||
| **CREATES** | The model using the architecture. ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ |
|
||||
|
||||
### spacy.mean_max_reducer.v1 {#mean_max_reducer}
|
||||
|
||||
Reduce sequences by concatenating their mean and max pooled vectors, and then
|
||||
combine the concatenated vectors with a hidden layer.
|
||||
|
||||
| Name | Description |
|
||||
| ------------- | ------------------------------------- |
|
||||
| `hidden_size` | The size of the hidden layer. ~~int~~ |
|
||||
|
||||
## Entity linking architectures {#entitylinker source="spacy/ml/models/entity_linker.py"}
|
||||
|
||||
An [`EntityLinker`](/api/entitylinker) component disambiguates textual mentions
|
||||
|
|
453
website/docs/api/spancategorizer.md
Normal file
453
website/docs/api/spancategorizer.md
Normal file
|
@ -0,0 +1,453 @@
|
|||
---
|
||||
title: SpanCategorizer
|
||||
tag: class,experimental
|
||||
source: spacy/pipeline/spancat.py
|
||||
new: 3.1
|
||||
teaser: 'Pipeline component for labeling potentially overlapping spans of text'
|
||||
api_base_class: /api/pipe
|
||||
api_string_name: spancat
|
||||
api_trainable: true
|
||||
---
|
||||
|
||||
A span categorizer consists of two parts: a [suggester function](#suggesters)
|
||||
that proposes candidate spans, which may or may not overlap, and a labeler model
|
||||
that predicts zero or more labels for each candidate.
|
||||
|
||||
## Config and implementation {#config}
|
||||
|
||||
The default config is defined by the pipeline component factory and describes
|
||||
how the component should be configured. You can override its settings via the
|
||||
`config` argument on [`nlp.add_pipe`](/api/language#add_pipe) or in your
|
||||
[`config.cfg` for training](/usage/training#config). See the
|
||||
[model architectures](/api/architectures) documentation for details on the
|
||||
architectures and their arguments and hyperparameters.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.pipeline.spancat import DEFAULT_SPANCAT_MODEL
|
||||
> config = {
|
||||
> "threshold": 0.5,
|
||||
> "spans_key": "labeled_spans",
|
||||
> "max_positive": None,
|
||||
> "model": DEFAULT_SPANCAT_MODEL,
|
||||
> "suggester": {"@misc": "ngram_suggester.v1", "sizes": [1, 2, 3]},
|
||||
> }
|
||||
> nlp.add_pipe("spancat", config=config)
|
||||
> ```
|
||||
|
||||
| Setting | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. Defaults to [`ngram_suggester`](#ngram_suggester). ~~Callable[List[Doc], Ragged]~~ |
|
||||
| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. Defaults to [SpanCategorizer](/api/architectures#SpanCategorizer). ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ |
|
||||
| `spans_key` | Key of the [`Doc.spans`](/api/doc#spans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"spans"`. ~~str~~ |
|
||||
| `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Defaults to `0.5`. ~~float~~ |
|
||||
| `max_positive` | Maximum number of labels to consider positive per span. Defaults to `None`, indicating no limit. ~~Optional[int]~~ |
|
||||
|
||||
```python
|
||||
%%GITHUB_SPACY/spacy/pipeline/spancat.py
|
||||
```
|
||||
|
||||
## SpanCategorizer.\_\_init\_\_ {#init tag="method"}
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> # Construction via add_pipe with default model
|
||||
> spancat = nlp.add_pipe("spancat")
|
||||
>
|
||||
> # Construction via add_pipe with custom model
|
||||
> config = {"model": {"@architectures": "my_spancat"}}
|
||||
> parser = nlp.add_pipe("spancat", config=config)
|
||||
>
|
||||
> # Construction from class
|
||||
> from spacy.pipeline import SpanCategorizer
|
||||
> spancat = SpanCategorizer(nlp.vocab, model, suggester)
|
||||
> ```
|
||||
|
||||
Create a new pipeline instance. In your application, you would normally use a
|
||||
shortcut for this and instantiate the component using its string name and
|
||||
[`nlp.add_pipe`](/api/language#create_pipe).
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `vocab` | The shared vocabulary. ~~Vocab~~ |
|
||||
| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ |
|
||||
| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. ~~Callable[List[Doc], Ragged]~~ |
|
||||
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `spans_key` | Key of the [`Doc.spans`](/api/doc#sans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"spans"`. ~~str~~ |
|
||||
| `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Defaults to `0.5`. ~~float~~ |
|
||||
| `max_positive` | Maximum number of labels to consider positive per span. Defaults to `None`, indicating no limit. ~~Optional[int]~~ |
|
||||
|
||||
## SpanCategorizer.\_\_call\_\_ {#call tag="method"}
|
||||
|
||||
Apply the pipe to one document. The document is modified in place, and returned.
|
||||
This usually happens under the hood when the `nlp` object is called on a text
|
||||
and all pipeline components are applied to the `Doc` in order. Both
|
||||
[`__call__`](/api/spancategorizer#call) and [`pipe`](/api/spancategorizer#pipe)
|
||||
delegate to the [`predict`](/api/spancategorizer#predict) and
|
||||
[`set_annotations`](/api/spancategorizer#set_annotations) methods.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> doc = nlp("This is a sentence.")
|
||||
> spancat = nlp.add_pipe("spancat")
|
||||
> # This usually happens under the hood
|
||||
> processed = spancat(doc)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | -------------------------------- |
|
||||
| `doc` | The document to process. ~~Doc~~ |
|
||||
| **RETURNS** | The processed document. ~~Doc~~ |
|
||||
|
||||
## SpanCategorizer.pipe {#pipe tag="method"}
|
||||
|
||||
Apply the pipe to a stream of documents. This usually happens under the hood
|
||||
when the `nlp` object is called on a text and all pipeline components are
|
||||
applied to the `Doc` in order. Both [`__call__`](/api/spancategorizer#call) and
|
||||
[`pipe`](/api/spancategorizer#pipe) delegate to the
|
||||
[`predict`](/api/spancategorizer#predict) and
|
||||
[`set_annotations`](/api/spancategorizer#set_annotations) methods.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> spancat = nlp.add_pipe("spancat")
|
||||
> for doc in spancat.pipe(docs, batch_size=50):
|
||||
> pass
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------- |
|
||||
| `stream` | A stream of documents. ~~Iterable[Doc]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `batch_size` | The number of documents to buffer. Defaults to `128`. ~~int~~ |
|
||||
| **YIELDS** | The processed documents in order. ~~Doc~~ |
|
||||
|
||||
## SpanCategorizer.initialize {#initialize tag="method"}
|
||||
|
||||
Initialize the component for training. `get_examples` should be a function that
|
||||
returns an iterable of [`Example`](/api/example) objects. The data examples are
|
||||
used to **initialize the model** of the component and can either be the full
|
||||
training data or a representative sample. Initialization includes validating the
|
||||
network,
|
||||
[inferring missing shapes](https://thinc.ai/docs/usage-models#validation) and
|
||||
setting up the label scheme based on the data. This method is typically called
|
||||
by [`Language.initialize`](/api/language#initialize) and lets you customize
|
||||
arguments it receives via the
|
||||
[`[initialize.components]`](/api/data-formats#config-initialize) block in the
|
||||
config.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> spancat = nlp.add_pipe("spancat")
|
||||
> spancat.initialize(lambda: [], nlp=nlp)
|
||||
> ```
|
||||
>
|
||||
> ```ini
|
||||
> ### config.cfg
|
||||
> [initialize.components.spancat]
|
||||
>
|
||||
> [initialize.components.spancat.labels]
|
||||
> @readers = "spacy.read_labels.v1"
|
||||
> path = "corpus/labels/spancat.json
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | Function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. ~~Callable[[], Iterable[Example]]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `nlp` | The current `nlp` object. Defaults to `None`. ~~Optional[Language]~~ |
|
||||
| `labels` | The label information to add to the component, as provided by the [`label_data`](#label_data) property after initialization. To generate a reusable JSON file from your data, you should run the [`init labels`](/api/cli#init-labels) command. If no labels are provided, the `get_examples` callback is used to extract the labels from the data, which may be a lot slower. ~~Optional[Iterable[str]]~~ |
|
||||
|
||||
## SpanCategorizer.predict {#predict tag="method"}
|
||||
|
||||
Apply the component's model to a batch of [`Doc`](/api/doc) objects without
|
||||
modifying them.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> spancat = nlp.add_pipe("spancat")
|
||||
> scores = spancat.predict([doc1, doc2])
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ------------------------------------------- |
|
||||
| `docs` | The documents to predict. ~~Iterable[Doc]~~ |
|
||||
| **RETURNS** | The model's prediction for each document. |
|
||||
|
||||
## SpanCategorizer.set_annotations {#set_annotations tag="method"}
|
||||
|
||||
Modify a batch of [`Doc`](/api/doc) objects using pre-computed scores.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> spancat = nlp.add_pipe("spancat")
|
||||
> scores = spancat.predict(docs)
|
||||
> spancat.set_annotations(docs, scores)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------- | --------------------------------------------------------- |
|
||||
| `docs` | The documents to modify. ~~Iterable[Doc]~~ |
|
||||
| `scores` | The scores to set, produced by `SpanCategorizer.predict`. |
|
||||
|
||||
## SpanCategorizer.update {#update tag="method"}
|
||||
|
||||
Learn from a batch of [`Example`](/api/example) objects containing the
|
||||
predictions and gold-standard annotations, and update the component's model.
|
||||
Delegates to [`predict`](/api/spancategorizer#predict) and
|
||||
[`get_loss`](/api/spancategorizer#get_loss).
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> spancat = nlp.add_pipe("spancat")
|
||||
> optimizer = nlp.initialize()
|
||||
> losses = spancat.update(examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------------- | ---------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `drop` | The dropout rate. ~~float~~ |
|
||||
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
||||
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||
|
||||
## SpanCategorizer.get_loss {#get_loss tag="method"}
|
||||
|
||||
Find the loss and gradient of loss for the batch of documents and their
|
||||
predicted scores.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> spancat = nlp.add_pipe("spancat")
|
||||
> scores = spancat.predict([eg.predicted for eg in examples])
|
||||
> loss, d_loss = spancat.get_loss(examples, scores)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | --------------------------------------------------------------------------- |
|
||||
| `examples` | The batch of examples. ~~Iterable[Example]~~ |
|
||||
| `scores` | Scores representing the model's predictions. |
|
||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||
|
||||
## SpanCategorizer.score {#score tag="method"}
|
||||
|
||||
Score a batch of examples.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> scores = spancat.score(examples)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | The examples to score. ~~Iterable[Example]~~ |
|
||||
| _keyword-only_ | |
|
||||
| **RETURNS** | The scores, produced by [`Scorer.score_spans`](/api/scorer#score_spans). ~~Dict[str, Union[float, Dict[str, float]]]~~ |
|
||||
|
||||
## SpanCategorizer.create_optimizer {#create_optimizer tag="method"}
|
||||
|
||||
Create an optimizer for the pipeline component.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> spancat = nlp.add_pipe("spancat")
|
||||
> optimizer = spancat.create_optimizer()
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ---------------------------- |
|
||||
| **RETURNS** | The optimizer. ~~Optimizer~~ |
|
||||
|
||||
## SpanCategorizer.use_params {#use_params tag="method, contextmanager"}
|
||||
|
||||
Modify the pipe's model to use the given parameter values.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> spancat = nlp.add_pipe("spancat")
|
||||
> with spancat.use_params(optimizer.averages):
|
||||
> spancat.to_disk("/best_model")
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------- | -------------------------------------------------- |
|
||||
| `params` | The parameter values to use in the model. ~~dict~~ |
|
||||
|
||||
## SpanCategorizer.add_label {#add_label tag="method"}
|
||||
|
||||
Add a new label to the pipe. Raises an error if the output dimension is already
|
||||
set, or if the model has already been fully [initialized](#initialize). Note
|
||||
that you don't have to call this method if you provide a **representative data
|
||||
sample** to the [`initialize`](#initialize) method. In this case, all labels
|
||||
found in the sample will be automatically added to the model, and the output
|
||||
dimension will be [inferred](/usage/layers-architectures#thinc-shape-inference)
|
||||
automatically.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> spancat = nlp.add_pipe("spancat")
|
||||
> spancat.add_label("MY_LABEL")
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ----------------------------------------------------------- |
|
||||
| `label` | The label to add. ~~str~~ |
|
||||
| **RETURNS** | `0` if the label is already present, otherwise `1`. ~~int~~ |
|
||||
|
||||
## SpanCategorizer.to_disk {#to_disk tag="method"}
|
||||
|
||||
Serialize the pipe to disk.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> spancat = nlp.add_pipe("spancat")
|
||||
> spancat.to_disk("/path/to/spancat")
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||
|
||||
## SpanCategorizer.from_disk {#from_disk tag="method"}
|
||||
|
||||
Load the pipe from disk. Modifies the object in place and returns it.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> spancat = nlp.add_pipe("spancat")
|
||||
> spancat.from_disk("/path/to/spancat")
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ----------------------------------------------------------------------------------------------- |
|
||||
| `path` | A path to a directory. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||
| **RETURNS** | The modified `SpanCategorizer` object. ~~SpanCategorizer~~ |
|
||||
|
||||
## SpanCategorizer.to_bytes {#to_bytes tag="method"}
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> spancat = nlp.add_pipe("spancat")
|
||||
> spancat_bytes = spancat.to_bytes()
|
||||
> ```
|
||||
|
||||
Serialize the pipe to a bytestring.
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------- |
|
||||
| _keyword-only_ | |
|
||||
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||
| **RETURNS** | The serialized form of the `SpanCategorizer` object. ~~bytes~~ |
|
||||
|
||||
## SpanCategorizer.from_bytes {#from_bytes tag="method"}
|
||||
|
||||
Load the pipe from a bytestring. Modifies the object in place and returns it.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> spancat_bytes = spancat.to_bytes()
|
||||
> spancat = nlp.add_pipe("spancat")
|
||||
> spancat.from_bytes(spancat_bytes)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------- |
|
||||
| `bytes_data` | The data to load from. ~~bytes~~ |
|
||||
| _keyword-only_ | |
|
||||
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||
| **RETURNS** | The `SpanCategorizer` object. ~~SpanCategorizer~~ |
|
||||
|
||||
## SpanCategorizer.labels {#labels tag="property"}
|
||||
|
||||
The labels currently added to the component.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> spancat.add_label("MY_LABEL")
|
||||
> assert "MY_LABEL" in spancat.labels
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ------------------------------------------------------ |
|
||||
| **RETURNS** | The labels added to the component. ~~Tuple[str, ...]~~ |
|
||||
|
||||
## SpanCategorizer.label_data {#label_data tag="property"}
|
||||
|
||||
The labels currently added to the component and their internal meta information.
|
||||
This is the data generated by [`init labels`](/api/cli#init-labels) and used by
|
||||
[`SpanCategorizer.initialize`](/api/spancategorizer#initialize) to initialize
|
||||
the model with a pre-defined label set.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> labels = spancat.label_data
|
||||
> spancat.initialize(lambda: [], nlp=nlp, labels=labels)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ---------------------------------------------------------- |
|
||||
| **RETURNS** | The label data added to the component. ~~Tuple[str, ...]~~ |
|
||||
|
||||
## Serialization fields {#serialization-fields}
|
||||
|
||||
During serialization, spaCy will export several data fields used to restore
|
||||
different aspects of the object. If needed, you can exclude them from
|
||||
serialization by passing in the string names via the `exclude` argument.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> data = spancat.to_disk("/path", exclude=["vocab"])
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ------- | -------------------------------------------------------------- |
|
||||
| `vocab` | The shared [`Vocab`](/api/vocab). |
|
||||
| `cfg` | The config file. You usually don't want to exclude this. |
|
||||
| `model` | The binary model data. You usually don't want to exclude this. |
|
||||
|
||||
## Suggesters {#suggesters tag="registered functions" source="spacy/pipeline/spancat.py"}
|
||||
|
||||
### spacy.ngram_suggester.v1 {#ngram_suggester}
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
> ```ini
|
||||
> [components.spancat.suggester]
|
||||
> @misc = "spacy.ngram_suggester.v1"
|
||||
> sizes = [1, 2, 3]
|
||||
> ```
|
||||
|
||||
Suggest all spans of the given lengths. Spans are returned as a ragged array of
|
||||
integers. The array has two columns, indicating the start and end position.
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | -------------------------------------------------------------------------------------------------------------------- |
|
||||
| `sizes` | The phrase lengths to suggest. For example, `[1, 2]` will suggest phrases consisting of 1 or 2 tokens. ~~List[int]~~ |
|
||||
| **CREATES** | The suggester function. ~~Callable[[List[Doc]], Ragged]~~ |
|
|
@ -94,6 +94,7 @@
|
|||
{ "text": "Morphologizer", "url": "/api/morphologizer" },
|
||||
{ "text": "SentenceRecognizer", "url": "/api/sentencerecognizer" },
|
||||
{ "text": "Sentencizer", "url": "/api/sentencizer" },
|
||||
{ "text": "SpanCategorizer", "url": "/api/spancategorizer" },
|
||||
{ "text": "Tagger", "url": "/api/tagger" },
|
||||
{ "text": "TextCategorizer", "url": "/api/textcategorizer" },
|
||||
{ "text": "Tok2Vec", "url": "/api/tok2vec" },
|
||||
|
|
Loading…
Reference in New Issue
Block a user