mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Merge pull request #6045 from svlandeg/feature/more-layers-docs [ci skip]
This commit is contained in:
commit
1955aaaa20
|
@ -165,7 +165,7 @@ def MultiHashEmbed(
|
||||||
|
|
||||||
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
||||||
def CharacterEmbed(width: int, rows: int, nM: int, nC: int):
|
def CharacterEmbed(width: int, rows: int, nM: int, nC: int):
|
||||||
"""Construct an embedded representations based on character embeddings, using
|
"""Construct an embedded representation based on character embeddings, using
|
||||||
a feed-forward network. A fixed number of UTF-8 byte characters are used for
|
a feed-forward network. A fixed number of UTF-8 byte characters are used for
|
||||||
each word, taken from the beginning and end of the word equally. Padding is
|
each word, taken from the beginning and end of the word equally. Padding is
|
||||||
used in the centre for words that are too short.
|
used in the centre for words that are too short.
|
||||||
|
@ -176,8 +176,8 @@ def CharacterEmbed(width: int, rows: int, nM: int, nC: int):
|
||||||
ensures that the final character is always in the last position, instead
|
ensures that the final character is always in the last position, instead
|
||||||
of being in an arbitrary position depending on the word length.
|
of being in an arbitrary position depending on the word length.
|
||||||
|
|
||||||
The characters are embedded in a embedding table with 256 rows, and the
|
The characters are embedded in a embedding table with a given number of rows,
|
||||||
vectors concatenated. A hash-embedded vector of the NORM of the word is
|
and the vectors concatenated. A hash-embedded vector of the NORM of the word is
|
||||||
also concatenated on, and the result is then passed through a feed-forward
|
also concatenated on, and the result is then passed through a feed-forward
|
||||||
network to construct a single vector to represent the information.
|
network to construct a single vector to represent the information.
|
||||||
|
|
||||||
|
|
|
@ -576,7 +576,7 @@ cdef class Doc:
|
||||||
entity_type = 0
|
entity_type = 0
|
||||||
kb_id = 0
|
kb_id = 0
|
||||||
|
|
||||||
# Set ent_iob to Missing (0) bij default unless this token was nered before
|
# Set ent_iob to Missing (0) by default unless this token was nered before
|
||||||
ent_iob = 0
|
ent_iob = 0
|
||||||
if self.c[i].ent_iob != 0:
|
if self.c[i].ent_iob != 0:
|
||||||
ent_iob = 2
|
ent_iob = 2
|
||||||
|
|
|
@ -181,10 +181,10 @@ characters would be `"jumpping"`: 4 from the start, 4 from the end. This ensures
|
||||||
that the final character is always in the last position, instead of being in an
|
that the final character is always in the last position, instead of being in an
|
||||||
arbitrary position depending on the word length.
|
arbitrary position depending on the word length.
|
||||||
|
|
||||||
The characters are embedded in a embedding table with 256 rows, and the vectors
|
The characters are embedded in a embedding table with a given number of rows,
|
||||||
concatenated. A hash-embedded vector of the `NORM` of the word is also
|
and the vectors concatenated. A hash-embedded vector of the `NORM` of the word
|
||||||
concatenated on, and the result is then passed through a feed-forward network to
|
is also concatenated on, and the result is then passed through a feed-forward
|
||||||
construct a single vector to represent the information.
|
network to construct a single vector to represent the information.
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ----------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| ----------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
|
|
@ -293,7 +293,11 @@ context, the original parameters are restored.
|
||||||
|
|
||||||
## DependencyParser.add_label {#add_label tag="method"}
|
## DependencyParser.add_label {#add_label tag="method"}
|
||||||
|
|
||||||
Add a new label to the pipe.
|
Add a new label to the pipe. Note that you don't have to call this method if you
|
||||||
|
provide a **representative data sample** to the
|
||||||
|
[`begin_training`](#begin_training) method. In this case, all labels found in
|
||||||
|
the sample will be automatically added to the model, and the output dimension
|
||||||
|
will be [inferred](/usage/layers-architectures#shape-inference) automatically.
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
|
@ -307,6 +311,25 @@ Add a new label to the pipe.
|
||||||
| `label` | The label to add. ~~str~~ |
|
| `label` | The label to add. ~~str~~ |
|
||||||
| **RETURNS** | `0` if the label is already present, otherwise `1`. ~~int~~ |
|
| **RETURNS** | `0` if the label is already present, otherwise `1`. ~~int~~ |
|
||||||
|
|
||||||
|
## DependencyParser.set_output {#set_output tag="method"}
|
||||||
|
|
||||||
|
Change the output dimension of the component's model by calling the model's
|
||||||
|
attribute `resize_output`. This is a function that takes the original model and
|
||||||
|
the new output dimension `nO`, and changes the model in place. When resizing an
|
||||||
|
already trained model, care should be taken to avoid the "catastrophic
|
||||||
|
forgetting" problem.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> parser = nlp.add_pipe("parser")
|
||||||
|
> parser.set_output(512)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ---- | --------------------------------- |
|
||||||
|
| `nO` | The new output dimension. ~~int~~ |
|
||||||
|
|
||||||
## DependencyParser.to_disk {#to_disk tag="method"}
|
## DependencyParser.to_disk {#to_disk tag="method"}
|
||||||
|
|
||||||
Serialize the pipe to disk.
|
Serialize the pipe to disk.
|
||||||
|
|
|
@ -281,7 +281,11 @@ context, the original parameters are restored.
|
||||||
|
|
||||||
## EntityRecognizer.add_label {#add_label tag="method"}
|
## EntityRecognizer.add_label {#add_label tag="method"}
|
||||||
|
|
||||||
Add a new label to the pipe.
|
Add a new label to the pipe. Note that you don't have to call this method if you
|
||||||
|
provide a **representative data sample** to the
|
||||||
|
[`begin_training`](#begin_training) method. In this case, all labels found in
|
||||||
|
the sample will be automatically added to the model, and the output dimension
|
||||||
|
will be [inferred](/usage/layers-architectures#shape-inference) automatically.
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
|
@ -295,6 +299,25 @@ Add a new label to the pipe.
|
||||||
| `label` | The label to add. ~~str~~ |
|
| `label` | The label to add. ~~str~~ |
|
||||||
| **RETURNS** | `0` if the label is already present, otherwise `1`. ~~int~~ |
|
| **RETURNS** | `0` if the label is already present, otherwise `1`. ~~int~~ |
|
||||||
|
|
||||||
|
## EntityRecognizer.set_output {#set_output tag="method"}
|
||||||
|
|
||||||
|
Change the output dimension of the component's model by calling the model's
|
||||||
|
attribute `resize_output`. This is a function that takes the original model and
|
||||||
|
the new output dimension `nO`, and changes the model in place. When resizing an
|
||||||
|
already trained model, care should be taken to avoid the "catastrophic
|
||||||
|
forgetting" problem.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> ner = nlp.add_pipe("ner")
|
||||||
|
> ner.set_output(512)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ---- | --------------------------------- |
|
||||||
|
| `nO` | The new output dimension. ~~int~~ |
|
||||||
|
|
||||||
## EntityRecognizer.to_disk {#to_disk tag="method"}
|
## EntityRecognizer.to_disk {#to_disk tag="method"}
|
||||||
|
|
||||||
Serialize the pipe to disk.
|
Serialize the pipe to disk.
|
||||||
|
|
|
@ -205,9 +205,16 @@ examples can either be the full training data or a representative sample. They
|
||||||
are used to **initialize the models** of trainable pipeline components and are
|
are used to **initialize the models** of trainable pipeline components and are
|
||||||
passed each component's [`begin_training`](/api/pipe#begin_training) method, if
|
passed each component's [`begin_training`](/api/pipe#begin_training) method, if
|
||||||
available. Initialization includes validating the network,
|
available. Initialization includes validating the network,
|
||||||
[inferring missing shapes](https://thinc.ai/docs/usage-models#validation) and
|
[inferring missing shapes](/usage/layers-architectures#shape-inference) and
|
||||||
setting up the label scheme based on the data.
|
setting up the label scheme based on the data.
|
||||||
|
|
||||||
|
If no `get_examples` function is provided when calling `nlp.begin_training`, the
|
||||||
|
pipeline components will be initialized with generic data. In this case, it is
|
||||||
|
crucial that the output dimension of each component has already been defined
|
||||||
|
either in the [config](/usage/training#config), or by calling
|
||||||
|
[`pipe.add_label`](/api/pipe#add_label) for each possible output label (e.g. for
|
||||||
|
the tagger or textcat).
|
||||||
|
|
||||||
<Infobox variant="warning" title="Changed in v3.0">
|
<Infobox variant="warning" title="Changed in v3.0">
|
||||||
|
|
||||||
The `Language.update` method now takes a **function** that is called with no
|
The `Language.update` method now takes a **function** that is called with no
|
||||||
|
|
|
@ -258,6 +258,12 @@ context, the original parameters are restored.
|
||||||
|
|
||||||
Add a new label to the pipe. If the `Morphologizer` should set annotations for
|
Add a new label to the pipe. If the `Morphologizer` should set annotations for
|
||||||
both `pos` and `morph`, the label should include the UPOS as the feature `POS`.
|
both `pos` and `morph`, the label should include the UPOS as the feature `POS`.
|
||||||
|
Raises an error if the output dimension is already set, or if the model has
|
||||||
|
already been fully [initialized](#begin_training). Note that you don't have to
|
||||||
|
call this method if you provide a **representative data sample** to the
|
||||||
|
[`begin_training`](#begin_training) method. In this case, all labels found in
|
||||||
|
the sample will be automatically added to the model, and the output dimension
|
||||||
|
will be [inferred](/usage/layers-architectures#shape-inference) automatically.
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
|
|
|
@ -286,9 +286,6 @@ context, the original parameters are restored.
|
||||||
|
|
||||||
## Pipe.add_label {#add_label tag="method"}
|
## Pipe.add_label {#add_label tag="method"}
|
||||||
|
|
||||||
Add a new label to the pipe. It's possible to extend trained models with new
|
|
||||||
labels, but care should be taken to avoid the "catastrophic forgetting" problem.
|
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
> ```python
|
> ```python
|
||||||
|
@ -296,10 +293,81 @@ labels, but care should be taken to avoid the "catastrophic forgetting" problem.
|
||||||
> pipe.add_label("MY_LABEL")
|
> pipe.add_label("MY_LABEL")
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
|
Add a new label to the pipe, to be predicted by the model. The actual
|
||||||
|
implementation depends on the specific component, but in general `add_label`
|
||||||
|
shouldn't be called if the output dimension is already set, or if the model has
|
||||||
|
already been fully [initialized](#begin_training). If these conditions are
|
||||||
|
violated, the function will raise an Error. The exception to this rule is when
|
||||||
|
the component is [resizable](#is_resizable), in which case
|
||||||
|
[`set_output`](#set_output) should be called to ensure that the model is
|
||||||
|
properly resized.
|
||||||
|
|
||||||
|
<Infobox variant="danger">
|
||||||
|
|
||||||
|
This method needs to be overwritten with your own custom `add_label` method.
|
||||||
|
|
||||||
|
</Infobox>
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ----------- | ----------------------------------------------------------- |
|
| ----------- | ------------------------------------------------------- |
|
||||||
| `label` | The label to add. ~~str~~ |
|
| `label` | The label to add. ~~str~~ |
|
||||||
| **RETURNS** | `0` if the label is already present, otherwise `1`. ~~int~~ |
|
| **RETURNS** | 0 if the label is already present, otherwise 1. ~~int~~ |
|
||||||
|
|
||||||
|
Note that in general, you don't have to call `pipe.add_label` if you provide a
|
||||||
|
representative data sample to the [`begin_training`](#begin_training) method. In
|
||||||
|
this case, all labels found in the sample will be automatically added to the
|
||||||
|
model, and the output dimension will be
|
||||||
|
[inferred](/usage/layers-architectures#shape-inference) automatically.
|
||||||
|
|
||||||
|
## Pipe.is_resizable {#is_resizable tag="method"}
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> can_resize = pipe.is_resizable()
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> ### Custom resizing
|
||||||
|
> def custom_resize(model, new_nO):
|
||||||
|
> # adjust model
|
||||||
|
> return model
|
||||||
|
>
|
||||||
|
> custom_model.attrs["resize_output"] = custom_resize
|
||||||
|
> ```
|
||||||
|
|
||||||
|
Check whether or not the output dimension of the component's model can be
|
||||||
|
resized. If this method returns `True`, [`set_output`](#set_output) can be
|
||||||
|
called to change the model's output dimension.
|
||||||
|
|
||||||
|
For built-in components that are not resizable, you have to create and train a
|
||||||
|
new model from scratch with the appropriate architecture and output dimension.
|
||||||
|
For custom components, you can implement a `resize_output` function and add it
|
||||||
|
as an attribute to the component's model.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | ---------------------------------------------------------------------------------------------- |
|
||||||
|
| **RETURNS** | Whether or not the output dimension of the model can be changed after initialization. ~~bool~~ |
|
||||||
|
|
||||||
|
## Pipe.set_output {#set_output tag="method"}
|
||||||
|
|
||||||
|
Change the output dimension of the component's model. If the component is not
|
||||||
|
[resizable](#is_resizable), this method will raise a `NotImplementedError`. If a
|
||||||
|
component is resizable, the model's attribute `resize_output` will be called.
|
||||||
|
This is a function that takes the original model and the new output dimension
|
||||||
|
`nO`, and changes the model in place. When resizing an already trained model,
|
||||||
|
care should be taken to avoid the "catastrophic forgetting" problem.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> if pipe.is_resizable():
|
||||||
|
> pipe.set_output(512)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ---- | --------------------------------- |
|
||||||
|
| `nO` | The new output dimension. ~~int~~ |
|
||||||
|
|
||||||
## Pipe.to_disk {#to_disk tag="method"}
|
## Pipe.to_disk {#to_disk tag="method"}
|
||||||
|
|
||||||
|
|
|
@ -250,7 +250,7 @@ Score a batch of examples.
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ----------- | --------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| ----------- | --------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `examples` | The examples to score. ~~Iterable[Example]~~ |
|
| `examples` | The examples to score. ~~Iterable[Example]~~ |
|
||||||
| **RETURNS** | The scores, produced by [`Scorer.score_token_attr`](/api/scorer#score_token_attr) for the attribute `"tag"`. ~~Dict[str, float]~~ |
|
| **RETURNS** | The scores, produced by [`Scorer.score_token_attr`](/api/scorer#score_token_attr) for the attribute `"tag"`. ~~Dict[str, float]~~ |
|
||||||
|
|
||||||
|
@ -288,7 +288,13 @@ context, the original parameters are restored.
|
||||||
|
|
||||||
## Tagger.add_label {#add_label tag="method"}
|
## Tagger.add_label {#add_label tag="method"}
|
||||||
|
|
||||||
Add a new label to the pipe.
|
Add a new label to the pipe. Raises an error if the output dimension is already
|
||||||
|
set, or if the model has already been fully [initialized](#begin_training). Note
|
||||||
|
that you don't have to call this method if you provide a **representative data
|
||||||
|
sample** to the [`begin_training`](#begin_training) method. In this case, all
|
||||||
|
labels found in the sample will be automatically added to the model, and the
|
||||||
|
output dimension will be [inferred](/usage/layers-architectures#shape-inference)
|
||||||
|
automatically.
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
|
|
|
@ -297,7 +297,13 @@ Modify the pipe's model, to use the given parameter values.
|
||||||
|
|
||||||
## TextCategorizer.add_label {#add_label tag="method"}
|
## TextCategorizer.add_label {#add_label tag="method"}
|
||||||
|
|
||||||
Add a new label to the pipe.
|
Add a new label to the pipe. Raises an error if the output dimension is already
|
||||||
|
set, or if the model has already been fully [initialized](#begin_training). Note
|
||||||
|
that you don't have to call this method if you provide a **representative data
|
||||||
|
sample** to the [`begin_training`](#begin_training) method. In this case, all
|
||||||
|
labels found in the sample will be automatically added to the model, and the
|
||||||
|
output dimension will be [inferred](/usage/layers-architectures#shape-inference)
|
||||||
|
automatically.
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
|
|
|
@ -5,7 +5,7 @@ menu:
|
||||||
- ['Type Signatures', 'type-sigs']
|
- ['Type Signatures', 'type-sigs']
|
||||||
- ['Swapping Architectures', 'swap-architectures']
|
- ['Swapping Architectures', 'swap-architectures']
|
||||||
- ['PyTorch & TensorFlow', 'frameworks']
|
- ['PyTorch & TensorFlow', 'frameworks']
|
||||||
- ['Thinc Models', 'thinc']
|
- ['Custom Thinc Models', 'thinc']
|
||||||
- ['Trainable Components', 'components']
|
- ['Trainable Components', 'components']
|
||||||
next: /usage/projects
|
next: /usage/projects
|
||||||
---
|
---
|
||||||
|
@ -118,7 +118,7 @@ code.
|
||||||
|
|
||||||
If no model is specified for the [`TextCategorizer`](/api/textcategorizer), the
|
If no model is specified for the [`TextCategorizer`](/api/textcategorizer), the
|
||||||
[TextCatEnsemble](/api/architectures#TextCatEnsemble) architecture is used by
|
[TextCatEnsemble](/api/architectures#TextCatEnsemble) architecture is used by
|
||||||
default. This architecture combines a simpel bag-of-words model with a neural
|
default. This architecture combines a simple bag-of-words model with a neural
|
||||||
network, usually resulting in the most accurate results, but at the cost of
|
network, usually resulting in the most accurate results, but at the cost of
|
||||||
speed. The config file for this model would look something like this:
|
speed. The config file for this model would look something like this:
|
||||||
|
|
||||||
|
@ -225,28 +225,263 @@ you'll be able to try it out in any of the spaCy components.
|
||||||
|
|
||||||
Thinc allows you to [wrap models](https://thinc.ai/docs/usage-frameworks)
|
Thinc allows you to [wrap models](https://thinc.ai/docs/usage-frameworks)
|
||||||
written in other machine learning frameworks like PyTorch, TensorFlow and MXNet
|
written in other machine learning frameworks like PyTorch, TensorFlow and MXNet
|
||||||
using a unified [`Model`](https://thinc.ai/docs/api-model) API. As well as
|
using a unified [`Model`](https://thinc.ai/docs/api-model) API. This makes it
|
||||||
**wrapping whole models**, Thinc lets you call into an external framework for
|
easy to use a model implemented in a different framework to power a component in
|
||||||
just **part of your model**: you can have a model where you use PyTorch just for
|
your spaCy pipeline. For example, to wrap a PyTorch model as a Thinc `Model`,
|
||||||
the transformer layers, using "native" Thinc layers to do fiddly input and
|
you can use Thinc's
|
||||||
output transformations and add on task-specific "heads", as efficiency is less
|
[`PyTorchWrapper`](https://thinc.ai/docs/api-layers#pytorchwrapper):
|
||||||
of a consideration for those parts of the network.
|
|
||||||
|
|
||||||
<!-- TODO: custom tagger implemented in PyTorch, wrapped as Thinc model, link off to project (with notebook?) -->
|
```python
|
||||||
|
from thinc.api import PyTorchWrapper
|
||||||
|
|
||||||
## Implementing models in Thinc {#thinc}
|
wrapped_pt_model = PyTorchWrapper(torch_model)
|
||||||
|
```
|
||||||
|
|
||||||
<!-- TODO: use same example as above, custom tagger, but implemented in Thinc, link off to Thinc docs where appropriate -->
|
Let's use PyTorch to define a very simple neural network consisting of two
|
||||||
|
hidden `Linear` layers with `ReLU` activation and dropout, and a
|
||||||
|
softmax-activated output layer:
|
||||||
|
|
||||||
## Models for trainable components {#components}
|
```python
|
||||||
|
### PyTorch model
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
torch_model = nn.Sequential(
|
||||||
|
nn.Linear(width, hidden_width),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout2d(dropout),
|
||||||
|
nn.Linear(hidden_width, nO),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout2d(dropout),
|
||||||
|
nn.Softmax(dim=1)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The resulting wrapped `Model` can be used as a **custom architecture** as such,
|
||||||
|
or can be a **subcomponent of a larger model**. For instance, we can use Thinc's
|
||||||
|
[`chain`](https://thinc.ai/docs/api-layers#chain) combinator, which works like
|
||||||
|
`Sequential` in PyTorch, to combine the wrapped model with other components in a
|
||||||
|
larger network. This effectively means that you can easily wrap different
|
||||||
|
components from different frameworks, and "glue" them together with Thinc:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from thinc.api import chain, with_array
|
||||||
|
from spacy.ml import CharacterEmbed
|
||||||
|
|
||||||
|
char_embed = CharacterEmbed(width, embed_size, nM, nC)
|
||||||
|
model = chain(char_embed, with_array(wrapped_pt_model))
|
||||||
|
```
|
||||||
|
|
||||||
|
In the above example, we have combined our custom PyTorch model with a character
|
||||||
|
embedding layer defined by spaCy.
|
||||||
|
[CharacterEmbed](/api/architectures#CharacterEmbed) returns a `Model` that takes
|
||||||
|
a ~~List[Doc]~~ as input, and outputs a ~~List[Floats2d]~~. To make sure that
|
||||||
|
the wrapped PyTorch model receives valid inputs, we use Thinc's
|
||||||
|
[`with_array`](https://thinc.ai/docs/api-layers#with_array) helper.
|
||||||
|
|
||||||
|
You could also implement a model that only uses PyTorch for the transformer
|
||||||
|
layers, and "native" Thinc layers to do fiddly input and output transformations
|
||||||
|
and add on task-specific "heads", as efficiency is less of a consideration for
|
||||||
|
those parts of the network.
|
||||||
|
|
||||||
|
### Using wrapped models {#frameworks-usage}
|
||||||
|
|
||||||
|
To use our custom model including the PyTorch subnetwork, all we need to do is
|
||||||
|
register the architecture using the
|
||||||
|
[`architectures` registry](/api/top-level#registry). This will assign the
|
||||||
|
architecture a name so spaCy knows how to find it, and allows passing in
|
||||||
|
arguments like hyperparameters via the [config](/usage/training#config). The
|
||||||
|
full example then becomes:
|
||||||
|
|
||||||
|
```python
|
||||||
|
### Registering the architecture {highlight="9"}
|
||||||
|
from typing import List
|
||||||
|
from thinc.types import Floats2d
|
||||||
|
from thinc.api import Model, PyTorchWrapper, chain, with_array
|
||||||
|
import spacy
|
||||||
|
from spacy.tokens.doc import Doc
|
||||||
|
from spacy.ml import CharacterEmbed
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
@spacy.registry.architectures("CustomTorchModel.v1")
|
||||||
|
def create_torch_model(
|
||||||
|
nO: int,
|
||||||
|
width: int,
|
||||||
|
hidden_width: int,
|
||||||
|
embed_size: int,
|
||||||
|
nM: int,
|
||||||
|
nC: int,
|
||||||
|
dropout: float,
|
||||||
|
) -> Model[List[Doc], List[Floats2d]]:
|
||||||
|
char_embed = CharacterEmbed(width, embed_size, nM, nC)
|
||||||
|
torch_model = nn.Sequential(
|
||||||
|
nn.Linear(width, hidden_width),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout2d(dropout),
|
||||||
|
nn.Linear(hidden_width, nO),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout2d(dropout),
|
||||||
|
nn.Softmax(dim=1)
|
||||||
|
)
|
||||||
|
wrapped_pt_model = PyTorchWrapper(torch_model)
|
||||||
|
model = chain(char_embed, with_array(wrapped_pt_model))
|
||||||
|
return model
|
||||||
|
```
|
||||||
|
|
||||||
|
The model definition can now be used in any existing trainable spaCy component,
|
||||||
|
by specifying it in the config file. In this configuration, all required
|
||||||
|
parameters for the various subcomponents of the custom architecture are passed
|
||||||
|
in as settings via the config.
|
||||||
|
|
||||||
|
```ini
|
||||||
|
### config.cfg (excerpt) {highlight="5-5"}
|
||||||
|
[components.tagger]
|
||||||
|
factory = "tagger"
|
||||||
|
|
||||||
|
[components.tagger.model]
|
||||||
|
@architectures = "CustomTorchModel.v1"
|
||||||
|
nO = 50
|
||||||
|
width = 96
|
||||||
|
hidden_width = 48
|
||||||
|
embed_size = 2000
|
||||||
|
nM = 64
|
||||||
|
nC = 8
|
||||||
|
dropout = 0.2
|
||||||
|
```
|
||||||
|
|
||||||
|
<Infobox variant="warning">
|
||||||
|
|
||||||
|
Remember that it is best not to rely on any (hidden) default values, to ensure
|
||||||
|
that training configs are complete and experiments fully reproducible.
|
||||||
|
|
||||||
|
</Infobox>
|
||||||
|
|
||||||
|
## Custom models with Thinc {#thinc}
|
||||||
|
|
||||||
|
Of course it's also possible to define the `Model` from the previous section
|
||||||
|
entirely in Thinc. The Thinc documentation provides details on the
|
||||||
|
[various layers](https://thinc.ai/docs/api-layers) and helper functions
|
||||||
|
available. Combinators can also be used to
|
||||||
|
[overload operators](https://thinc.ai/docs/usage-models#operators) and a common
|
||||||
|
usage pattern is to bind `chain` to `>>`. The "native" Thinc version of our
|
||||||
|
simple neural network would then become:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from thinc.api import chain, with_array, Model, Relu, Dropout, Softmax
|
||||||
|
from spacy.ml import CharacterEmbed
|
||||||
|
|
||||||
|
char_embed = CharacterEmbed(width, embed_size, nM, nC)
|
||||||
|
with Model.define_operators({">>": chain}):
|
||||||
|
layers = (
|
||||||
|
Relu(hidden_width, width)
|
||||||
|
>> Dropout(dropout)
|
||||||
|
>> Relu(hidden_width, hidden_width)
|
||||||
|
>> Dropout(dropout)
|
||||||
|
>> Softmax(nO, hidden_width)
|
||||||
|
)
|
||||||
|
model = char_embed >> with_array(layers)
|
||||||
|
```
|
||||||
|
|
||||||
|
<Infobox variant="warning" title="Important note on inputs and outputs">
|
||||||
|
|
||||||
|
Note that Thinc layers define the output dimension (`nO`) as the first argument,
|
||||||
|
followed (optionally) by the input dimension (`nI`). This is in contrast to how
|
||||||
|
the PyTorch layers are defined, where `in_features` precedes `out_features`.
|
||||||
|
|
||||||
|
</Infobox>
|
||||||
|
|
||||||
|
### Shape inference in Thinc {#thinc-shape-inference}
|
||||||
|
|
||||||
|
It is **not** strictly necessary to define all the input and output dimensions
|
||||||
|
for each layer, as Thinc can perform
|
||||||
|
[shape inference](https://thinc.ai/docs/usage-models#validation) between
|
||||||
|
sequential layers by matching up the output dimensionality of one layer to the
|
||||||
|
input dimensionality of the next. This means that we can simplify the `layers`
|
||||||
|
definition:
|
||||||
|
|
||||||
|
> #### Diff
|
||||||
|
>
|
||||||
|
> ```diff
|
||||||
|
> layers = (
|
||||||
|
> Relu(hidden_width, width)
|
||||||
|
> >> Dropout(dropout)
|
||||||
|
> - >> Relu(hidden_width, hidden_width)
|
||||||
|
> + >> Relu(hidden_width)
|
||||||
|
> >> Dropout(dropout)
|
||||||
|
> - >> Softmax(nO, hidden_width)
|
||||||
|
> + >> Softmax(nO)
|
||||||
|
> )
|
||||||
|
> ```
|
||||||
|
|
||||||
|
```python
|
||||||
|
with Model.define_operators({">>": chain}):
|
||||||
|
layers = (
|
||||||
|
Relu(hidden_width, width)
|
||||||
|
>> Dropout(dropout)
|
||||||
|
>> Relu(hidden_width)
|
||||||
|
>> Dropout(dropout)
|
||||||
|
>> Softmax(nO)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Thinc can even go one step further and **deduce the correct input dimension** of
|
||||||
|
the first layer, and output dimension of the last. To enable this functionality,
|
||||||
|
you have to call
|
||||||
|
[`Model.initialize`](https://thinc.ai/docs/api-model#initialize) with an **input
|
||||||
|
sample** `X` and an **output sample** `Y` with the correct dimensions:
|
||||||
|
|
||||||
|
```python
|
||||||
|
### Shape inference with initialization {highlight="3,7,10"}
|
||||||
|
with Model.define_operators({">>": chain}):
|
||||||
|
layers = (
|
||||||
|
Relu(hidden_width)
|
||||||
|
>> Dropout(dropout)
|
||||||
|
>> Relu(hidden_width)
|
||||||
|
>> Dropout(dropout)
|
||||||
|
>> Softmax()
|
||||||
|
)
|
||||||
|
model = char_embed >> with_array(layers)
|
||||||
|
model.initialize(X=input_sample, Y=output_sample)
|
||||||
|
```
|
||||||
|
|
||||||
|
The built-in [pipeline components](/usage/processing-pipelines) in spaCy ensure
|
||||||
|
that their internal models are **always initialized** with appropriate sample
|
||||||
|
data. In this case, `X` is typically a ~~List[Doc]~~, while `Y` is typically a
|
||||||
|
~~List[Array1d]~~ or ~~List[Array2d]~~, depending on the specific task. This
|
||||||
|
functionality is triggered when
|
||||||
|
[`nlp.begin_training`](/api/language#begin_training) is called.
|
||||||
|
|
||||||
|
### Dropout and normalization in Thinc {#thinc-dropout-norm}
|
||||||
|
|
||||||
|
Many of the available Thinc [layers](https://thinc.ai/docs/api-layers) allow you
|
||||||
|
to define a `dropout` argument that will result in "chaining" an additional
|
||||||
|
[`Dropout`](https://thinc.ai/docs/api-layers#dropout) layer. Optionally, you can
|
||||||
|
often specify whether or not you want to add layer normalization, which would
|
||||||
|
result in an additional
|
||||||
|
[`LayerNorm`](https://thinc.ai/docs/api-layers#layernorm) layer. That means that
|
||||||
|
the following `layers` definition is equivalent to the previous:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with Model.define_operators({">>": chain}):
|
||||||
|
layers = (
|
||||||
|
Relu(hidden_width, dropout=dropout, normalize=False)
|
||||||
|
>> Relu(hidden_width, dropout=dropout, normalize=False)
|
||||||
|
>> Softmax()
|
||||||
|
)
|
||||||
|
model = char_embed >> with_array(layers)
|
||||||
|
model.initialize(X=input_sample, Y=output_sample)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Create new trainable components {#components}
|
||||||
|
|
||||||
<!-- TODO:
|
<!-- TODO:
|
||||||
|
|
||||||
- Interaction with `predict`, `get_loss` and `set_annotations`
|
- Interaction with `predict`, `get_loss` and `set_annotations`
|
||||||
- Initialization life-cycle with `begin_training`.
|
- Initialization life-cycle with `begin_training`, correlation with add_label
|
||||||
|
|
||||||
Example: relation extraction component (implemented as project template)
|
Example: relation extraction component (implemented as project template)
|
||||||
|
|
||||||
|
Avoid duplication with usage/processing-pipelines#trainable-components ?
|
||||||
|
|
||||||
-->
|
-->
|
||||||
|
|
||||||
![Diagram of a pipeline component with its model](../images/layers-architectures.svg)
|
![Diagram of a pipeline component with its model](../images/layers-architectures.svg)
|
||||||
|
|
|
@ -1028,11 +1028,11 @@ plug fully custom machine learning components into your pipeline. You'll need
|
||||||
the following:
|
the following:
|
||||||
|
|
||||||
1. **Model:** A Thinc [`Model`](https://thinc.ai/docs/api-model) instance. This
|
1. **Model:** A Thinc [`Model`](https://thinc.ai/docs/api-model) instance. This
|
||||||
can be a model using [layers](https://thinc.ai/docs/api-layers) implemented
|
can be a model using implemented in
|
||||||
in Thinc, or a [wrapped model](https://thinc.ai/docs/usage-frameworks)
|
[Thinc](/usage/layers-architectures#thinc), or a
|
||||||
implemented in PyTorch, TensorFlow, MXNet or a fully custom solution. The
|
[wrapped model](/usage/layers-architectures#frameworks) implemented in
|
||||||
model must take a list of [`Doc`](/api/doc) objects as input and can have any
|
PyTorch, TensorFlow, MXNet or a fully custom solution. The model must take a
|
||||||
type of output.
|
list of [`Doc`](/api/doc) objects as input and can have any type of output.
|
||||||
2. **Pipe subclass:** A subclass of [`Pipe`](/api/pipe) that implements at least
|
2. **Pipe subclass:** A subclass of [`Pipe`](/api/pipe) that implements at least
|
||||||
two methods: [`Pipe.predict`](/api/pipe#predict) and
|
two methods: [`Pipe.predict`](/api/pipe#predict) and
|
||||||
[`Pipe.set_annotations`](/api/pipe#set_annotations).
|
[`Pipe.set_annotations`](/api/pipe#set_annotations).
|
||||||
|
@ -1078,8 +1078,9 @@ _first_ create a `Model` from a [registered architecture](/api/architectures),
|
||||||
validate its arguments and _then_ pass the object forward to the component. This
|
validate its arguments and _then_ pass the object forward to the component. This
|
||||||
means that the config can express very complex, nested trees of objects – but
|
means that the config can express very complex, nested trees of objects – but
|
||||||
the objects don't have to pass the model settings all the way down to the
|
the objects don't have to pass the model settings all the way down to the
|
||||||
components. It also makes the components more **modular** and lets you swap
|
components. It also makes the components more **modular** and lets you
|
||||||
different architectures in your config, and re-use model definitions.
|
[swap](/usage/layers-architectures#swap-architectures) different architectures
|
||||||
|
in your config, and re-use model definitions.
|
||||||
|
|
||||||
```ini
|
```ini
|
||||||
### config.cfg (excerpt)
|
### config.cfg (excerpt)
|
||||||
|
@ -1134,7 +1135,7 @@ loss is calculated and to add evaluation scores to the training output.
|
||||||
For more details on how to implement your own trainable components and model
|
For more details on how to implement your own trainable components and model
|
||||||
architectures, and plug existing models implemented in PyTorch or TensorFlow
|
architectures, and plug existing models implemented in PyTorch or TensorFlow
|
||||||
into your spaCy pipeline, see the usage guide on
|
into your spaCy pipeline, see the usage guide on
|
||||||
[layers and model architectures](/usage/layers-architectures#components).
|
[layers and model architectures](/usage/layers-architectures).
|
||||||
|
|
||||||
</Infobox>
|
</Infobox>
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,8 @@
|
||||||
"Floats2d": "https://thinc.ai/docs/api-types#types",
|
"Floats2d": "https://thinc.ai/docs/api-types#types",
|
||||||
"Floats3d": "https://thinc.ai/docs/api-types#types",
|
"Floats3d": "https://thinc.ai/docs/api-types#types",
|
||||||
"FloatsXd": "https://thinc.ai/docs/api-types#types",
|
"FloatsXd": "https://thinc.ai/docs/api-types#types",
|
||||||
|
"Array1d": "https://thinc.ai/docs/api-types#types",
|
||||||
|
"Array2d": "https://thinc.ai/docs/api-types#types",
|
||||||
"Ops": "https://thinc.ai/docs/api-backends#ops",
|
"Ops": "https://thinc.ai/docs/api-backends#ops",
|
||||||
"cymem.Pool": "https://github.com/explosion/cymem",
|
"cymem.Pool": "https://github.com/explosion/cymem",
|
||||||
"preshed.BloomFilter": "https://github.com/explosion/preshed",
|
"preshed.BloomFilter": "https://github.com/explosion/preshed",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user