initialize and update explanation

This commit is contained in:
svlandeg 2020-10-05 00:39:36 +02:00
parent b0463fbf75
commit 52b660e9dc
2 changed files with 119 additions and 36 deletions

View File

@ -226,6 +226,12 @@ the "catastrophic forgetting" problem. This feature is experimental.
Find the loss and gradient of loss for the batch of documents and their Find the loss and gradient of loss for the batch of documents and their
predicted scores. predicted scores.
<Infobox variant="danger">
This method needs to be overwritten with your own custom `get_loss` method.
</Infobox>
> #### Example > #### Example
> >
> ```python > ```python

View File

@ -618,31 +618,97 @@ create a subclass of [`Pipe`](/api/pipe) that will hold the model:
```python ```python
from spacy.pipeline import Pipe from spacy.pipeline import Pipe
from spacy.language import Language
class RelationExtractor(Pipe): class RelationExtractor(Pipe):
def __init__(self, vocab, model, name="rel", labels=[]): def __init__(self, vocab, model, name="rel", labels=[]):
self.model = model
... ...
def predict(self, docs): def predict(self, docs):
... ...
def set_annotations(self, docs, scores): def set_annotations(self, docs, predictions):
... ...
@Language.factory("relation_extractor")
def make_relation_extractor(nlp, name, model, labels):
return RelationExtractor(nlp.vocab, model, name, labels=labels)
``` ```
Before the model can be used however, it needs to be
[initialized](/api/pipe#initialize). This function recieves either the full
training data set, or a representative sample. The training data can be used
to deduce all relevant labels. Alternatively, a list of labels can be provided,
or a script can call `rel_component.add_label()` to add each label separately.
The number of labels will define the output dimensionality of the network,
and will be used to do
[shape inference](https://thinc.ai/docs/usage-models#validation) throughout
the layers of the neural network. This is triggerd by calling `model.initialize`.
```python
from itertools import islice
def initialize(
self,
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Language = None,
labels: Optional[List[str]] = None,
):
if labels is not None:
for label in labels:
self.add_label(label)
else:
for example in get_examples():
relations = example.reference._.rel
for indices, label_dict in relations.items():
for label in label_dict.keys():
self.add_label(label)
subbatch = list(islice(get_examples(), 10))
doc_sample = [eg.reference for eg in subbatch]
label_sample = self._examples_to_truth(subbatch)
self.model.initialize(X=doc_sample, Y=label_sample)
```
The `initialize` method will be triggered whenever this component is part of an
`nlp` pipeline, and `nlp.initialize()` is invoked. After doing so, the pipeline
component and its internal model can be trained and used to make predictions.
During training the function [`update`](/api/pipe#update) is invoked which delegates to
[`self.model.begin_update`](https://thinc.ai/docs/api-model#begin_update) and
needs a function [`get_loss`](/api/pipe#get_loss) that will calculate the
loss for a batch of examples, as well as the gradient of loss that will be used to update
the weights of the model layers.
```python
def update(
self,
examples: Iterable[Example],
*,
drop: float = 0.0,
set_annotations: bool = False,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
) -> Dict[str, float]:
...
docs = [ex.predicted for ex in examples]
predictions, backprop = self.model.begin_update(docs)
loss, gradient = self.get_loss(examples, predictions)
backprop(gradient)
losses[self.name] += loss
...
return losses
```
Thinc provides some [loss functions](https://thinc.ai/docs/api-loss) that can be used
for the implementation of the `get_loss` function.
When the internal model is trained, the component can be used to make novel predictions.
The [`predict`](/api/pipe#predict) function needs to be implemented for each The [`predict`](/api/pipe#predict) function needs to be implemented for each
subclass. In our case, we can simply delegate to the internal model's subclass of `Pipe`. In our case, we can simply delegate to the internal model's
[predict](https://thinc.ai/docs/api-model#predict) function: [predict](https://thinc.ai/docs/api-model#predict) function:
```python ```python
def predict(self, docs: Iterable[Doc]) -> Floats2d: def predict(self, docs: Iterable[Doc]) -> Floats2d:
scores = self.model.predict(docs) predictions = self.model.predict(docs)
return self.model.ops.asarray(scores) return self.model.ops.asarray(predictions)
``` ```
The other method that needs to be implemented, is The other method that needs to be implemented, is
@ -650,7 +716,7 @@ The other method that needs to be implemented, is
and modifies the given `Doc` object in place to hold the predictions. For our and modifies the given `Doc` object in place to hold the predictions. For our
relation extraction component, we'll store the data as a dictionary in a custom relation extraction component, we'll store the data as a dictionary in a custom
extension attribute `doc._.rel`. As keys, we represent the candidate pair by the extension attribute `doc._.rel`. As keys, we represent the candidate pair by the
start offsets of each entity, as this defines an entity uniquely within one start offsets of each entity, as this defines an entity pair uniquely within one
document. document.
To interpret the scores predicted by the REL model correctly, we need to To interpret the scores predicted by the REL model correctly, we need to
@ -674,7 +740,7 @@ related to those exact entities:
> ``` > ```
```python ```python
def set_annotations(self, docs: Iterable[Doc], rel_scores: Floats2d): def set_annotations(self, docs: Iterable[Doc], predictions: Floats2d):
c = 0 c = 0
get_candidates = self.model.attrs["get_candidates"] get_candidates = self.model.attrs["get_candidates"]
for doc in docs: for doc in docs:
@ -683,34 +749,45 @@ def set_annotations(self, docs: Iterable[Doc], rel_scores: Floats2d):
if offset not in doc._.rel: if offset not in doc._.rel:
doc._.rel[offset] = {} doc._.rel[offset] = {}
for j, label in enumerate(self.labels): for j, label in enumerate(self.labels):
doc._.rel[offset][label] = rel_scores[c, j] doc._.rel[offset][label] = predictions[c, j]
c += 1 c += 1
``` ```
Under the hood, when the pipe is applied to a document, it will delegate to these
two methods:
<Infobox title="This section is still under construction" emoji="🚧" variant="warning">
</Infobox>
<!-- TODO: write trainable component section
- Interaction with `predict`, `get_loss` and `set_annotations`
- Initialization life-cycle with `initialize`, correlation with add_label
Example: relation extraction component (implemented as project template)
Avoid duplication with usage/processing-pipelines#trainable-components ?
-->
<!-- ![Diagram of a pipeline component with its model](../images/layers-architectures.svg)
```python ```python
def update(self, examples): def __call__(self, Doc doc):
docs = [ex.predicted for ex in examples] predictions = self.predict([doc])
refs = [ex.reference for ex in examples] self.set_annotations([doc], predictions)
predictions, backprop = self.model.begin_update(docs) return doc
gradient = self.get_loss(predictions, refs)
backprop(gradient)
def __call__(self, doc):
predictions = self.model([doc])
self.set_annotations(predictions)
``` ```
-->
Once our `Pipe` subclass is fully implemented, we can
[register](http://localhost:8000/usage/processing-pipelines#custom-components-factories)
the component with the
`Language.factory` decorator. This will enable the creation of the component with
`nlp.add_pipe`, or via the config.
> ```
>
> [components.relation_extractor]
> factory = "relation_extractor"
> labels = []
>
> [components.relation_extractor.model]
> @architectures = "rel_model.v1"
> ...
> ```
```python
from spacy.language import Language
@Language.factory("relation_extractor")
def make_relation_extractor(nlp, name, model, labels):
return RelationExtractor(nlp.vocab, model, name, labels=labels)
```
<!-- TODO: refer once more to example project -->
<!-- ![Diagram of a pipeline component with its model](../images/layers-architectures.svg) -->