mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
SpanFinder into spaCy from experimental (#12507)
* span finder integrated into spacy from experimental * black * isort * black * default spankey constant * black * Update spacy/pipeline/spancat.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * rename * rename * max_length and min_length as Optional[int] and strict checking * black * mypy fix for integer type infinity * revert line order * implement all comparison operators for inf int * avoid two for loops over all docs by not precomputing * interleave thresholding with span creation * black * revert to not interleaving (relized its faster) * black * Update spacy/errors.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * update dosctring * enforce that the gold and predicted documents have the same text * new error for ensuring reference and predicted texts are the same * remove todo * adjust test * black * handle misaligned tokenization * return correct variable * failing overfit test * only use a single spans_key like in spancat * black * remove debug lines * typo * remove comment * remove near duplicate reduntant method * use the 'spans_key' variable name everywhere * Update spacy/pipeline/span_finder.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * flaky test fix suggestion, hand set bias terms * only test suggester and test result exhaustively * make it clear that the span_finder_suggester is more general (not specific to span_finder) * Update spacy/tests/pipeline/test_span_finder.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Apply suggestions from code review * remove question comment * move preset_spans_suggester test to spancat tests * Add docs and unify default configs for spancat and span finder * Add `allow_overlap=True` to span finder scorer * Fix offset bug in set_annotations * Ignore labels in span finder scorer * Format * Add span_finder to quickstart template * Move settings to self.cfg, store min/max unset as None * Remove debugging * Update docstrings and docs * Update spacy/pipeline/span_finder.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Fix imports --------- Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
c3c064ace4
commit
c003aac29a
|
@ -3,7 +3,7 @@ the docs and the init config command. It encodes various best practices and
|
|||
can help generate the best possible configuration, given a user's requirements. #}
|
||||
{%- set use_transformer = hardware != "cpu" and transformer_data -%}
|
||||
{%- set transformer = transformer_data[optimize] if use_transformer else {} -%}
|
||||
{%- set listener_components = ["tagger", "morphologizer", "parser", "ner", "textcat", "textcat_multilabel", "entity_linker", "spancat", "spancat_singlelabel", "trainable_lemmatizer"] -%}
|
||||
{%- set listener_components = ["tagger", "morphologizer", "parser", "ner", "textcat", "textcat_multilabel", "entity_linker", "span_finder", "spancat", "spancat_singlelabel", "trainable_lemmatizer"] -%}
|
||||
[paths]
|
||||
train = null
|
||||
dev = null
|
||||
|
@ -28,7 +28,7 @@ lang = "{{ lang }}"
|
|||
tok2vec/transformer. #}
|
||||
{%- set with_accuracy_or_transformer = (use_transformer or with_accuracy) -%}
|
||||
{%- set textcat_needs_features = has_textcat and with_accuracy_or_transformer -%}
|
||||
{%- if ("tagger" in components or "morphologizer" in components or "parser" in components or "ner" in components or "spancat" in components or "spancat_singlelabel" in components or "trainable_lemmatizer" in components or "entity_linker" in components or textcat_needs_features) -%}
|
||||
{%- if ("tagger" in components or "morphologizer" in components or "parser" in components or "ner" in components or "span_finder" in components or "spancat" in components or "spancat_singlelabel" in components or "trainable_lemmatizer" in components or "entity_linker" in components or textcat_needs_features) -%}
|
||||
{%- set full_pipeline = ["transformer" if use_transformer else "tok2vec"] + components -%}
|
||||
{%- else -%}
|
||||
{%- set full_pipeline = components -%}
|
||||
|
@ -127,6 +127,30 @@ grad_factor = 1.0
|
|||
@layers = "reduce_mean.v1"
|
||||
{% endif -%}
|
||||
|
||||
{% if "span_finder" in components -%}
|
||||
[components.span_finder]
|
||||
factory = "span_finder"
|
||||
max_length = null
|
||||
min_length = null
|
||||
scorer = {"@scorers":"spacy.span_finder_scorer.v1"}
|
||||
spans_key = "sc"
|
||||
threshold = 0.5
|
||||
|
||||
[components.span_finder.model]
|
||||
@architectures = "spacy.SpanFinder.v1"
|
||||
|
||||
[components.span_finder.model.scorer]
|
||||
@layers = "spacy.LinearLogistic.v1"
|
||||
nO = 2
|
||||
|
||||
[components.span_finder.model.tok2vec]
|
||||
@architectures = "spacy-transformers.TransformerListener.v1"
|
||||
grad_factor = 1.0
|
||||
|
||||
[components.span_finder.model.tok2vec.pooling]
|
||||
@layers = "reduce_mean.v1"
|
||||
{% endif -%}
|
||||
|
||||
{% if "spancat" in components -%}
|
||||
[components.spancat]
|
||||
factory = "spancat"
|
||||
|
@ -392,6 +416,27 @@ nO = null
|
|||
width = ${components.tok2vec.model.encode.width}
|
||||
{% endif %}
|
||||
|
||||
{% if "span_finder" in components %}
|
||||
[components.span_finder]
|
||||
factory = "span_finder"
|
||||
max_length = null
|
||||
min_length = null
|
||||
scorer = {"@scorers":"spacy.span_finder_scorer.v1"}
|
||||
spans_key = "sc"
|
||||
threshold = 0.5
|
||||
|
||||
[components.span_finder.model]
|
||||
@architectures = "spacy.SpanFinder.v1"
|
||||
|
||||
[components.span_finder.model.scorer]
|
||||
@layers = "spacy.LinearLogistic.v1"
|
||||
nO = 2
|
||||
|
||||
[components.span_finder.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecListener.v1"
|
||||
width = ${components.tok2vec.model.encode.width}
|
||||
{% endif %}
|
||||
|
||||
{% if "spancat" in components %}
|
||||
[components.spancat]
|
||||
factory = "spancat"
|
||||
|
|
|
@ -973,6 +973,10 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
E1052 = ("Unable to copy spans: the character offsets for the span at "
|
||||
"index {i} in the span group do not align with the tokenization "
|
||||
"in the target doc.")
|
||||
E1053 = ("Both 'min_length' and 'max_length' should be larger than 0, but found"
|
||||
" 'min_length': {min_length}, 'max_length': {max_length}")
|
||||
E1054 = ("The text, including whitespace, must match between reference and "
|
||||
"predicted docs when training {component}.")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from .entity_linker import * # noqa
|
||||
from .multi_task import * # noqa
|
||||
from .parser import * # noqa
|
||||
from .span_finder import * # noqa
|
||||
from .spancat import * # noqa
|
||||
from .tagger import * # noqa
|
||||
from .textcat import * # noqa
|
||||
|
|
42
spacy/ml/models/span_finder.py
Normal file
42
spacy/ml/models/span_finder.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
from typing import Callable, List, Tuple
|
||||
|
||||
from thinc.api import Model, chain, with_array
|
||||
from thinc.types import Floats1d, Floats2d
|
||||
|
||||
from ...tokens import Doc
|
||||
|
||||
from ...util import registry
|
||||
|
||||
InT = List[Doc]
|
||||
OutT = Floats2d
|
||||
|
||||
|
||||
@registry.architectures("spacy.SpanFinder.v1")
|
||||
def build_finder_model(
|
||||
tok2vec: Model[InT, List[Floats2d]], scorer: Model[OutT, OutT]
|
||||
) -> Model[InT, OutT]:
|
||||
|
||||
logistic_layer: Model[List[Floats2d], List[Floats2d]] = with_array(scorer)
|
||||
model: Model[InT, OutT] = chain(tok2vec, logistic_layer, flattener())
|
||||
model.set_ref("tok2vec", tok2vec)
|
||||
model.set_ref("scorer", scorer)
|
||||
model.set_ref("logistic_layer", logistic_layer)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def flattener() -> Model[List[Floats2d], Floats2d]:
|
||||
"""Flattens the input to a 1-dimensional list of scores"""
|
||||
|
||||
def forward(
|
||||
model: Model[Floats1d, Floats1d], X: List[Floats2d], is_train: bool
|
||||
) -> Tuple[Floats2d, Callable[[Floats2d], List[Floats2d]]]:
|
||||
lens = model.ops.asarray1i([len(doc) for doc in X])
|
||||
Y = model.ops.flatten(X)
|
||||
|
||||
def backprop(dY: Floats2d) -> List[Floats2d]:
|
||||
return model.ops.unflatten(dY, lens)
|
||||
|
||||
return Y, backprop
|
||||
|
||||
return Model("Flattener", forward=forward)
|
|
@ -2,21 +2,22 @@ from .attributeruler import AttributeRuler
|
|||
from .dep_parser import DependencyParser
|
||||
from .edit_tree_lemmatizer import EditTreeLemmatizer
|
||||
from .entity_linker import EntityLinker
|
||||
from .ner import EntityRecognizer
|
||||
from .entityruler import EntityRuler
|
||||
from .functions import merge_entities, merge_noun_chunks, merge_subtokens
|
||||
from .lemmatizer import Lemmatizer
|
||||
from .morphologizer import Morphologizer
|
||||
from .ner import EntityRecognizer
|
||||
from .pipe import Pipe
|
||||
from .trainable_pipe import TrainablePipe
|
||||
from .senter import SentenceRecognizer
|
||||
from .sentencizer import Sentencizer
|
||||
from .senter import SentenceRecognizer
|
||||
from .span_finder import SpanFinder
|
||||
from .span_ruler import SpanRuler
|
||||
from .spancat import SpanCategorizer
|
||||
from .tagger import Tagger
|
||||
from .textcat import TextCategorizer
|
||||
from .spancat import SpanCategorizer
|
||||
from .span_ruler import SpanRuler
|
||||
from .textcat_multilabel import MultiLabel_TextCategorizer
|
||||
from .tok2vec import Tok2Vec
|
||||
from .functions import merge_entities, merge_noun_chunks, merge_subtokens
|
||||
from .trainable_pipe import TrainablePipe
|
||||
|
||||
__all__ = [
|
||||
"AttributeRuler",
|
||||
|
@ -31,6 +32,7 @@ __all__ = [
|
|||
"SentenceRecognizer",
|
||||
"Sentencizer",
|
||||
"SpanCategorizer",
|
||||
"SpanFinder",
|
||||
"SpanRuler",
|
||||
"Tagger",
|
||||
"TextCategorizer",
|
||||
|
|
336
spacy/pipeline/span_finder.py
Normal file
336
spacy/pipeline/span_finder.py
Normal file
|
@ -0,0 +1,336 @@
|
|||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from thinc.api import Config, Model, Optimizer, set_dropout_rate
|
||||
from thinc.types import Floats2d
|
||||
|
||||
from ..language import Language
|
||||
from .trainable_pipe import TrainablePipe
|
||||
from ..scorer import Scorer
|
||||
from ..tokens import Doc, Span
|
||||
from ..training import Example
|
||||
from ..errors import Errors
|
||||
|
||||
from ..util import registry
|
||||
from .spancat import DEFAULT_SPANS_KEY
|
||||
|
||||
span_finder_default_config = """
|
||||
[model]
|
||||
@architectures = "spacy.SpanFinder.v1"
|
||||
|
||||
[model.scorer]
|
||||
@layers = "spacy.LinearLogistic.v1"
|
||||
nO = 2
|
||||
|
||||
[model.tok2vec]
|
||||
@architectures = "spacy.Tok2Vec.v2"
|
||||
|
||||
[model.tok2vec.embed]
|
||||
@architectures = "spacy.MultiHashEmbed.v2"
|
||||
width = 96
|
||||
rows = [5000, 1000, 2500, 1000]
|
||||
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||
include_static_vectors = false
|
||||
|
||||
[model.tok2vec.encode]
|
||||
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||
width = ${model.tok2vec.embed.width}
|
||||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
depth = 4
|
||||
"""
|
||||
|
||||
DEFAULT_SPAN_FINDER_MODEL = Config().from_str(span_finder_default_config)["model"]
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"span_finder",
|
||||
assigns=["doc.spans"],
|
||||
default_config={
|
||||
"threshold": 0.5,
|
||||
"model": DEFAULT_SPAN_FINDER_MODEL,
|
||||
"spans_key": DEFAULT_SPANS_KEY,
|
||||
"max_length": None,
|
||||
"min_length": None,
|
||||
"scorer": {"@scorers": "spacy.span_finder_scorer.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
f"span_finder_{DEFAULT_SPANS_KEY}_f": 1.0,
|
||||
f"span_finder_{DEFAULT_SPANS_KEY}_p": 0.0,
|
||||
f"span_finder_{DEFAULT_SPANS_KEY}_r": 0.0,
|
||||
},
|
||||
)
|
||||
def make_span_finder(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model[Iterable[Doc], Floats2d],
|
||||
spans_key: str,
|
||||
threshold: float,
|
||||
max_length: Optional[int],
|
||||
min_length: Optional[int],
|
||||
scorer: Optional[Callable],
|
||||
) -> "SpanFinder":
|
||||
"""Create a SpanFinder component. The component predicts whether a token is
|
||||
the start or the end of a potential span.
|
||||
|
||||
model (Model[List[Doc], Floats2d]): A model instance that
|
||||
is given a list of documents and predicts a probability for each token.
|
||||
spans_key (str): Key of the doc.spans dict to save the spans under. During
|
||||
initialization and training, the component will look for spans on the
|
||||
reference document under the same key.
|
||||
threshold (float): Minimum probability to consider a prediction positive.
|
||||
max_length (Optional[int]): Maximum length of the produced spans, defaults
|
||||
to None meaning unlimited length.
|
||||
min_length (Optional[int]): Minimum length of the produced spans, defaults
|
||||
to None meaning shortest span length is 1.
|
||||
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||
Scorer.score_spans for the Doc.spans[spans_key] with overlapping
|
||||
spans allowed.
|
||||
"""
|
||||
return SpanFinder(
|
||||
nlp,
|
||||
model=model,
|
||||
threshold=threshold,
|
||||
name=name,
|
||||
scorer=scorer,
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
spans_key=spans_key,
|
||||
)
|
||||
|
||||
|
||||
@registry.scorers("spacy.span_finder_scorer.v1")
|
||||
def make_span_finder_scorer():
|
||||
return span_finder_score
|
||||
|
||||
|
||||
def span_finder_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
kwargs = dict(kwargs)
|
||||
attr_prefix = "span_finder_"
|
||||
key = kwargs["spans_key"]
|
||||
kwargs.setdefault("attr", f"{attr_prefix}{key}")
|
||||
kwargs.setdefault(
|
||||
"getter", lambda doc, key: doc.spans.get(key[len(attr_prefix) :], [])
|
||||
)
|
||||
kwargs.setdefault("has_annotation", lambda doc: key in doc.spans)
|
||||
kwargs.setdefault("allow_overlap", True)
|
||||
kwargs.setdefault("labeled", False)
|
||||
scores = Scorer.score_spans(examples, **kwargs)
|
||||
scores.pop(f"{kwargs['attr']}_per_type", None)
|
||||
return scores
|
||||
|
||||
|
||||
def _char_indices(span: Span) -> Tuple[int, int]:
|
||||
start = span[0].idx
|
||||
end = span[-1].idx + len(span[-1])
|
||||
return start, end
|
||||
|
||||
|
||||
class SpanFinder(TrainablePipe):
|
||||
"""Pipeline that learns span boundaries.
|
||||
|
||||
DOCS: https://spacy.io/api/spanfinder
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nlp: Language,
|
||||
model: Model[Iterable[Doc], Floats2d],
|
||||
name: str = "span_finder",
|
||||
*,
|
||||
spans_key: str = DEFAULT_SPANS_KEY,
|
||||
threshold: float = 0.5,
|
||||
max_length: Optional[int] = None,
|
||||
min_length: Optional[int] = None,
|
||||
scorer: Optional[Callable] = span_finder_score,
|
||||
) -> None:
|
||||
"""Initialize the span finder.
|
||||
model (thinc.api.Model): The Thinc Model powering the pipeline
|
||||
component.
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
threshold (float): Minimum probability to consider a prediction
|
||||
positive.
|
||||
scorer (Optional[Callable]): The scoring method.
|
||||
spans_key (str): Key of the doc.spans dict to save the spans under.
|
||||
During initialization and training, the component will look for
|
||||
spans on the reference document under the same key.
|
||||
max_length (Optional[int]): Maximum length of the produced spans,
|
||||
defaults to None meaning unlimited length.
|
||||
min_length (Optional[int]): Minimum length of the produced spans,
|
||||
defaults to None meaning shortest span length is 1.
|
||||
|
||||
DOCS: https://spacy.io/api/spanfinder#init
|
||||
"""
|
||||
self.vocab = nlp.vocab
|
||||
if (max_length is not None and max_length < 1) or (
|
||||
min_length is not None and min_length < 1
|
||||
):
|
||||
raise ValueError(
|
||||
Errors.E1053.format(min_length=min_length, max_length=max_length)
|
||||
)
|
||||
self.model = model
|
||||
self.name = name
|
||||
self.scorer = scorer
|
||||
self.cfg: Dict[str, Any] = {
|
||||
"min_length": min_length,
|
||||
"max_length": max_length,
|
||||
"threshold": threshold,
|
||||
"spans_key": spans_key,
|
||||
}
|
||||
|
||||
def predict(self, docs: Iterable[Doc]):
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying
|
||||
them.
|
||||
|
||||
docs (Iterable[Doc]): The documents to predict.
|
||||
RETURNS: The models prediction for each document.
|
||||
|
||||
DOCS: https://spacy.io/api/spanfinder#predict
|
||||
"""
|
||||
scores = self.model.predict(docs)
|
||||
return scores
|
||||
|
||||
def set_annotations(self, docs: Iterable[Doc], scores: Floats2d) -> None:
|
||||
"""Modify a batch of Doc objects, using pre-computed scores.
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
scores: The scores to set, produced by SpanFinder predict method.
|
||||
|
||||
DOCS: https://spacy.io/api/spanfinder#set_annotations
|
||||
"""
|
||||
offset = 0
|
||||
for i, doc in enumerate(docs):
|
||||
doc.spans[self.cfg["spans_key"]] = []
|
||||
starts = []
|
||||
ends = []
|
||||
doc_scores = scores[offset : offset + len(doc)]
|
||||
|
||||
for token, token_score in zip(doc, doc_scores):
|
||||
if token_score[0] >= self.cfg["threshold"]:
|
||||
starts.append(token.i)
|
||||
if token_score[1] >= self.cfg["threshold"]:
|
||||
ends.append(token.i)
|
||||
|
||||
for start in starts:
|
||||
for end in ends:
|
||||
span_length = end + 1 - start
|
||||
if span_length < 1:
|
||||
continue
|
||||
if (
|
||||
self.cfg["min_length"] is None
|
||||
or self.cfg["min_length"] <= span_length
|
||||
) and (
|
||||
self.cfg["max_length"] is None
|
||||
or span_length <= self.cfg["max_length"]
|
||||
):
|
||||
doc.spans[self.cfg["spans_key"]].append(doc[start : end + 1])
|
||||
offset += len(doc)
|
||||
|
||||
def update(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""Learn from a batch of documents and gold-standard information,
|
||||
updating the pipe's model. Delegates to predict and get_loss.
|
||||
examples (Iterable[Example]): A batch of Example objects.
|
||||
drop (float): The dropout rate.
|
||||
sgd (Optional[thinc.api.Optimizer]): The optimizer.
|
||||
losses (Optional[Dict[str, float]]): Optional record of the loss during
|
||||
training. Updated using the component name as the key.
|
||||
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||
|
||||
DOCS: https://spacy.io/api/spanfinder#update
|
||||
"""
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
predicted = [eg.predicted for eg in examples]
|
||||
set_dropout_rate(self.model, drop)
|
||||
scores, backprop_scores = self.model.begin_update(predicted)
|
||||
loss, d_scores = self.get_loss(examples, scores)
|
||||
backprop_scores(d_scores)
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
losses[self.name] += loss
|
||||
return losses
|
||||
|
||||
def get_loss(self, examples, scores) -> Tuple[float, Floats2d]:
|
||||
"""Find the loss and gradient of loss for the batch of documents and
|
||||
their predicted scores.
|
||||
examples (Iterable[Examples]): The batch of examples.
|
||||
scores: Scores representing the model's predictions.
|
||||
RETURNS (Tuple[float, Floats2d]): The loss and the gradient.
|
||||
|
||||
DOCS: https://spacy.io/api/spanfinder#get_loss
|
||||
"""
|
||||
truths, masks = self._get_aligned_truth_scores(examples, self.model.ops)
|
||||
d_scores = scores - self.model.ops.asarray2f(truths)
|
||||
d_scores *= masks
|
||||
loss = float((d_scores**2).sum())
|
||||
return loss, d_scores
|
||||
|
||||
def _get_aligned_truth_scores(self, examples, ops) -> Tuple[Floats2d, Floats2d]:
|
||||
"""Align scores of the predictions to the references for calculating
|
||||
the loss.
|
||||
"""
|
||||
truths = []
|
||||
masks = []
|
||||
for eg in examples:
|
||||
if eg.x.text != eg.y.text:
|
||||
raise ValueError(Errors.E1054.format(component="span_finder"))
|
||||
n_tokens = len(eg.predicted)
|
||||
truth = ops.xp.zeros((n_tokens, 2), dtype="float32")
|
||||
mask = ops.xp.ones((n_tokens, 2), dtype="float32")
|
||||
if self.cfg["spans_key"] in eg.reference.spans:
|
||||
for span in eg.reference.spans[self.cfg["spans_key"]]:
|
||||
ref_start_char, ref_end_char = _char_indices(span)
|
||||
pred_span = eg.predicted.char_span(
|
||||
ref_start_char, ref_end_char, alignment_mode="expand"
|
||||
)
|
||||
pred_start_char, pred_end_char = _char_indices(pred_span)
|
||||
start_match = pred_start_char == ref_start_char
|
||||
end_match = pred_end_char == ref_end_char
|
||||
if start_match:
|
||||
truth[pred_span[0].i, 0] = 1
|
||||
else:
|
||||
mask[pred_span[0].i, 0] = 0
|
||||
if end_match:
|
||||
truth[pred_span[-1].i, 1] = 1
|
||||
else:
|
||||
mask[pred_span[-1].i, 1] = 0
|
||||
truths.append(truth)
|
||||
masks.append(mask)
|
||||
truths = ops.xp.concatenate(truths, axis=0)
|
||||
masks = ops.xp.concatenate(masks, axis=0)
|
||||
return truths, masks
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
get_examples: Callable[[], Iterable[Example]],
|
||||
*,
|
||||
nlp: Optional[Language] = None,
|
||||
) -> None:
|
||||
"""Initialize the pipe for training, using a representative set
|
||||
of data examples.
|
||||
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||
returns a representative sample of gold-standard Example objects.
|
||||
nlp (Optional[Language]): The current nlp object the component is part
|
||||
of.
|
||||
|
||||
DOCS: https://spacy.io/api/spanfinder#initialize
|
||||
"""
|
||||
subbatch: List[Example] = []
|
||||
|
||||
for eg in get_examples():
|
||||
if len(subbatch) < 10:
|
||||
subbatch.append(eg)
|
||||
|
||||
if subbatch:
|
||||
docs = [eg.reference for eg in subbatch]
|
||||
Y, _ = self._get_aligned_truth_scores(subbatch, self.model.ops)
|
||||
self.model.initialize(X=docs, Y=Y)
|
||||
else:
|
||||
self.model.initialize()
|
|
@ -1,22 +1,20 @@
|
|||
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
||||
from thinc.api import Optimizer
|
||||
from thinc.types import Ragged, Ints2d, Floats2d
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy
|
||||
from thinc.api import Config, Model, Ops, Optimizer, get_current_ops, set_dropout_rate
|
||||
from thinc.types import Floats2d, Ints1d, Ints2d, Ragged
|
||||
|
||||
from ..compat import Protocol, runtime_checkable
|
||||
from ..scorer import Scorer
|
||||
from ..language import Language
|
||||
from .trainable_pipe import TrainablePipe
|
||||
from ..tokens import Doc, SpanGroup, Span
|
||||
from ..vocab import Vocab
|
||||
from ..training import Example, validate_examples
|
||||
from ..errors import Errors
|
||||
from ..language import Language
|
||||
from ..scorer import Scorer
|
||||
from ..tokens import Doc, Span, SpanGroup
|
||||
from ..training import Example, validate_examples
|
||||
from ..util import registry
|
||||
|
||||
from ..vocab import Vocab
|
||||
from .trainable_pipe import TrainablePipe
|
||||
|
||||
spancat_default_config = """
|
||||
[model]
|
||||
|
@ -33,8 +31,8 @@ hidden_size = 128
|
|||
[model.tok2vec.embed]
|
||||
@architectures = "spacy.MultiHashEmbed.v2"
|
||||
width = 96
|
||||
rows = [5000, 2000, 1000, 1000]
|
||||
attrs = ["ORTH", "PREFIX", "SUFFIX", "SHAPE"]
|
||||
rows = [5000, 1000, 2500, 1000]
|
||||
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||
include_static_vectors = false
|
||||
|
||||
[model.tok2vec.encode]
|
||||
|
@ -71,6 +69,7 @@ maxout_pieces = 3
|
|||
depth = 4
|
||||
"""
|
||||
|
||||
DEFAULT_SPANS_KEY = "sc"
|
||||
DEFAULT_SPANCAT_MODEL = Config().from_str(spancat_default_config)["model"]
|
||||
DEFAULT_SPANCAT_SINGLELABEL_MODEL = Config().from_str(
|
||||
spancat_singlelabel_default_config
|
||||
|
@ -112,6 +111,29 @@ def ngram_suggester(
|
|||
return output
|
||||
|
||||
|
||||
def preset_spans_suggester(
|
||||
docs: Iterable[Doc], spans_key: str, *, ops: Optional[Ops] = None
|
||||
) -> Ragged:
|
||||
if ops is None:
|
||||
ops = get_current_ops()
|
||||
spans = []
|
||||
lengths = []
|
||||
for doc in docs:
|
||||
length = 0
|
||||
if doc.spans[spans_key]:
|
||||
for span in doc.spans[spans_key]:
|
||||
spans.append([span.start, span.end])
|
||||
length += 1
|
||||
|
||||
lengths.append(length)
|
||||
lengths_array = cast(Ints1d, ops.asarray(lengths, dtype="i"))
|
||||
if len(spans) > 0:
|
||||
output = Ragged(ops.asarray(spans, dtype="i"), lengths_array)
|
||||
else:
|
||||
output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array)
|
||||
return output
|
||||
|
||||
|
||||
@registry.misc("spacy.ngram_suggester.v1")
|
||||
def build_ngram_suggester(sizes: List[int]) -> Suggester:
|
||||
"""Suggest all spans of the given lengths. Spans are returned as a ragged
|
||||
|
@ -130,12 +152,20 @@ def build_ngram_range_suggester(min_size: int, max_size: int) -> Suggester:
|
|||
return build_ngram_suggester(sizes)
|
||||
|
||||
|
||||
@registry.misc("spacy.preset_spans_suggester.v1")
|
||||
def build_preset_spans_suggester(spans_key: str) -> Suggester:
|
||||
"""Suggest all spans that are already stored in doc.spans[spans_key].
|
||||
This is useful when an upstream component is used to set the spans
|
||||
on the Doc such as a SpanRuler or SpanFinder."""
|
||||
return partial(preset_spans_suggester, spans_key=spans_key)
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"spancat",
|
||||
assigns=["doc.spans"],
|
||||
default_config={
|
||||
"threshold": 0.5,
|
||||
"spans_key": "sc",
|
||||
"spans_key": DEFAULT_SPANS_KEY,
|
||||
"max_positive": None,
|
||||
"model": DEFAULT_SPANCAT_MODEL,
|
||||
"suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
|
||||
|
@ -199,7 +229,7 @@ def make_spancat(
|
|||
"spancat_singlelabel",
|
||||
assigns=["doc.spans"],
|
||||
default_config={
|
||||
"spans_key": "sc",
|
||||
"spans_key": DEFAULT_SPANS_KEY,
|
||||
"model": DEFAULT_SPANCAT_SINGLELABEL_MODEL,
|
||||
"negative_weight": 1.0,
|
||||
"suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
|
||||
|
|
242
spacy/tests/pipeline/test_span_finder.py
Normal file
242
spacy/tests/pipeline/test_span_finder.py
Normal file
|
@ -0,0 +1,242 @@
|
|||
import pytest
|
||||
from thinc.api import Config
|
||||
|
||||
from spacy.language import Language
|
||||
from spacy.lang.en import English
|
||||
from spacy.pipeline.span_finder import span_finder_default_config
|
||||
from spacy.tokens import Doc
|
||||
from spacy.training import Example
|
||||
from spacy import util
|
||||
from spacy.util import registry
|
||||
from spacy.util import fix_random_seed, make_tempdir
|
||||
|
||||
|
||||
SPANS_KEY = "pytest"
|
||||
TRAIN_DATA = [
|
||||
("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}),
|
||||
(
|
||||
"I like London and Berlin.",
|
||||
{"spans": {SPANS_KEY: [(7, 13), (18, 24)]}},
|
||||
),
|
||||
]
|
||||
|
||||
TRAIN_DATA_OVERLAPPING = [
|
||||
("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}),
|
||||
(
|
||||
"I like London and Berlin",
|
||||
{"spans": {SPANS_KEY: [(7, 13), (18, 24), (7, 24)]}},
|
||||
),
|
||||
("", {"spans": {SPANS_KEY: []}}),
|
||||
]
|
||||
|
||||
|
||||
def make_examples(nlp, data=TRAIN_DATA):
|
||||
train_examples = []
|
||||
for t in data:
|
||||
eg = Example.from_dict(nlp.make_doc(t[0]), t[1])
|
||||
train_examples.append(eg)
|
||||
return train_examples
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tokens_predicted, tokens_reference, reference_truths",
|
||||
[
|
||||
(
|
||||
["Mon", ".", "-", "June", "16"],
|
||||
["Mon.", "-", "June", "16"],
|
||||
[(0, 0), (0, 0), (0, 0), (1, 1), (0, 0)],
|
||||
),
|
||||
(
|
||||
["Mon.", "-", "J", "une", "16"],
|
||||
["Mon.", "-", "June", "16"],
|
||||
[(0, 0), (0, 0), (1, 0), (0, 1), (0, 0)],
|
||||
),
|
||||
(
|
||||
["Mon", ".", "-", "June", "16"],
|
||||
["Mon.", "-", "June", "1", "6"],
|
||||
[(0, 0), (0, 0), (0, 0), (1, 1), (0, 0)],
|
||||
),
|
||||
(
|
||||
["Mon.", "-J", "un", "e 16"],
|
||||
["Mon.", "-", "June", "16"],
|
||||
[(0, 0), (0, 0), (0, 0), (0, 0)],
|
||||
),
|
||||
pytest.param(
|
||||
["Mon.-June", "16"],
|
||||
["Mon.", "-", "June", "16"],
|
||||
[(0, 1), (0, 0)],
|
||||
),
|
||||
pytest.param(
|
||||
["Mon.-", "June", "16"],
|
||||
["Mon.", "-", "J", "une", "16"],
|
||||
[(0, 0), (1, 1), (0, 0)],
|
||||
),
|
||||
pytest.param(
|
||||
["Mon.-", "June 16"],
|
||||
["Mon.", "-", "June", "16"],
|
||||
[(0, 0), (1, 0)],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_loss_alignment_example(tokens_predicted, tokens_reference, reference_truths):
|
||||
nlp = Language()
|
||||
predicted = Doc(
|
||||
nlp.vocab, words=tokens_predicted, spaces=[False] * len(tokens_predicted)
|
||||
)
|
||||
reference = Doc(
|
||||
nlp.vocab, words=tokens_reference, spaces=[False] * len(tokens_reference)
|
||||
)
|
||||
example = Example(predicted, reference)
|
||||
example.reference.spans[SPANS_KEY] = [example.reference.char_span(5, 9)]
|
||||
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
|
||||
nlp.initialize()
|
||||
ops = span_finder.model.ops
|
||||
if predicted.text != reference.text:
|
||||
with pytest.raises(
|
||||
ValueError, match="must match between reference and predicted"
|
||||
):
|
||||
span_finder._get_aligned_truth_scores([example], ops)
|
||||
return
|
||||
truth_scores, masks = span_finder._get_aligned_truth_scores([example], ops)
|
||||
assert len(truth_scores) == len(tokens_predicted)
|
||||
ops.xp.testing.assert_array_equal(truth_scores, ops.xp.asarray(reference_truths))
|
||||
|
||||
|
||||
def test_span_finder_model():
|
||||
nlp = Language()
|
||||
|
||||
docs = [nlp("This is an example."), nlp("This is the second example.")]
|
||||
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
|
||||
docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
|
||||
|
||||
total_tokens = 0
|
||||
for doc in docs:
|
||||
total_tokens += len(doc)
|
||||
|
||||
config = Config().from_str(span_finder_default_config).interpolate()
|
||||
model = registry.resolve(config)["model"]
|
||||
|
||||
model.initialize(X=docs)
|
||||
predictions = model.predict(docs)
|
||||
|
||||
assert len(predictions) == total_tokens
|
||||
assert len(predictions[0]) == 2
|
||||
|
||||
|
||||
def test_span_finder_component():
|
||||
nlp = Language()
|
||||
|
||||
docs = [nlp("This is an example."), nlp("This is the second example.")]
|
||||
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
|
||||
docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
|
||||
|
||||
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
|
||||
nlp.initialize()
|
||||
docs = list(span_finder.pipe(docs))
|
||||
|
||||
assert SPANS_KEY in docs[0].spans
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"min_length, max_length, span_count",
|
||||
[(0, 0, 0), (None, None, 8), (2, None, 6), (None, 1, 2), (2, 3, 2)],
|
||||
)
|
||||
def test_set_annotations_span_lengths(min_length, max_length, span_count):
|
||||
nlp = Language()
|
||||
doc = nlp("Me and Jenny goes together like peas and carrots.")
|
||||
if min_length == 0 and max_length == 0:
|
||||
with pytest.raises(ValueError, match="Both 'min_length' and 'max_length'"):
|
||||
span_finder = nlp.add_pipe(
|
||||
"span_finder",
|
||||
config={
|
||||
"max_length": max_length,
|
||||
"min_length": min_length,
|
||||
"spans_key": SPANS_KEY,
|
||||
},
|
||||
)
|
||||
return
|
||||
span_finder = nlp.add_pipe(
|
||||
"span_finder",
|
||||
config={
|
||||
"max_length": max_length,
|
||||
"min_length": min_length,
|
||||
"spans_key": SPANS_KEY,
|
||||
},
|
||||
)
|
||||
nlp.initialize()
|
||||
# Starts [Me, Jenny, peas]
|
||||
# Ends [Jenny, peas, carrots]
|
||||
scores = [
|
||||
(1, 0),
|
||||
(0, 0),
|
||||
(1, 1),
|
||||
(0, 0),
|
||||
(0, 0),
|
||||
(0, 0),
|
||||
(1, 1),
|
||||
(0, 0),
|
||||
(0, 1),
|
||||
(0, 0),
|
||||
]
|
||||
span_finder.set_annotations([doc], scores)
|
||||
|
||||
assert doc.spans[SPANS_KEY]
|
||||
assert len(doc.spans[SPANS_KEY]) == span_count
|
||||
|
||||
# Assert below will fail when max_length is set to 0
|
||||
if max_length is None:
|
||||
max_length = float("inf")
|
||||
if min_length is None:
|
||||
min_length = 1
|
||||
|
||||
assert all(min_length <= len(span) <= max_length for span in doc.spans[SPANS_KEY])
|
||||
|
||||
|
||||
def test_overfitting_IO():
|
||||
# Simple test to try and quickly overfit the span_finder component - ensuring the ML models work correctly
|
||||
fix_random_seed(0)
|
||||
nlp = English()
|
||||
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
|
||||
train_examples = make_examples(nlp)
|
||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||
assert span_finder.model.get_dim("nO") == 2
|
||||
|
||||
for i in range(50):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
assert losses["span_finder"] < 0.001
|
||||
|
||||
# test the trained model
|
||||
test_text = "I like London and Berlin"
|
||||
doc = nlp(test_text)
|
||||
spans = doc.spans[SPANS_KEY]
|
||||
assert len(spans) == 3
|
||||
assert set([span.text for span in spans]) == {
|
||||
"London",
|
||||
"Berlin",
|
||||
"London and Berlin",
|
||||
}
|
||||
|
||||
# Also test the results are still the same after IO
|
||||
with make_tempdir() as tmp_dir:
|
||||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = util.load_model_from_path(tmp_dir)
|
||||
doc2 = nlp2(test_text)
|
||||
spans2 = doc2.spans[SPANS_KEY]
|
||||
assert len(spans2) == 3
|
||||
assert set([span.text for span in spans2]) == {
|
||||
"London",
|
||||
"Berlin",
|
||||
"London and Berlin",
|
||||
}
|
||||
|
||||
# Test scoring
|
||||
scores = nlp.evaluate(train_examples)
|
||||
assert f"span_finder_{SPANS_KEY}_f" in scores
|
||||
# It's not perfect 1.0 F1 because it's designed to overgenerate for now.
|
||||
assert scores[f"span_finder_{SPANS_KEY}_p"] == 0.75
|
||||
assert scores[f"span_finder_{SPANS_KEY}_r"] == 1.0
|
||||
|
||||
# also test that the spancat works for just a single entity in a sentence
|
||||
doc = nlp("London")
|
||||
assert len(doc.spans[SPANS_KEY]) == 1
|
|
@ -406,6 +406,21 @@ def test_ngram_sizes(en_tokenizer):
|
|||
assert_array_equal(OPS.to_numpy(ngrams_3.lengths), [0, 1, 3, 6, 9])
|
||||
|
||||
|
||||
def test_preset_spans_suggester():
|
||||
nlp = Language()
|
||||
docs = [nlp("This is an example."), nlp("This is the second example.")]
|
||||
docs[0].spans[SPAN_KEY] = [docs[0][3:4]]
|
||||
docs[1].spans[SPAN_KEY] = [docs[1][0:4], docs[1][3:5]]
|
||||
suggester = registry.misc.get("spacy.preset_spans_suggester.v1")(spans_key=SPAN_KEY)
|
||||
candidates = suggester(docs)
|
||||
assert type(candidates) == Ragged
|
||||
assert len(candidates) == 2
|
||||
assert list(candidates.dataXd[0]) == [3, 4]
|
||||
assert list(candidates.dataXd[1]) == [0, 4]
|
||||
assert list(candidates.dataXd[2]) == [3, 5]
|
||||
assert list(candidates.lengths) == [1, 2]
|
||||
|
||||
|
||||
def test_overfitting_IO():
|
||||
# Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly
|
||||
fix_random_seed(0)
|
||||
|
@ -428,7 +443,7 @@ def test_overfitting_IO():
|
|||
spans = doc.spans[SPAN_KEY]
|
||||
assert len(spans) == 2
|
||||
assert len(spans.attrs["scores"]) == 2
|
||||
assert min(spans.attrs["scores"]) > 0.9
|
||||
assert min(spans.attrs["scores"]) > 0.8
|
||||
assert set([span.text for span in spans]) == {"London", "Berlin"}
|
||||
assert set([span.label_ for span in spans]) == {"LOC"}
|
||||
|
||||
|
@ -440,7 +455,7 @@ def test_overfitting_IO():
|
|||
spans2 = doc2.spans[SPAN_KEY]
|
||||
assert len(spans2) == 2
|
||||
assert len(spans2.attrs["scores"]) == 2
|
||||
assert min(spans2.attrs["scores"]) > 0.9
|
||||
assert min(spans2.attrs["scores"]) > 0.8
|
||||
assert set([span.text for span in spans2]) == {"London", "Berlin"}
|
||||
assert set([span.label_ for span in spans2]) == {"LOC"}
|
||||
|
||||
|
|
|
@ -105,7 +105,7 @@ architectures and their arguments and hyperparameters.
|
|||
>
|
||||
> # Construction via add_pipe with custom model
|
||||
> config = {"model": {"@architectures": "my_spancat"}}
|
||||
> parser = nlp.add_pipe("spancat", config=config)
|
||||
> spancat = nlp.add_pipe("spancat", config=config)
|
||||
>
|
||||
> # Construction from class
|
||||
> from spacy.pipeline import SpanCategorizer
|
||||
|
@ -524,3 +524,22 @@ has two columns, indicating the start and end position.
|
|||
| `min_size` | The minimal phrase lengths to suggest (inclusive). ~~[int]~~ |
|
||||
| `max_size` | The maximal phrase lengths to suggest (exclusive). ~~[int]~~ |
|
||||
| **CREATES** | The suggester function. ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
|
||||
|
||||
### spacy.preset_spans_suggester.v1 {id="preset_spans_suggester"}
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
> ```ini
|
||||
> [components.spancat.suggester]
|
||||
> @misc = "spacy.preset_spans_suggester.v1"
|
||||
> spans_key = "my_spans"
|
||||
> ```
|
||||
|
||||
Suggest all spans that are already stored in doc.spans[spans_key]. This is
|
||||
useful when an upstream component is used to set the spans on the Doc such as a
|
||||
[`SpanRuler`](/api/spanruler) or [`SpanFinder`](/api/spanfinder).
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ----------------------------------------------------------------------------- |
|
||||
| `spans_key` | Key of [`Doc.spans`](/api/doc/#spans) that provides spans to suggest. ~~str~~ |
|
||||
| **CREATES** | The suggester function. ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
|
||||
|
|
372
website/docs/api/spanfinder.mdx
Normal file
372
website/docs/api/spanfinder.mdx
Normal file
|
@ -0,0 +1,372 @@
|
|||
---
|
||||
title: SpanFinder
|
||||
tag: class,experimental
|
||||
source: spacy/pipeline/span_finder.py
|
||||
version: 3.6
|
||||
teaser:
|
||||
'Pipeline component for identifying potentially overlapping spans of text'
|
||||
api_base_class: /api/pipe
|
||||
api_string_name: span_finder
|
||||
api_trainable: true
|
||||
---
|
||||
|
||||
The span finder identifies potentially overlapping, unlabeled spans. It
|
||||
identifies tokens that start or end spans and annotates unlabeled spans between
|
||||
starts and ends, with optional filters for min and max span length. It is
|
||||
intended for use in combination with a component like
|
||||
[`SpanCategorizer`](/api/spancategorizer) that may further filter or label the
|
||||
spans. Predicted spans will be saved in a [`SpanGroup`](/api/spangroup) on the
|
||||
doc under `doc.spans[spans_key]`, where `spans_key` is a component config
|
||||
setting.
|
||||
|
||||
## Assigned Attributes {id="assigned-attributes"}
|
||||
|
||||
Predictions will be saved to `Doc.spans[spans_key]` as a
|
||||
[`SpanGroup`](/api/spangroup).
|
||||
|
||||
`spans_key` defaults to `"sc"`, but can be passed as a parameter. The
|
||||
`span_finder` component will overwrite any existing spans under the spans key
|
||||
`doc.spans[spans_key]`.
|
||||
|
||||
| Location | Value |
|
||||
| ---------------------- | ---------------------------------- |
|
||||
| `Doc.spans[spans_key]` | The unlabeled spans. ~~SpanGroup~~ |
|
||||
|
||||
## Config and implementation {id="config"}
|
||||
|
||||
The default config is defined by the pipeline component factory and describes
|
||||
how the component should be configured. You can override its settings via the
|
||||
`config` argument on [`nlp.add_pipe`](/api/language#add_pipe) or in your
|
||||
[`config.cfg` for training](/usage/training#config). See the
|
||||
[model architectures](/api/architectures) documentation for details on the
|
||||
architectures and their arguments and hyperparameters.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.pipeline.span_finder import DEFAULT_SPAN_FINDER_MODEL
|
||||
> config = {
|
||||
> "threshold": 0.5,
|
||||
> "spans_key": "my_spans",
|
||||
> "max_length": None,
|
||||
> "min_length": None,
|
||||
> "model": DEFAULT_SPAN_FINDER_MODEL,
|
||||
> }
|
||||
> nlp.add_pipe("span_finder", config=config)
|
||||
> ```
|
||||
|
||||
| Setting | Description |
|
||||
| ------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `model` | A model instance that is given a list of documents and predicts a probability for each token. ~~Model[List[Doc], Floats2d]~~ |
|
||||
| `spans_key` | Key of the [`Doc.spans`](/api/doc#spans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"sc"`. ~~str~~ |
|
||||
| `threshold` | Minimum probability to consider a prediction positive. Defaults to `0.5`. ~~float~~ |
|
||||
| `max_length` | Maximum length of the produced spans, defaults to `None` meaning unlimited length. ~~Optional[int]~~ |
|
||||
| `min_length` | Minimum length of the produced spans, defaults to `None` meaning shortest span length is 1. ~~Optional[int]~~ |
|
||||
| `scorer` | The scoring method. Defaults to [`Scorer.score_spans`](/api/scorer#score_spans) for `Doc.spans[spans_key]` with overlapping spans allowed. ~~Optional[Callable]~~ |
|
||||
|
||||
```python
|
||||
%%GITHUB_SPACY/spacy/pipeline/span_finder.py
|
||||
```
|
||||
|
||||
## SpanFinder.\_\_init\_\_ {id="init",tag="method"}
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> # Construction via add_pipe with default model
|
||||
> span_finder = nlp.add_pipe("span_finder")
|
||||
>
|
||||
> # Construction via add_pipe with custom model
|
||||
> config = {"model": {"@architectures": "my_span_finder"}}
|
||||
> span_finder = nlp.add_pipe("span_finder", config=config)
|
||||
>
|
||||
> # Construction from class
|
||||
> from spacy.pipeline import SpanFinder
|
||||
> span_finder = SpanFinder(nlp.vocab, model)
|
||||
> ```
|
||||
|
||||
Create a new pipeline instance. In your application, you would normally use a
|
||||
shortcut for this and instantiate the component using its string name and
|
||||
[`nlp.add_pipe`](/api/language#create_pipe).
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | The shared vocabulary. ~~Vocab~~ |
|
||||
| `model` | A model instance that is given a list of documents and predicts a probability for each token. ~~Model[List[Doc], Floats2d]~~ |
|
||||
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `spans_key` | Key of the [`Doc.spans`](/api/doc#spans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"sc"`. ~~str~~ |
|
||||
| `threshold` | Minimum probability to consider a prediction positive. Defaults to `0.5`. ~~float~~ |
|
||||
| `max_length` | Maximum length of the produced spans, defaults to `None` meaning unlimited length. ~~Optional[int]~~ |
|
||||
| `min_length` | Minimum length of the produced spans, defaults to `None` meaning shortest span length is 1. ~~Optional[int]~~ |
|
||||
| `scorer` | The scoring method. Defaults to [`Scorer.score_spans`](/api/scorer#score_spans) for `Doc.spans[spans_key]` with overlapping spans allowed. ~~Optional[Callable]~~ |
|
||||
|
||||
## SpanFinder.\_\_call\_\_ {id="call",tag="method"}
|
||||
|
||||
Apply the pipe to one document. The document is modified in place, and returned.
|
||||
This usually happens under the hood when the `nlp` object is called on a text
|
||||
and all pipeline components are applied to the `Doc` in order. Both
|
||||
[`__call__`](/api/spanfinder#call) and [`pipe`](/api/spanfinder#pipe) delegate
|
||||
to the [`predict`](/api/spanfinder#predict) and
|
||||
[`set_annotations`](/api/spanfinder#set_annotations) methods.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> doc = nlp("This is a sentence.")
|
||||
> span_finder = nlp.add_pipe("span_finder")
|
||||
> # This usually happens under the hood
|
||||
> processed = span_finder(doc)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | -------------------------------- |
|
||||
| `doc` | The document to process. ~~Doc~~ |
|
||||
| **RETURNS** | The processed document. ~~Doc~~ |
|
||||
|
||||
## SpanFinder.pipe {id="pipe",tag="method"}
|
||||
|
||||
Apply the pipe to a stream of documents. This usually happens under the hood
|
||||
when the `nlp` object is called on a text and all pipeline components are
|
||||
applied to the `Doc` in order. Both [`__call__`](/api/spanfinder#call) and
|
||||
[`pipe`](/api/spanfinder#pipe) delegate to the
|
||||
[`predict`](/api/spanfinder#predict) and
|
||||
[`set_annotations`](/api/spanfinder#set_annotations) methods.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_finder = nlp.add_pipe("span_finder")
|
||||
> for doc in span_finder.pipe(docs, batch_size=50):
|
||||
> pass
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------- |
|
||||
| `stream` | A stream of documents. ~~Iterable[Doc]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `batch_size` | The number of documents to buffer. Defaults to `128`. ~~int~~ |
|
||||
| **YIELDS** | The processed documents in order. ~~Doc~~ |
|
||||
|
||||
## SpanFinder.initialize {id="initialize",tag="method"}
|
||||
|
||||
Initialize the component for training. `get_examples` should be a function that
|
||||
returns an iterable of [`Example`](/api/example) objects. **At least one example
|
||||
should be supplied.** The data examples are used to **initialize the model** of
|
||||
the component and can either be the full training data or a representative
|
||||
sample. Initialization includes validating the network and
|
||||
[inferring missing shapes](https://thinc.ai/docs/usage-models#validation) This
|
||||
method is typically called by [`Language.initialize`](/api/language#initialize)
|
||||
and lets you customize arguments it receives via the
|
||||
[`[initialize.components]`](/api/data-formats#config-initialize) block in the
|
||||
config.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_finder = nlp.add_pipe("span_finder")
|
||||
> span_finder.initialize(lambda: examples, nlp=nlp)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | Function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. Must contain at least one `Example`. ~~Callable[[], Iterable[Example]]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `nlp` | The current `nlp` object. Defaults to `None`. ~~Optional[Language]~~ |
|
||||
|
||||
## SpanFinder.predict {id="predict",tag="method"}
|
||||
|
||||
Apply the component's model to a batch of [`Doc`](/api/doc) objects without
|
||||
modifying them.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_finder = nlp.add_pipe("span_finder")
|
||||
> scores = span_finder.predict([doc1, doc2])
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ------------------------------------------- |
|
||||
| `docs` | The documents to predict. ~~Iterable[Doc]~~ |
|
||||
| **RETURNS** | The model's prediction for each document. |
|
||||
|
||||
## SpanFinder.set_annotations {id="set_annotations",tag="method"}
|
||||
|
||||
Modify a batch of [`Doc`](/api/doc) objects using pre-computed scores.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_finder = nlp.add_pipe("span_finder")
|
||||
> scores = span_finder.predict(docs)
|
||||
> span_finder.set_annotations(docs, scores)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------- | ---------------------------------------------------- |
|
||||
| `docs` | The documents to modify. ~~Iterable[Doc]~~ |
|
||||
| `scores` | The scores to set, produced by `SpanFinder.predict`. |
|
||||
|
||||
## SpanFinder.update {id="update",tag="method"}
|
||||
|
||||
Learn from a batch of [`Example`](/api/example) objects containing the
|
||||
predictions and gold-standard annotations, and update the component's model.
|
||||
Delegates to [`predict`](/api/spanfinder#predict) and
|
||||
[`get_loss`](/api/spanfinder#get_loss).
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_finder = nlp.add_pipe("span_finder")
|
||||
> optimizer = nlp.initialize()
|
||||
> losses = span_finder.update(examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `drop` | The dropout rate. ~~float~~ |
|
||||
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
||||
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||
|
||||
## SpanFinder.get_loss {id="get_loss",tag="method"}
|
||||
|
||||
Find the loss and gradient of loss for the batch of documents and their
|
||||
predicted scores.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_finder = nlp.add_pipe("span_finder")
|
||||
> scores = span_finder.predict([eg.predicted for eg in examples])
|
||||
> loss, d_loss = span_finder.get_loss(examples, scores)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------ |
|
||||
| `examples` | The batch of examples. ~~Iterable[Example]~~ |
|
||||
| `spans_scores` | Scores representing the model's predictions. ~~Tuple[Ragged, Floats2d]~~ |
|
||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, Floats2d]~~ |
|
||||
|
||||
## SpanFinder.create_optimizer {id="create_optimizer",tag="method"}
|
||||
|
||||
Create an optimizer for the pipeline component.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_finder = nlp.add_pipe("span_finder")
|
||||
> optimizer = span_finder.create_optimizer()
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ---------------------------- |
|
||||
| **RETURNS** | The optimizer. ~~Optimizer~~ |
|
||||
|
||||
## SpanFinder.use_params {id="use_params",tag="method, contextmanager"}
|
||||
|
||||
Modify the pipe's model to use the given parameter values.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_finder = nlp.add_pipe("span_finder")
|
||||
> with span_finder.use_params(optimizer.averages):
|
||||
> span_finder.to_disk("/best_model")
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------- | -------------------------------------------------- |
|
||||
| `params` | The parameter values to use in the model. ~~dict~~ |
|
||||
|
||||
## SpanFinder.to_disk {id="to_disk",tag="method"}
|
||||
|
||||
Serialize the pipe to disk.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_finder = nlp.add_pipe("span_finder")
|
||||
> span_finder.to_disk("/path/to/span_finder")
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||
|
||||
## SpanFinder.from_disk {id="from_disk",tag="method"}
|
||||
|
||||
Load the pipe from disk. Modifies the object in place and returns it.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_finder = nlp.add_pipe("span_finder")
|
||||
> span_finder.from_disk("/path/to/span_finder")
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ----------------------------------------------------------------------------------------------- |
|
||||
| `path` | A path to a directory. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||
| **RETURNS** | The modified `SpanFinder` object. ~~SpanFinder~~ |
|
||||
|
||||
## SpanFinder.to_bytes {id="to_bytes",tag="method"}
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_finder = nlp.add_pipe("span_finder")
|
||||
> span_finder_bytes = span_finder.to_bytes()
|
||||
> ```
|
||||
|
||||
Serialize the pipe to a bytestring.
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------- |
|
||||
| _keyword-only_ | |
|
||||
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||
| **RETURNS** | The serialized form of the `SpanFinder` object. ~~bytes~~ |
|
||||
|
||||
## SpanFinder.from_bytes {id="from_bytes",tag="method"}
|
||||
|
||||
Load the pipe from a bytestring. Modifies the object in place and returns it.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_finder_bytes = span_finder.to_bytes()
|
||||
> span_finder = nlp.add_pipe("span_finder")
|
||||
> span_finder.from_bytes(span_finder_bytes)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------- |
|
||||
| `bytes_data` | The data to load from. ~~bytes~~ |
|
||||
| _keyword-only_ | |
|
||||
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||
| **RETURNS** | The `SpanFinder` object. ~~SpanFinder~~ |
|
||||
|
||||
## Serialization fields {id="serialization-fields"}
|
||||
|
||||
During serialization, spaCy will export several data fields used to restore
|
||||
different aspects of the object. If needed, you can exclude them from
|
||||
serialization by passing in the string names via the `exclude` argument.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> data = span_finder.to_disk("/path", exclude=["vocab"])
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ------- | -------------------------------------------------------------- |
|
||||
| `vocab` | The shared [`Vocab`](/api/vocab). |
|
||||
| `cfg` | The config file. You usually don't want to exclude this. |
|
||||
| `model` | The binary model data. You usually don't want to exclude this. |
|
|
@ -106,6 +106,7 @@
|
|||
{ "text": "SentenceRecognizer", "url": "/api/sentencerecognizer" },
|
||||
{ "text": "Sentencizer", "url": "/api/sentencizer" },
|
||||
{ "text": "SpanCategorizer", "url": "/api/spancategorizer" },
|
||||
{ "text": "SpanFinder", "url": "/api/spanfinder" },
|
||||
{ "text": "SpanResolver", "url": "/api/span-resolver" },
|
||||
{ "text": "SpanRuler", "url": "/api/spanruler" },
|
||||
{ "text": "Tagger", "url": "/api/tagger" },
|
||||
|
|
Loading…
Reference in New Issue
Block a user