mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
merge
This commit is contained in:
commit
403fb95d56
|
@ -14,17 +14,15 @@ from .coref_util import add_dummy
|
|||
@registry.architectures("spacy.Coref.v1")
|
||||
def build_wl_coref_model(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
embedding_size: int = 20,
|
||||
distance_embedding_size: int = 20,
|
||||
hidden_size: int = 1024,
|
||||
depth: int = 1,
|
||||
dropout: float = 0.3,
|
||||
# pairs to keep per mention after rough scoring
|
||||
rough_candidates: int = 50,
|
||||
# TODO is this not a training loop setting?
|
||||
a_scoring_batch_size: int = 512,
|
||||
# span predictor embeddings
|
||||
sp_embedding_size: int = 64,
|
||||
antecedent_limit: int = 50,
|
||||
antecedent_batch_size: int = 512,
|
||||
):
|
||||
# TODO add model return types
|
||||
# TODO fix this
|
||||
try:
|
||||
dim = tok2vec.get_dim("nO")
|
||||
|
@ -36,12 +34,12 @@ def build_wl_coref_model(
|
|||
coref_scorer = PyTorchWrapper(
|
||||
CorefScorer(
|
||||
dim,
|
||||
embedding_size,
|
||||
distance_embedding_size,
|
||||
hidden_size,
|
||||
depth,
|
||||
dropout,
|
||||
rough_candidates,
|
||||
a_scoring_batch_size,
|
||||
antecedent_limit,
|
||||
antecedent_batch_size,
|
||||
),
|
||||
convert_inputs=convert_coref_scorer_inputs,
|
||||
convert_outputs=convert_coref_scorer_outputs,
|
||||
|
@ -100,7 +98,7 @@ class CorefScorer(torch.nn.Module):
|
|||
dist_emb_size: int,
|
||||
hidden_size: int,
|
||||
n_layers: int,
|
||||
dropout_rate: float,
|
||||
dropout: float,
|
||||
roughk: int,
|
||||
batch_size: int,
|
||||
):
|
||||
|
@ -110,31 +108,31 @@ class CorefScorer(torch.nn.Module):
|
|||
dist_emb_size: Size of the distance embeddings.
|
||||
hidden_size: Size of the coreference candidate embeddings.
|
||||
n_layers: Numbers of layers in the AnaphoricityScorer.
|
||||
dropout_rate: Dropout probability to apply across all modules.
|
||||
dropout: Dropout probability to apply across all modules.
|
||||
roughk: Number of candidates the RoughScorer returns.
|
||||
batch_size: Internal batch-size for the more expensive scorer.
|
||||
"""
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.batch_size = batch_size
|
||||
# Modules
|
||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
|
||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
|
||||
pair_emb = dim * 3 + self.pw.shape
|
||||
self.a_scorer = AnaphoricityScorer(
|
||||
pair_emb,
|
||||
hidden_size,
|
||||
n_layers,
|
||||
dropout_rate
|
||||
dropout
|
||||
)
|
||||
self.lstm = torch.nn.LSTM(
|
||||
input_size=dim,
|
||||
hidden_size=dim,
|
||||
batch_first=True,
|
||||
)
|
||||
self.rough_scorer = RoughScorer(dim, dropout_rate, roughk)
|
||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
|
||||
self.rough_scorer = RoughScorer(dim, dropout, roughk)
|
||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
|
||||
pair_emb = dim * 3 + self.pw.shape
|
||||
self.a_scorer = AnaphoricityScorer(
|
||||
pair_emb, hidden_size, n_layers, dropout_rate
|
||||
pair_emb, hidden_size, n_layers, dropout
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -191,18 +189,18 @@ class CorefScorer(torch.nn.Module):
|
|||
class AnaphoricityScorer(torch.nn.Module):
|
||||
"""Calculates anaphoricity scores by passing the inputs into a FFNN"""
|
||||
|
||||
def __init__(self, in_features: int, hidden_size, n_hidden_layers, dropout_rate):
|
||||
def __init__(self, in_features: int, hidden_size, depth, dropout):
|
||||
super().__init__()
|
||||
hidden_size = hidden_size
|
||||
if not n_hidden_layers:
|
||||
if not depth:
|
||||
hidden_size = in_features
|
||||
layers = []
|
||||
for i in range(n_hidden_layers):
|
||||
for i in range(depth):
|
||||
layers.extend(
|
||||
[
|
||||
torch.nn.Linear(hidden_size if i else in_features, hidden_size),
|
||||
torch.nn.LeakyReLU(),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.Dropout(dropout),
|
||||
]
|
||||
)
|
||||
self.hidden = torch.nn.Sequential(*layers)
|
||||
|
@ -244,7 +242,7 @@ class AnaphoricityScorer(torch.nn.Module):
|
|||
def _ffnn(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: tensor of shape (batch_size x roughk x n_features
|
||||
returns: tensor of shape (batch_size x rough_k)
|
||||
returns: tensor of shape (batch_size x antecedent_limit)
|
||||
"""
|
||||
x = self.out(self.hidden(x))
|
||||
return x.squeeze(2)
|
||||
|
@ -290,11 +288,11 @@ class RoughScorer(torch.nn.Module):
|
|||
steps to reduce computational cost.
|
||||
"""
|
||||
|
||||
def __init__(self, features: int, dropout_rate: float, rough_k: float):
|
||||
def __init__(self, features: int, dropout: float, antecedent_limit: int):
|
||||
super().__init__()
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.bilinear = torch.nn.Linear(features, features)
|
||||
self.k = rough_k
|
||||
self.k = antecedent_limit
|
||||
|
||||
def forward(
|
||||
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
||||
|
@ -318,7 +316,7 @@ class RoughScorer(torch.nn.Module):
|
|||
|
||||
|
||||
class DistancePairwiseEncoder(torch.nn.Module):
|
||||
def __init__(self, embedding_size, dropout_rate):
|
||||
def __init__(self, distance_embedding_size, dropout):
|
||||
"""
|
||||
Takes the top_indices indicating, which is a ranked
|
||||
list for each word and its most likely corresponding
|
||||
|
@ -326,15 +324,15 @@ class DistancePairwiseEncoder(torch.nn.Module):
|
|||
up a distance embedding from a table, where the distance
|
||||
corresponds to the log-distance.
|
||||
|
||||
embedding_size: int,
|
||||
distance_embedding_size: int,
|
||||
Dimensionality of the distance-embeddings table.
|
||||
dropout_rate: float,
|
||||
dropout: float,
|
||||
Dropout probability.
|
||||
"""
|
||||
super().__init__()
|
||||
emb_size = embedding_size
|
||||
emb_size = distance_embedding_size
|
||||
self.distance_emb = torch.nn.Embedding(9, emb_size)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.shape = emb_size
|
||||
|
||||
def forward(
|
||||
|
|
|
@ -18,6 +18,7 @@ def build_span_predictor(
|
|||
dist_emb_size: int = 64,
|
||||
prefix: str = "coref_head_clusters"
|
||||
):
|
||||
# TODO add model return types
|
||||
# TODO fix this
|
||||
try:
|
||||
dim = tok2vec.get_dim("nO")
|
||||
|
|
|
@ -31,13 +31,12 @@ from ..coref_scorer import Evaluator, get_cluster_info, lea
|
|||
default_config = """
|
||||
[model]
|
||||
@architectures = "spacy.Coref.v1"
|
||||
embedding_size = 20
|
||||
distance_embedding_size = 20
|
||||
hidden_size = 1024
|
||||
depth = 1
|
||||
dropout = 0.3
|
||||
rough_candidates = 50
|
||||
a_scoring_batch_size = 512
|
||||
sp_embedding_size = 64
|
||||
antecedent_limit = 50
|
||||
antecedent_batch_size = 512
|
||||
|
||||
[model.tok2vec]
|
||||
@architectures = "spacy.Tok2Vec.v2"
|
||||
|
|
|
@ -922,3 +922,77 @@ A function that takes as input a [`KnowledgeBase`](/api/kb) and a
|
|||
plausible [`Candidate`](/api/kb/#candidate) objects. The default
|
||||
`CandidateGenerator` simply uses the text of a mention to find its potential
|
||||
aliases in the `KnowledgeBase`. Note that this function is case-dependent.
|
||||
|
||||
## Coreference Architectures
|
||||
|
||||
A [`CoreferenceResolver`](/api/coref) component identifies tokens that refer to
|
||||
the same entity. A [`SpanPredictor`](/api/span-predictor) component infers spans
|
||||
from single tokens. Together these components can be used to reproduce
|
||||
traditional coreference models. You can also omit the `SpanPredictor` for faster
|
||||
performance if working with only token-level clusters is acceptable.
|
||||
|
||||
### spacy.Coref.v1 {#Coref}
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
> ```ini
|
||||
>
|
||||
> [model]
|
||||
> @architectures = "spacy.Coref.v1"
|
||||
> distance_embedding_size = 20
|
||||
> dropout = 0.3
|
||||
> hidden_size = 1024
|
||||
> depth = 2
|
||||
> antecedent_limit = 50
|
||||
> antecedent_batch_size = 512
|
||||
>
|
||||
> [model.tok2vec]
|
||||
> @architectures = "spacy-transformers.TransformerListener.v1"
|
||||
> grad_factor = 1.0
|
||||
> upstream = "transformer"
|
||||
> pooling = {"@layers":"reduce_mean.v1"}
|
||||
> ```
|
||||
|
||||
The `Coref` model architecture is a Thinc `Model`.
|
||||
|
||||
| Name | Description |
|
||||
| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
|
||||
| `distance_embedding_size` | A representation of the distance between candidates. ~~int~~ |
|
||||
| `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ |
|
||||
| `hidden_size` | Size of the main internal layers. ~~int~~ |
|
||||
| `depth` | Depth of the internal network. ~~int~~ |
|
||||
| `antecedent_limit` | How many candidate antecedents to keep after rough scoring. This has a significant effect on memory usage. Typical values would be 50 to 200, or higher for very long documents. ~~int~~ |
|
||||
| `antecedent_batch_size` | Internal batch size. ~~int~~ |
|
||||
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
|
||||
|
||||
### spacy.SpanPredictor.v1 {#SpanPredictor}
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
> ```ini
|
||||
>
|
||||
> [model]
|
||||
> @architectures = "spacy.SpanPredictor.v1"
|
||||
> hidden_size = 1024
|
||||
> dist_emb_size = 64
|
||||
>
|
||||
> [model.tok2vec]
|
||||
> @architectures = "spacy-transformers.TransformerListener.v1"
|
||||
> grad_factor = 1.0
|
||||
> upstream = "transformer"
|
||||
> pooling = {"@layers":"reduce_mean.v1"}
|
||||
> ```
|
||||
|
||||
The `SpanPredictor` model architecture is a Thinc `Model`.
|
||||
|
||||
| Name | Description |
|
||||
| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
|
||||
| `distance_embedding_size` | A representation of the distance between two candidates. ~~int~~ |
|
||||
| `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ |
|
||||
| `hidden_size` | Size of the main internal layers. ~~int~~ |
|
||||
| `depth` | Depth of the internal network. ~~int~~ |
|
||||
| `antecedent_limit` | How many candidate antecedents to keep after rough scoring. This has a significant effect on memory usage. Typical values would be 50 to 200, or higher for very long documents. ~~int~~ |
|
||||
| `antecedent_batch_size` | Internal batch size. ~~int~~ |
|
||||
| **CREATES** | The model using the architecture. ~~Model[List[Doc], TupleFloats2d]~~ |
|
||||
|
|
|
@ -92,9 +92,9 @@ shortcut for this and instantiate the component using its string name and
|
|||
Apply the pipe to one document. The document is modified in place and returned.
|
||||
This usually happens under the hood when the `nlp` object is called on a text
|
||||
and all pipeline components are applied to the `Doc` in order. Both
|
||||
[`__call__`](/api/entitylinker#call) and [`pipe`](/api/entitylinker#pipe)
|
||||
delegate to the [`predict`](/api/entitylinker#predict) and
|
||||
[`set_annotations`](/api/entitylinker#set_annotations) methods.
|
||||
[`__call__`](/api/coref#call) and [`pipe`](/api/coref#pipe) delegate to the
|
||||
[`predict`](/api/coref#predict) and
|
||||
[`set_annotations`](/api/coref#set_annotations) methods.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -197,7 +197,7 @@ Modify a batch of documents, saving coreference clusters in `Doc.spans`.
|
|||
## CoreferenceResolver.update {#update tag="method"}
|
||||
|
||||
Learn from a batch of [`Example`](/api/example) objects. Delegates to
|
||||
[`predict`](/api/entitylinker#predict).
|
||||
[`predict`](/api/coref#predict).
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
|
340
website/docs/api/span-predictor.md
Normal file
340
website/docs/api/span-predictor.md
Normal file
|
@ -0,0 +1,340 @@
|
|||
---
|
||||
title: SpanPredictor
|
||||
tag: class
|
||||
source: spacy/pipeline/span_predictor.py
|
||||
new: 3.4
|
||||
teaser: 'Pipeline component for resolving tokens into spans'
|
||||
api_base_class: /api/pipe
|
||||
api_string_name: span_predictor
|
||||
api_trainable: true
|
||||
---
|
||||
|
||||
A `SpanPredictor` component takes in tokens (represented as `Span`s of length
|
||||
|
||||
1. and resolves them into `Span`s of arbitrary length. The initial use case is
|
||||
as a post-processing step on word-level [coreference resolution](/api/coref).
|
||||
The input and output keys used to store `Span`s are configurable.
|
||||
|
||||
## Assigned Attributes {#assigned-attributes}
|
||||
|
||||
Predictions will be saved to `Doc.spans` as [`SpanGroup`s](/api/spangroup).
|
||||
|
||||
Input token spans will be read in using an input prefix, by default
|
||||
`"coref_head_clusters"`, and output spans will be saved using an output prefix
|
||||
(default `"coref_clusters"`) plus a serial number starting from zero. The
|
||||
prefixes are configurable.
|
||||
|
||||
| Location | Value |
|
||||
| ------------------------------------------------- | ------------------------------------------- |
|
||||
| `Doc.spans[output_prefix + "_" + cluster_number]` | One group of predicted spans. ~~SpanGroup~~ |
|
||||
|
||||
## Config and implementation {#config}
|
||||
|
||||
The default config is defined by the pipeline component factory and describes
|
||||
how the component should be configured. You can override its settings via the
|
||||
`config` argument on [`nlp.add_pipe`](/api/language#add_pipe) or in your
|
||||
[`config.cfg` for training](/usage/training#config). See the
|
||||
[model architectures](/api/architectures) documentation for details on the
|
||||
architectures and their arguments and hyperparameters.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.pipeline.span_predictor import DEFAULT_SPAN_PREDICTOR_MODEL
|
||||
> config={
|
||||
> "model": DEFAULT_SPAN_PREDICTOR_MODEL,
|
||||
> "span_cluster_prefix": DEFAULT_CLUSTER_PREFIX,
|
||||
> },
|
||||
> nlp.add_pipe("span_predictor", config=config)
|
||||
> ```
|
||||
|
||||
| Setting | Description |
|
||||
| --------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [SpanPredictor](/api/architectures#SpanPredictor). ~~Model~~ |
|
||||
| `input_prefix` | The prefix to use for input `SpanGroup`s. Defaults to `coref_head_clusters`. ~~str~~ |
|
||||
| `output_prefix` | The prefix for predicted `SpanGroup`s. Defaults to `coref_clusters`. ~~str~~ |
|
||||
|
||||
```python
|
||||
%%GITHUB_SPACY/spacy/pipeline/span_predictor.py
|
||||
```
|
||||
|
||||
## SpanPredictor.\_\_init\_\_ {#init tag="method"}
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> # Construction via add_pipe with default model
|
||||
> span_predictor = nlp.add_pipe("span_predictor")
|
||||
>
|
||||
> # Construction via add_pipe with custom model
|
||||
> config = {"model": {"@architectures": "my_span_predictor.v1"}}
|
||||
> span_predictor = nlp.add_pipe("span_predictor", config=config)
|
||||
>
|
||||
> # Construction from class
|
||||
> from spacy.pipeline import SpanPredictor
|
||||
> span_predictor = SpanPredictor(nlp.vocab, model)
|
||||
> ```
|
||||
|
||||
Create a new pipeline instance. In your application, you would normally use a
|
||||
shortcut for this and instantiate the component using its string name and
|
||||
[`nlp.add_pipe`](/api/language#add_pipe).
|
||||
|
||||
| Name | Description |
|
||||
| --------------- | --------------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | The shared vocabulary. ~~Vocab~~ |
|
||||
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model~~ |
|
||||
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `input_prefix` | The prefix to use for input `SpanGroup`s. Defaults to `coref_head_clusters`. ~~str~~ |
|
||||
| `output_prefix` | The prefix for predicted `SpanGroup`s. Defaults to `coref_clusters`. ~~str~~ |
|
||||
|
||||
## SpanPredictor.\_\_call\_\_ {#call tag="method"}
|
||||
|
||||
Apply the pipe to one document. The document is modified in place and returned.
|
||||
This usually happens under the hood when the `nlp` object is called on a text
|
||||
and all pipeline components are applied to the `Doc` in order. Both
|
||||
[`__call__`](#call) and [`pipe`](#pipe) delegate to the [`predict`](#predict)
|
||||
and [`set_annotations`](#set_annotations) methods.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> doc = nlp("This is a sentence.")
|
||||
> span_predictor = nlp.add_pipe("span_predictor")
|
||||
> # This usually happens under the hood
|
||||
> processed = span_predictor(doc)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | -------------------------------- |
|
||||
| `doc` | The document to process. ~~Doc~~ |
|
||||
| **RETURNS** | The processed document. ~~Doc~~ |
|
||||
|
||||
## SpanPredictor.pipe {#pipe tag="method"}
|
||||
|
||||
Apply the pipe to a stream of documents. This usually happens under the hood
|
||||
when the `nlp` object is called on a text and all pipeline components are
|
||||
applied to the `Doc` in order. Both [`__call__`](/api/span-predictor#call) and
|
||||
[`pipe`](/api/span-predictor#pipe) delegate to the
|
||||
[`predict`](/api/span-predictor#predict) and
|
||||
[`set_annotations`](/api/span-predictor#set_annotations) methods.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_predictor = nlp.add_pipe("span_predictor")
|
||||
> for doc in span_predictor.pipe(docs, batch_size=50):
|
||||
> pass
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------- |
|
||||
| `stream` | A stream of documents. ~~Iterable[Doc]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `batch_size` | The number of documents to buffer. Defaults to `128`. ~~int~~ |
|
||||
| **YIELDS** | The processed documents in order. ~~Doc~~ |
|
||||
|
||||
## SpanPredictor.initialize {#initialize tag="method"}
|
||||
|
||||
Initialize the component for training. `get_examples` should be a function that
|
||||
returns an iterable of [`Example`](/api/example) objects. The data examples are
|
||||
used to **initialize the model** of the component and can either be the full
|
||||
training data or a representative sample. Initialization includes validating the
|
||||
network,
|
||||
[inferring missing shapes](https://thinc.ai/docs/usage-models#validation) and
|
||||
setting up the label scheme based on the data. This method is typically called
|
||||
by [`Language.initialize`](/api/language#initialize).
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_predictor = nlp.add_pipe("span_predictor")
|
||||
> span_predictor.initialize(lambda: [], nlp=nlp)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | Function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. ~~Callable[[], Iterable[Example]]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `nlp` | The current `nlp` object. Defaults to `None`. ~~Optional[Language]~~ |
|
||||
|
||||
## SpanPredictor.predict {#predict tag="method"}
|
||||
|
||||
Apply the component's model to a batch of [`Doc`](/api/doc) objects, without
|
||||
modifying them. Predictions are returned as a list of `MentionClusters`, one for
|
||||
each input `Doc`. A `MentionClusters` instance is just a list of lists of pairs
|
||||
of `int`s, where each item corresponds to an input `SpanGroup`, and the `int`s
|
||||
correspond to token indices.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_predictor = nlp.add_pipe("span_predictor")
|
||||
> spans = span_predictor.predict([doc1, doc2])
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ------------------------------------------------------------- |
|
||||
| `docs` | The documents to predict. ~~Iterable[Doc]~~ |
|
||||
| **RETURNS** | The predicted spans for the `Doc`s. ~~List[MentionClusters]~~ |
|
||||
|
||||
## SpanPredictor.set_annotations {#set_annotations tag="method"}
|
||||
|
||||
Modify a batch of documents, saving predictions using the output prefix in
|
||||
`Doc.spans`.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_predictor = nlp.add_pipe("span_predictor")
|
||||
> spans = span_predictor.predict([doc1, doc2])
|
||||
> span_predictor.set_annotations([doc1, doc2], spans)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ------- | ------------------------------------------------------------- |
|
||||
| `docs` | The documents to modify. ~~Iterable[Doc]~~ |
|
||||
| `spans` | The predicted spans for the `docs`. ~~List[MentionClusters]~~ |
|
||||
|
||||
## SpanPredictor.update {#update tag="method"}
|
||||
|
||||
Learn from a batch of [`Example`](/api/example) objects. Delegates to
|
||||
[`predict`](/api/span-predictor#predict).
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_predictor = nlp.add_pipe("span_predictor")
|
||||
> optimizer = nlp.initialize()
|
||||
> losses = span_predictor.update(examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `drop` | The dropout rate. ~~float~~ |
|
||||
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
||||
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||
|
||||
## SpanPredictor.create_optimizer {#create_optimizer tag="method"}
|
||||
|
||||
Create an optimizer for the pipeline component.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_predictor = nlp.add_pipe("span_predictor")
|
||||
> optimizer = span_predictor.create_optimizer()
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ---------------------------- |
|
||||
| **RETURNS** | The optimizer. ~~Optimizer~~ |
|
||||
|
||||
## SpanPredictor.use_params {#use_params tag="method, contextmanager"}
|
||||
|
||||
Modify the pipe's model, to use the given parameter values. At the end of the
|
||||
context, the original parameters are restored.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_predictor = nlp.add_pipe("span_predictor")
|
||||
> with span_predictor.use_params(optimizer.averages):
|
||||
> span_predictor.to_disk("/best_model")
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------- | -------------------------------------------------- |
|
||||
| `params` | The parameter values to use in the model. ~~dict~~ |
|
||||
|
||||
## SpanPredictor.to_disk {#to_disk tag="method"}
|
||||
|
||||
Serialize the pipe to disk.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_predictor = nlp.add_pipe("span_predictor")
|
||||
> span_predictor.to_disk("/path/to/span_predictor")
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||
|
||||
## SpanPredictor.from_disk {#from_disk tag="method"}
|
||||
|
||||
Load the pipe from disk. Modifies the object in place and returns it.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_predictor = nlp.add_pipe("span_predictor")
|
||||
> span_predictor.from_disk("/path/to/span_predictor")
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ----------------------------------------------------------------------------------------------- |
|
||||
| `path` | A path to a directory. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||
| **RETURNS** | The modified `SpanPredictor` object. ~~SpanPredictor~~ |
|
||||
|
||||
## SpanPredictor.to_bytes {#to_bytes tag="method"}
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_predictor = nlp.add_pipe("span_predictor")
|
||||
> span_predictor_bytes = span_predictor.to_bytes()
|
||||
> ```
|
||||
|
||||
Serialize the pipe to a bytestring.
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------- |
|
||||
| _keyword-only_ | |
|
||||
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||
| **RETURNS** | The serialized form of the `SpanPredictor` object. ~~bytes~~ |
|
||||
|
||||
## SpanPredictor.from_bytes {#from_bytes tag="method"}
|
||||
|
||||
Load the pipe from a bytestring. Modifies the object in place and returns it.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> span_predictor_bytes = span_predictor.to_bytes()
|
||||
> span_predictor = nlp.add_pipe("span_predictor")
|
||||
> span_predictor.from_bytes(span_predictor_bytes)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------- |
|
||||
| `bytes_data` | The data to load from. ~~bytes~~ |
|
||||
| _keyword-only_ | |
|
||||
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||
| **RETURNS** | The `SpanPredictor` object. ~~SpanPredictor~~ |
|
||||
|
||||
## Serialization fields {#serialization-fields}
|
||||
|
||||
During serialization, spaCy will export several data fields used to restore
|
||||
different aspects of the object. If needed, you can exclude them from
|
||||
serialization by passing in the string names via the `exclude` argument.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> data = span_predictor.to_disk("/path", exclude=["vocab"])
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ------- | -------------------------------------------------------------- |
|
||||
| `vocab` | The shared [`Vocab`](/api/vocab). |
|
||||
| `cfg` | The config file. You usually don't want to exclude this. |
|
||||
| `model` | The binary model data. You usually don't want to exclude this. |
|
Loading…
Reference in New Issue
Block a user