how to register and use custom function

This commit is contained in:
svlandeg 2020-09-08 20:22:20 +02:00
parent b35a26ea5d
commit 1c476b4b41

View File

@ -281,14 +281,70 @@ the transformer layers, and use "native" Thinc layers to do fiddly input and
output transformations and add on task-specific "heads", as efficiency is less
of a consideration for those parts of the network.
<!-- TODO: custom tagger implemented in PyTorch, wrapped as Thinc model, link off to project (with notebook?) -->
## Implementing models in Thinc {#thinc}
<!-- TODO: use same example as above, custom tagger, but implemented in Thinc, link off to Thinc docs where appropriate -->
## Models for trainable components {#components}
To use our custom model including the Pytorch subnetwork, all we need to do is register
the architecture. The full example then becomes:
```python
from typing import List
from thinc.types import Floats2d
from thinc.api import Model, PyTorchWrapper, chain, with_array
import spacy
from spacy.tokens.doc import Doc
from spacy.ml import CharacterEmbed
from torch import nn
@spacy.registry.architectures("CustomTorchModel.v1")
def TorchModel(nO: int,
width: int,
hidden_width: int,
embed_size: int,
nM: int,
nC: int,
dropout: float,
) -> Model[List[Doc], List[Floats2d]]:
embed = CharacterEmbed(width, embed_size, nM, nC)
torch_model = nn.Sequential(
nn.Linear(width, hidden_width),
nn.ReLU(),
nn.Dropout2d(dropout),
nn.Linear(hidden_width, nO),
nn.ReLU(),
nn.Dropout2d(dropout),
nn.Softmax(dim=1)
)
wrapped_pt_model = PyTorchWrapper(torch_model)
model = chain(embed, with_array(wrapped_pt_model))
return model
```
Now you can use this model definition in any existing trainable spaCy component,
by specifying it in the config file:
```ini
### config.cfg (excerpt) {highlight="6-6"}
[components.tagger]
factory = "tagger"
[components.tagger.model]
@architectures = "CustomTorchModel.v1"
nO = 50
nM = 64
nC = 8
dropout = 0.2
width = 96
hidden_width = 48
embed_size = 2000
```
In this configuration, we pass all required parameters for the various
subcomponents of the custom architecture as settings in the training config file.
Remember that it is best not to rely on any (hidden) default values, to ensure that
training configs are complete and experiments fully reproducible.
<!-- TODO:
- Interaction with `predict`, `get_loss` and `set_annotations`