mirror of
https://github.com/explosion/spaCy.git
synced 2025-10-24 12:41:23 +03:00
821 lines
36 KiB
Plaintext
821 lines
36 KiB
Plaintext
---
|
||
title: Embeddings, Transformers and Transfer Learning
|
||
teaser: Using transformer embeddings like BERT in spaCy
|
||
menu:
|
||
- ['Embedding Layers', 'embedding-layers']
|
||
- ['Transformers', 'transformers']
|
||
- ['Static Vectors', 'static-vectors']
|
||
- ['Pretraining', 'pretraining']
|
||
next: /usage/training
|
||
---
|
||
|
||
spaCy supports a number of **transfer and multi-task learning** workflows that
|
||
can often help improve your pipeline's efficiency or accuracy. Transfer learning
|
||
refers to techniques such as word vector tables and language model pretraining.
|
||
These techniques can be used to import knowledge from raw text into your
|
||
pipeline, so that your models are able to generalize better from your annotated
|
||
examples.
|
||
|
||
You can convert **word vectors** from popular tools like
|
||
[FastText](https://fasttext.cc) and [Gensim](https://radimrehurek.com/gensim),
|
||
or you can load in any pretrained **transformer model** if you install
|
||
[`spacy-transformers`](https://github.com/explosion/spacy-transformers). You can
|
||
also do your own language model pretraining via the
|
||
[`spacy pretrain`](/api/cli#pretrain) command. You can even **share** your
|
||
transformer or other contextual embedding model across multiple components,
|
||
which can make long pipelines several times more efficient. To use transfer
|
||
learning, you'll need at least a few annotated examples for what you're trying
|
||
to predict. Otherwise, you could try using a "one-shot learning" approach using
|
||
[vectors and similarity](/usage/linguistic-features#vectors-similarity).
|
||
|
||
<Accordion title="What’s the difference between word vectors and language models?" id="vectors-vs-language-models">
|
||
|
||
[Transformers](#transformers) are large and powerful neural networks that give
|
||
you better accuracy, but are harder to deploy in production, as they require a
|
||
GPU to run effectively. [Word vectors](#word-vectors) are a slightly older
|
||
technique that can give your models a smaller improvement in accuracy, and can
|
||
also provide some additional capabilities.
|
||
|
||
The key difference between word-vectors and contextual language models such as
|
||
transformers is that word vectors model **lexical types**, rather than _tokens_.
|
||
If you have a list of terms with no context around them, a transformer model
|
||
like BERT can't really help you. BERT is designed to understand language **in
|
||
context**, which isn't what you have. A word vectors table will be a much better
|
||
fit for your task. However, if you do have words in context – whole sentences or
|
||
paragraphs of running text – word vectors will only provide a very rough
|
||
approximation of what the text is about.
|
||
|
||
Word vectors are also very computationally efficient, as they map a word to a
|
||
vector with a single indexing operation. Word vectors are therefore useful as a
|
||
way to **improve the accuracy** of neural network models, especially models that
|
||
are small or have received little or no pretraining. In spaCy, word vector
|
||
tables are only used as **static features**. spaCy does not backpropagate
|
||
gradients to the pretrained word vectors table. The static vectors table is
|
||
usually used in combination with a smaller table of learned task-specific
|
||
embeddings.
|
||
|
||
</Accordion>
|
||
|
||
<Accordion title="When should I add word vectors to my model?">
|
||
|
||
Word vectors are not compatible with most [transformer models](#transformers),
|
||
but if you're training another type of NLP network, it's almost always worth
|
||
adding word vectors to your model. As well as improving your final accuracy,
|
||
word vectors often make experiments more consistent, as the accuracy you reach
|
||
will be less sensitive to how the network is randomly initialized. High variance
|
||
due to random chance can slow down your progress significantly, as you need to
|
||
run many experiments to filter the signal from the noise.
|
||
|
||
Word vector features need to be enabled prior to training, and the same word
|
||
vectors table will need to be available at runtime as well. You cannot add word
|
||
vector features once the model has already been trained, and you usually cannot
|
||
replace one word vectors table with another without causing a significant loss
|
||
of performance.
|
||
|
||
</Accordion>
|
||
|
||
## Shared embedding layers {id="embedding-layers"}
|
||
|
||
spaCy lets you share a single transformer or other token-to-vector ("tok2vec")
|
||
embedding layer between multiple components. You can even update the shared
|
||
layer, performing **multi-task learning**. Reusing the tok2vec layer between
|
||
components can make your pipeline run a lot faster and result in much smaller
|
||
models. However, it can make the pipeline less modular and make it more
|
||
difficult to swap components or retrain parts of the pipeline. Multi-task
|
||
learning can affect your accuracy (either positively or negatively), and may
|
||
require some retuning of your hyper-parameters.
|
||
|
||

|
||
|
||
| Shared | Independent |
|
||
| ------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------- |
|
||
| ✅ **smaller:** models only need to include a single copy of the embeddings | ❌ **larger:** models need to include the embeddings for each component |
|
||
| ✅ **faster:** embed the documents once for your whole pipeline | ❌ **slower:** rerun the embedding for each component |
|
||
| ❌ **less composable:** all components require the same embedding component in the pipeline | ✅ **modular:** components can be moved and swapped freely |
|
||
|
||
You can share a single transformer or other tok2vec model between multiple
|
||
components by adding a [`Transformer`](/api/transformer) or
|
||
[`Tok2Vec`](/api/tok2vec) component near the start of your pipeline. Components
|
||
later in the pipeline can "connect" to it by including a **listener layer** like
|
||
[Tok2VecListener](/api/architectures#Tok2VecListener) within their model.
|
||
|
||

|
||
|
||
At the beginning of training, the [`Tok2Vec`](/api/tok2vec) component will grab
|
||
a reference to the relevant listener layers in the rest of your pipeline. When
|
||
it processes a batch of documents, it will pass forward its predictions to the
|
||
listeners, allowing the listeners to **reuse the predictions** when they are
|
||
eventually called. A similar mechanism is used to pass gradients from the
|
||
listeners back to the model. The [`Transformer`](/api/transformer) component and
|
||
[TransformerListener](/api/architectures#TransformerListener) layer do the same
|
||
thing for transformer models, but the `Transformer` component will also save the
|
||
transformer outputs to the
|
||
[`Doc._.trf_data`](/api/transformer#custom_attributes) extension attribute,
|
||
giving you access to them after the pipeline has finished running.
|
||
|
||
### Example: Shared vs. independent config {id="embedding-layers-config"}
|
||
|
||
The [config system](/usage/training#config) lets you express model configuration
|
||
for both shared and independent embedding layers. The shared setup uses a single
|
||
[`Tok2Vec`](/api/tok2vec) component with the
|
||
[Tok2Vec](/api/architectures#Tok2Vec) architecture. All other components, like
|
||
the entity recognizer, use a
|
||
[Tok2VecListener](/api/architectures#Tok2VecListener) layer as their model's
|
||
`tok2vec` argument, which connects to the `tok2vec` component model.
|
||
|
||
```ini {title="Shared",highlight="1-2,4-5,19-20"}
|
||
[components.tok2vec]
|
||
factory = "tok2vec"
|
||
|
||
[components.tok2vec.model]
|
||
@architectures = "spacy.Tok2Vec.v2"
|
||
|
||
[components.tok2vec.model.embed]
|
||
@architectures = "spacy.MultiHashEmbed.v2"
|
||
|
||
[components.tok2vec.model.encode]
|
||
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||
|
||
[components.ner]
|
||
factory = "ner"
|
||
|
||
[components.ner.model]
|
||
@architectures = "spacy.TransitionBasedParser.v1"
|
||
|
||
[components.ner.model.tok2vec]
|
||
@architectures = "spacy.Tok2VecListener.v1"
|
||
```
|
||
|
||
In the independent setup, the entity recognizer component defines its own
|
||
[Tok2Vec](/api/architectures#Tok2Vec) instance. Other components will do the
|
||
same. This makes them fully independent and doesn't require an upstream
|
||
[`Tok2Vec`](/api/tok2vec) component to be present in the pipeline.
|
||
|
||
```ini {title="Independent", highlight="7-8"}
|
||
[components.ner]
|
||
factory = "ner"
|
||
|
||
[components.ner.model]
|
||
@architectures = "spacy.TransitionBasedParser.v1"
|
||
|
||
[components.ner.model.tok2vec]
|
||
@architectures = "spacy.Tok2Vec.v2"
|
||
|
||
[components.ner.model.tok2vec.embed]
|
||
@architectures = "spacy.MultiHashEmbed.v2"
|
||
|
||
[components.ner.model.tok2vec.encode]
|
||
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||
```
|
||
|
||
{/* TODO: Once rehearsal is tested, mention it here. */}
|
||
|
||
## Using transformer models {id="transformers"}
|
||
|
||
Transformers are a family of neural network architectures that compute **dense,
|
||
context-sensitive representations** for the tokens in your documents. Downstream
|
||
models in your pipeline can then use these representations as input features to
|
||
**improve their predictions**. You can connect multiple components to a single
|
||
transformer model, with any or all of those components giving feedback to the
|
||
transformer to fine-tune it to your tasks. spaCy's transformer support
|
||
interoperates with [PyTorch](https://pytorch.org) and the
|
||
[HuggingFace `transformers`](https://huggingface.co/transformers/) library,
|
||
giving you access to thousands of pretrained models for your pipelines. There
|
||
are many [great guides](http://jalammar.github.io/illustrated-transformer/) to
|
||
transformer models, but for practical purposes, you can simply think of them as
|
||
drop-in replacements that let you achieve **higher accuracy** in exchange for
|
||
**higher training and runtime costs**.
|
||
|
||
### Setup and installation {id="transformers-installation"}
|
||
|
||
> #### System requirements
|
||
>
|
||
> We recommend an NVIDIA **GPU** with at least **10GB of memory** in order to
|
||
> work with transformer models. Make sure your GPU drivers are up to date and
|
||
> you have **CUDA v9+** installed.
|
||
|
||
> The exact requirements will depend on the transformer model. Training a
|
||
> transformer-based model without a GPU will be too slow for most practical
|
||
> purposes.
|
||
>
|
||
> Provisioning a new machine will require about **5GB** of data to be
|
||
> downloaded: 3GB CUDA runtime, 800MB PyTorch, 400MB CuPy, 500MB weights, 200MB
|
||
> spaCy and dependencies.
|
||
|
||
Once you have CUDA installed, we recommend installing PyTorch following the
|
||
[PyTorch installation guidelines](https://pytorch.org/get-started/locally/) for
|
||
your package manager and CUDA version. If you skip this step, pip will install
|
||
PyTorch as a dependency below, but it may not find the best version for your
|
||
setup.
|
||
|
||
```bash {title="Example: Install PyTorch 1.11.0 for CUDA 11.3 with pip"}
|
||
# See: https://pytorch.org/get-started/locally/
|
||
$ pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
|
||
```
|
||
|
||
Next, install spaCy with the extras for your CUDA version and transformers. The
|
||
CUDA extra (e.g., `cuda102`, `cuda113`) installs the correct version of
|
||
[`cupy`](https://docs.cupy.dev/en/stable/install.html#installing-cupy), which is
|
||
just like `numpy`, but for GPU. You may also need to set the `CUDA_PATH`
|
||
environment variable if your CUDA runtime is installed in a non-standard
|
||
location. Putting it all together, if you had installed CUDA 11.3 in
|
||
`/opt/nvidia/cuda`, you would run:
|
||
|
||
```bash {title="Installation with CUDA"}
|
||
$ export CUDA_PATH="/opt/nvidia/cuda"
|
||
$ pip install -U %%SPACY_PKG_NAME[cuda113,transformers]%%SPACY_PKG_FLAGS
|
||
```
|
||
|
||
For [`transformers`](https://huggingface.co/transformers/) v4.0.0+ and models
|
||
that require [`SentencePiece`](https://github.com/google/sentencepiece) (e.g.,
|
||
ALBERT, CamemBERT, XLNet, Marian, and T5), install the additional dependencies
|
||
with:
|
||
|
||
```bash {title="Install sentencepiece"}
|
||
$ pip install transformers[sentencepiece]
|
||
```
|
||
|
||
### Runtime usage {id="transformers-runtime"}
|
||
|
||
Transformer models can be used as **drop-in replacements** for other types of
|
||
neural networks, so your spaCy pipeline can include them in a way that's
|
||
completely invisible to the user. Users will download, load and use the model in
|
||
the standard way, like any other spaCy pipeline. Instead of using the
|
||
transformers as subnetworks directly, you can also use them via the
|
||
[`Transformer`](/api/transformer) pipeline component.
|
||
|
||

|
||
|
||
The `Transformer` component sets the
|
||
[`Doc._.trf_data`](/api/transformer#custom_attributes) extension attribute,
|
||
which lets you access the transformers outputs at runtime. The trained
|
||
transformer-based [pipelines](/models) provided by spaCy end on `_trf`, e.g.
|
||
[`en_core_web_trf`](/models/en#en_core_web_trf).
|
||
|
||
```bash
|
||
$ python -m spacy download en_core_web_trf
|
||
```
|
||
|
||
```python {title="Example"}
|
||
import spacy
|
||
from thinc.api import set_gpu_allocator, require_gpu
|
||
|
||
# Use the GPU, with memory allocations directed via PyTorch.
|
||
# This prevents out-of-memory errors that would otherwise occur from competing
|
||
# memory pools.
|
||
set_gpu_allocator("pytorch")
|
||
require_gpu(0)
|
||
|
||
nlp = spacy.load("en_core_web_trf")
|
||
for doc in nlp.pipe(["some text", "some other text"]):
|
||
tokvecs = doc._.trf_data.tensors[-1]
|
||
```
|
||
|
||
You can also customize how the [`Transformer`](/api/transformer) component sets
|
||
annotations onto the [`Doc`](/api/doc) by specifying a custom
|
||
`set_extra_annotations` function. This callback will be called with the raw
|
||
input and output data for the whole batch, along with the batch of `Doc`
|
||
objects, allowing you to implement whatever you need. The annotation setter is
|
||
called with a batch of [`Doc`](/api/doc) objects and a
|
||
[`FullTransformerBatch`](/api/transformer#fulltransformerbatch) containing the
|
||
transformers data for the batch.
|
||
|
||
```python
|
||
def custom_annotation_setter(docs, trf_data):
|
||
doc_data = list(trf_data.doc_data)
|
||
for doc, data in zip(docs, doc_data):
|
||
doc._.custom_attr = data
|
||
|
||
nlp = spacy.load("en_core_web_trf")
|
||
nlp.get_pipe("transformer").set_extra_annotations = custom_annotation_setter
|
||
doc = nlp("This is a text")
|
||
assert isinstance(doc._.custom_attr, TransformerData)
|
||
print(doc._.custom_attr.tensors)
|
||
```
|
||
|
||
### Training usage {id="transformers-training"}
|
||
|
||
The recommended workflow for training is to use spaCy's
|
||
[config system](/usage/training#config), usually via the
|
||
[`spacy train`](/api/cli#train) command. The training config defines all
|
||
component settings and hyperparameters in one place and lets you describe a tree
|
||
of objects by referring to creation functions, including functions you register
|
||
yourself. For details on how to get started with training your own model, check
|
||
out the [training quickstart](/usage/training#quickstart).
|
||
|
||
{/* TODO: <Project id="pipelines/transformers"> */}
|
||
|
||
{/* The easiest way to get started is to clone a transformers-based project */}
|
||
{/* template. Swap in your data, edit the settings and hyperparameters and train, */}
|
||
{/* evaluate, package and visualize your model. */}
|
||
|
||
{/* </Project> */}
|
||
|
||
The `[components]` section in the [`config.cfg`](/api/data-formats#config)
|
||
describes the pipeline components and the settings used to construct them,
|
||
including their model implementation. Here's a config snippet for the
|
||
[`Transformer`](/api/transformer) component, along with matching Python code. In
|
||
this case, the `[components.transformer]` block describes the `transformer`
|
||
component:
|
||
|
||
> #### Python equivalent
|
||
>
|
||
> ```python
|
||
> from spacy_transformers import Transformer, TransformerModel
|
||
> from spacy_transformers.annotation_setters import null_annotation_setter
|
||
> from spacy_transformers.span_getters import get_doc_spans
|
||
>
|
||
> trf = Transformer(
|
||
> nlp.vocab,
|
||
> TransformerModel(
|
||
> "bert-base-cased",
|
||
> get_spans=get_doc_spans,
|
||
> tokenizer_config={"use_fast": True},
|
||
> ),
|
||
> set_extra_annotations=null_annotation_setter,
|
||
> max_batch_items=4096,
|
||
> )
|
||
> ```
|
||
|
||
```ini {title="config.cfg",excerpt="true"}
|
||
[components.transformer]
|
||
factory = "transformer"
|
||
max_batch_items = 4096
|
||
|
||
[components.transformer.model]
|
||
@architectures = "spacy-transformers.TransformerModel.v3"
|
||
name = "bert-base-cased"
|
||
tokenizer_config = {"use_fast": true}
|
||
|
||
[components.transformer.model.get_spans]
|
||
@span_getters = "spacy-transformers.doc_spans.v1"
|
||
|
||
[components.transformer.set_extra_annotations]
|
||
@annotation_setters = "spacy-transformers.null_annotation_setter.v1"
|
||
|
||
```
|
||
|
||
The `[components.transformer.model]` block describes the `model` argument passed
|
||
to the transformer component. It's a Thinc
|
||
[`Model`](https://thinc.ai/docs/api-model) object that will be passed into the
|
||
component. Here, it references the function
|
||
[spacy-transformers.TransformerModel.v3](/api/architectures#TransformerModel)
|
||
registered in the [`architectures` registry](/api/top-level#registry). If a key
|
||
in a block starts with `@`, it's **resolved to a function** and all other
|
||
settings are passed to the function as arguments. In this case, `name`,
|
||
`tokenizer_config` and `get_spans`.
|
||
|
||
`get_spans` is a function that takes a batch of `Doc` objects and returns lists
|
||
of potentially overlapping `Span` objects to process by the transformer. Several
|
||
[built-in functions](/api/transformer#span_getters) are available – for example,
|
||
to process the whole document or individual sentences. When the config is
|
||
resolved, the function is created and passed into the model as an argument.
|
||
|
||
The `name` value is the name of any [HuggingFace model](huggingface-models),
|
||
which will be downloaded automatically the first time it's used. You can also
|
||
use a local file path. For full details, see the
|
||
[`TransformerModel` docs](/api/architectures#TransformerModel).
|
||
|
||
[huggingface-models]:
|
||
https://huggingface.co/models?library=pytorch&sort=downloads
|
||
|
||
A wide variety of PyTorch models are supported, but some might not work. If a
|
||
model doesn't seem to work feel free to open an
|
||
[issue](https://github.com/explosion/spacy/issues). Additionally note that
|
||
Transformers loaded in spaCy can only be used for tensors, and pretrained
|
||
task-specific heads or text generation features cannot be used as part of the
|
||
`transformer` pipeline component.
|
||
|
||
<Infobox variant="warning">
|
||
|
||
Remember that the `config.cfg` used for training should contain **no missing
|
||
values** and requires all settings to be defined. You don't want any hidden
|
||
defaults creeping in and changing your results! spaCy will tell you if settings
|
||
are missing, and you can run
|
||
[`spacy init fill-config`](/api/cli#init-fill-config) to automatically fill in
|
||
all defaults.
|
||
|
||
</Infobox>
|
||
|
||
### Customizing the settings {id="transformers-training-custom-settings"}
|
||
|
||
To change any of the settings, you can edit the `config.cfg` and re-run the
|
||
training. To change any of the functions, like the span getter, you can replace
|
||
the name of the referenced function – e.g.
|
||
`@span_getters = "spacy-transformers.sent_spans.v1"` to process sentences. You
|
||
can also register your own functions using the
|
||
[`span_getters` registry](/api/top-level#registry). For instance, the following
|
||
custom function returns [`Span`](/api/span) objects following sentence
|
||
boundaries, unless a sentence succeeds a certain amount of tokens, in which case
|
||
subsentences of at most `max_length` tokens are returned.
|
||
|
||
> #### config.cfg
|
||
>
|
||
> ```ini
|
||
> [components.transformer.model.get_spans]
|
||
> @span_getters = "custom_sent_spans"
|
||
> max_length = 25
|
||
> ```
|
||
|
||
```python {title="code.py"}
|
||
import spacy_transformers
|
||
|
||
@spacy_transformers.registry.span_getters("custom_sent_spans")
|
||
def configure_custom_sent_spans(max_length: int):
|
||
def get_custom_sent_spans(docs):
|
||
spans = []
|
||
for doc in docs:
|
||
spans.append([])
|
||
for sent in doc.sents:
|
||
start = 0
|
||
end = max_length
|
||
while end <= len(sent):
|
||
spans[-1].append(sent[start:end])
|
||
start += max_length
|
||
end += max_length
|
||
if start < len(sent):
|
||
spans[-1].append(sent[start:len(sent)])
|
||
return spans
|
||
|
||
return get_custom_sent_spans
|
||
```
|
||
|
||
To resolve the config during training, spaCy needs to know about your custom
|
||
function. You can make it available via the `--code` argument that can point to
|
||
a Python file. For more details on training with custom code, see the
|
||
[training documentation](/usage/training#custom-functions).
|
||
|
||
```bash
|
||
python -m spacy train ./config.cfg --code ./code.py
|
||
```
|
||
|
||
### Customizing the model implementations {id="training-custom-model"}
|
||
|
||
The [`Transformer`](/api/transformer) component expects a Thinc
|
||
[`Model`](https://thinc.ai/docs/api-model) object to be passed in as its `model`
|
||
argument. You're not limited to the implementation provided by
|
||
`spacy-transformers` – the only requirement is that your registered function
|
||
must return an object of type ~~Model[List[Doc], FullTransformerBatch]~~: that
|
||
is, a Thinc model that takes a list of [`Doc`](/api/doc) objects, and returns a
|
||
[`FullTransformerBatch`](/api/transformer#fulltransformerbatch) object with the
|
||
transformer data.
|
||
|
||
The same idea applies to task models that power the **downstream components**.
|
||
Most of spaCy's built-in model creation functions support a `tok2vec` argument,
|
||
which should be a Thinc layer of type ~~Model[List[Doc], List[Floats2d]]~~. This
|
||
is where we'll plug in our transformer model, using the
|
||
[TransformerListener](/api/architectures#TransformerListener) layer, which
|
||
sneakily delegates to the `Transformer` pipeline component.
|
||
|
||
```ini {title="config.cfg (excerpt)",highlight="12"}
|
||
[components.ner]
|
||
factory = "ner"
|
||
|
||
[nlp.pipeline.ner.model]
|
||
@architectures = "spacy.TransitionBasedParser.v1"
|
||
state_type = "ner"
|
||
extra_state_tokens = false
|
||
hidden_width = 128
|
||
maxout_pieces = 3
|
||
use_upper = false
|
||
|
||
[nlp.pipeline.ner.model.tok2vec]
|
||
@architectures = "spacy-transformers.TransformerListener.v1"
|
||
grad_factor = 1.0
|
||
|
||
[nlp.pipeline.ner.model.tok2vec.pooling]
|
||
@layers = "reduce_mean.v1"
|
||
```
|
||
|
||
The [TransformerListener](/api/architectures#TransformerListener) layer expects
|
||
a [pooling layer](https://thinc.ai/docs/api-layers#reduction-ops) as the
|
||
argument `pooling`, which needs to be of type ~~Model[Ragged, Floats2d]~~. This
|
||
layer determines how the vector for each spaCy token will be computed from the
|
||
zero or more source rows the token is aligned against. Here we use the
|
||
[`reduce_mean`](https://thinc.ai/docs/api-layers#reduce_mean) layer, which
|
||
averages the wordpiece rows. We could instead use
|
||
[`reduce_max`](https://thinc.ai/docs/api-layers#reduce_max), or a custom
|
||
function you write yourself.
|
||
|
||
You can have multiple components all listening to the same transformer model,
|
||
and all passing gradients back to it. By default, all of the gradients will be
|
||
**equally weighted**. You can control this with the `grad_factor` setting, which
|
||
lets you reweight the gradients from the different listeners. For instance,
|
||
setting `grad_factor = 0` would disable gradients from one of the listeners,
|
||
while `grad_factor = 2.0` would multiply them by 2. This is similar to having a
|
||
custom learning rate for each component. Instead of a constant, you can also
|
||
provide a schedule, allowing you to freeze the shared parameters at the start of
|
||
training.
|
||
|
||
## Static vectors {id="static-vectors"}
|
||
|
||
If your pipeline includes a **word vectors table**, you'll be able to use the
|
||
`.similarity()` method on the [`Doc`](/api/doc), [`Span`](/api/span),
|
||
[`Token`](/api/token) and [`Lexeme`](/api/lexeme) objects. You'll also be able
|
||
to access the vectors using the `.vector` attribute, or you can look up one or
|
||
more vectors directly using the [`Vocab`](/api/vocab) object. Pipelines with
|
||
word vectors can also **use the vectors as features** for the statistical
|
||
models, which can **improve the accuracy** of your components.
|
||
|
||
Word vectors in spaCy are "static" in the sense that they are not learned
|
||
parameters of the statistical models, and spaCy itself does not feature any
|
||
algorithms for learning word vector tables. You can train a word vectors table
|
||
using tools such as [floret](https://github.com/explosion/floret),
|
||
[Gensim](https://radimrehurek.com/gensim/), [FastText](https://fasttext.cc/) or
|
||
[GloVe](https://nlp.stanford.edu/projects/glove/), or download existing
|
||
pretrained vectors. The [`init vectors`](/api/cli#init-vectors) command lets you
|
||
convert vectors for use with spaCy and will give you a directory you can load or
|
||
refer to in your [training configs](/usage/training#config).
|
||
|
||
<Infobox title="Word vectors and similarity" emoji="📖">
|
||
|
||
For more details on loading word vectors into spaCy, using them for similarity
|
||
and improving word vector coverage by truncating and pruning the vectors, see
|
||
the usage guide on
|
||
[word vectors and similarity](/usage/linguistic-features#vectors-similarity).
|
||
|
||
</Infobox>
|
||
|
||
### Using word vectors in your models {id="word-vectors-models"}
|
||
|
||
Many neural network models are able to use word vector tables as additional
|
||
features, which sometimes results in significant improvements in accuracy.
|
||
spaCy's built-in embedding layer,
|
||
[MultiHashEmbed](/api/architectures#MultiHashEmbed), can be configured to use
|
||
word vector tables using the `include_static_vectors` flag.
|
||
|
||
```ini
|
||
[tagger.model.tok2vec.embed]
|
||
@architectures = "spacy.MultiHashEmbed.v2"
|
||
width = 128
|
||
attrs = ["LOWER","PREFIX","SUFFIX","SHAPE"]
|
||
rows = [5000,2500,2500,2500]
|
||
include_static_vectors = true
|
||
```
|
||
|
||
<Infobox title="How it works" emoji="💡">
|
||
|
||
The configuration system will look up the string `"spacy.MultiHashEmbed.v2"` in
|
||
the `architectures` [registry](/api/top-level#registry), and call the returned
|
||
object with the rest of the arguments from the block. This will result in a call
|
||
to the
|
||
[`MultiHashEmbed`](https://github.com/explosion/spacy/tree/develop/spacy/ml/models/tok2vec.py)
|
||
function, which will return a [Thinc](https://thinc.ai) model object with the
|
||
type signature ~~Model[List[Doc], List[Floats2d]]~~. Because the embedding layer
|
||
takes a list of `Doc` objects as input, it does not need to store a copy of the
|
||
vectors table. The vectors will be retrieved from the `Doc` objects that are
|
||
passed in, via the `doc.vocab.vectors` attribute. This part of the process is
|
||
handled by the [StaticVectors](/api/architectures#StaticVectors) layer.
|
||
|
||
</Infobox>
|
||
|
||
#### Creating a custom embedding layer {id="custom-embedding-layer"}
|
||
|
||
The [MultiHashEmbed](/api/architectures#StaticVectors) layer is spaCy's
|
||
recommended strategy for constructing initial word representations for your
|
||
neural network models, but you can also implement your own. You can register any
|
||
function to a string name, and then reference that function within your config
|
||
(see the [training docs](/usage/training) for more details). To try this out,
|
||
you can save the following little example to a new Python file:
|
||
|
||
```python
|
||
from spacy.ml.staticvectors import StaticVectors
|
||
from spacy.util import registry
|
||
|
||
print("I was imported!")
|
||
|
||
@registry.architectures("my_example.MyEmbedding.v1")
|
||
def MyEmbedding(output_width: int) -> Model[List[Doc], List[Floats2d]]:
|
||
print("I was called!")
|
||
return StaticVectors(nO=output_width)
|
||
```
|
||
|
||
If you pass the path to your file to the [`spacy train`](/api/cli#train) command
|
||
using the `--code` argument, your file will be imported, which means the
|
||
decorator registering the function will be run. Your function is now on equal
|
||
footing with any of spaCy's built-ins, so you can drop it in instead of any
|
||
other model with the same input and output signature. For instance, you could
|
||
use it in the tagger model as follows:
|
||
|
||
```ini
|
||
[tagger.model.tok2vec.embed]
|
||
@architectures = "my_example.MyEmbedding.v1"
|
||
output_width = 128
|
||
```
|
||
|
||
Now that you have a custom function wired into the network, you can start
|
||
implementing the logic you're interested in. For example, let's say you want to
|
||
try a relatively simple embedding strategy that makes use of static word
|
||
vectors, but combines them via summation with a smaller table of learned
|
||
embeddings.
|
||
|
||
```python
|
||
from thinc.api import add, chain, remap_ids, Embed
|
||
from spacy.ml.staticvectors import StaticVectors
|
||
from spacy.ml.featureextractor import FeatureExtractor
|
||
from spacy.util import registry
|
||
|
||
@registry.architectures("my_example.MyEmbedding.v1")
|
||
def MyCustomVectors(
|
||
output_width: int,
|
||
vector_width: int,
|
||
embed_rows: int,
|
||
key2row: Dict[int, int]
|
||
) -> Model[List[Doc], List[Floats2d]]:
|
||
return add(
|
||
StaticVectors(nO=output_width),
|
||
chain(
|
||
FeatureExtractor(["ORTH"]),
|
||
remap_ids(key2row),
|
||
Embed(nO=output_width, nV=embed_rows)
|
||
)
|
||
)
|
||
```
|
||
|
||
## Pretraining {id="pretraining"}
|
||
|
||
The [`spacy pretrain`](/api/cli#pretrain) command lets you initialize your
|
||
models with **information from raw text**. Without pretraining, the models for
|
||
your components will usually be initialized randomly. The idea behind
|
||
pretraining is simple: random probably isn't optimal, so if we have some text to
|
||
learn from, we can probably find a way to get the model off to a better start.
|
||
|
||
Pretraining uses the same [`config.cfg`](/usage/training#config) file as the
|
||
regular training, which helps keep the settings and hyperparameters consistent.
|
||
The additional `[pretraining]` section has several configuration subsections
|
||
that are familiar from the training block: the `[pretraining.batcher]`,
|
||
`[pretraining.optimizer]` and `[pretraining.corpus]` all work the same way and
|
||
expect the same types of objects, although for pretraining your corpus does not
|
||
need to have any annotations, so you will often use a different reader, such as
|
||
the [`JsonlCorpus`](/api/top-level#jsonlcorpus).
|
||
|
||
> #### Raw text format
|
||
>
|
||
> The raw text can be provided in spaCy's
|
||
> [binary `.spacy` format](/api/data-formats#training) consisting of serialized
|
||
> `Doc` objects or as a JSONL (newline-delimited JSON) with a key `"text"` per
|
||
> entry. This allows the data to be read in line by line, while also allowing
|
||
> you to include newlines in the texts.
|
||
>
|
||
> ```json
|
||
> {"text": "Can I ask where you work now and what you do, and if you enjoy it?"}
|
||
> {"text": "They may just pull out of the Seattle market completely, at least until they have autonomous vehicles."}
|
||
> ```
|
||
>
|
||
> You can also use your own custom corpus loader instead.
|
||
|
||
You can add a `[pretraining]` block to your config by setting the
|
||
`--pretraining` flag on [`init config`](/api/cli#init-config) or
|
||
[`init fill-config`](/api/cli#init-fill-config):
|
||
|
||
```bash
|
||
$ python -m spacy init fill-config config.cfg config_pretrain.cfg --pretraining
|
||
```
|
||
|
||
You can then run [`spacy pretrain`](/api/cli#pretrain) with the updated config
|
||
and pass in optional config overrides, like the path to the raw text file:
|
||
|
||
```bash
|
||
$ python -m spacy pretrain config_pretrain.cfg ./output --paths.raw_text text.jsonl
|
||
```
|
||
|
||
The following defaults are used for the `[pretraining]` block and merged into
|
||
your existing config when you run [`init config`](/api/cli#init-config) or
|
||
[`init fill-config`](/api/cli#init-fill-config) with `--pretraining`. If needed,
|
||
you can [configure](#pretraining-configure) the settings and hyperparameters or
|
||
change the [objective](#pretraining-objectives).
|
||
|
||
```ini
|
||
%%GITHUB_SPACY/spacy/default_config_pretraining.cfg
|
||
```
|
||
|
||
### How pretraining works {id="pretraining-details"}
|
||
|
||
The impact of [`spacy pretrain`](/api/cli#pretrain) varies, but it will usually
|
||
be worth trying if you're **not using a transformer** model and you have
|
||
**relatively little training data** (for instance, fewer than 5,000 sentences).
|
||
A good rule of thumb is that pretraining will generally give you a similar
|
||
accuracy improvement to using word vectors in your model. If word vectors have
|
||
given you a 10% error reduction, pretraining with spaCy might give you another
|
||
10%, for a 20% error reduction in total.
|
||
|
||
The [`spacy pretrain`](/api/cli#pretrain) command will take a **specific
|
||
subnetwork** within one of your components, and add additional layers to build a
|
||
network for a temporary task that forces the model to learn something about
|
||
sentence structure and word cooccurrence statistics.
|
||
|
||
Pretraining produces a **binary weights file** that can be loaded back in at the
|
||
start of training, using the configuration option `initialize.init_tok2vec`. The
|
||
weights file specifies an initial set of weights. Training then proceeds as
|
||
normal.
|
||
|
||
You can only pretrain one subnetwork from your pipeline at a time, and the
|
||
subnetwork must be typed ~~Model[List[Doc], List[Floats2d]]~~ (i.e. it has to be
|
||
a "tok2vec" layer). The most common workflow is to use the
|
||
[`Tok2Vec`](/api/tok2vec) component to create a shared token-to-vector layer for
|
||
several components of your pipeline, and apply pretraining to its whole model.
|
||
|
||
#### Configuring the pretraining {id="pretraining-configure"}
|
||
|
||
The [`spacy pretrain`](/api/cli#pretrain) command is configured using the
|
||
`[pretraining]` section of your [config file](/usage/training#config). The
|
||
`component` and `layer` settings tell spaCy how to **find the subnetwork** to
|
||
pretrain. The `layer` setting should be either the empty string (to use the
|
||
whole model), or a
|
||
[node reference](https://thinc.ai/docs/usage-models#model-state). Most of
|
||
spaCy's built-in model architectures have a reference named `"tok2vec"` that
|
||
will refer to the right layer.
|
||
|
||
```ini {title="config.cfg"}
|
||
# 1. Use the whole model of the "tok2vec" component
|
||
[pretraining]
|
||
component = "tok2vec"
|
||
layer = ""
|
||
|
||
# 2. Pretrain the "tok2vec" node of the "textcat" component
|
||
[pretraining]
|
||
component = "textcat"
|
||
layer = "tok2vec"
|
||
```
|
||
|
||
#### Connecting pretraining to training {id="pretraining-training"}
|
||
|
||
To benefit from pretraining, your training step needs to know to initialize its
|
||
`tok2vec` component with the weights learned from the pretraining step. You do
|
||
this by setting `initialize.init_tok2vec` to the filename of the `.bin` file
|
||
that you want to use from pretraining.
|
||
|
||
A pretraining step that runs for 5 epochs with an output path of `pretrain/`, as
|
||
an example, produces `pretrain/model0.bin` through `pretrain/model4.bin`. To
|
||
make use of the final output, you could fill in this value in your config file:
|
||
|
||
```ini {title="config.cfg"}
|
||
|
||
[paths]
|
||
init_tok2vec = "pretrain/model4.bin"
|
||
|
||
[initialize]
|
||
init_tok2vec = ${paths.init_tok2vec}
|
||
```
|
||
|
||
<Infobox variant="warning">
|
||
|
||
The outputs of `spacy pretrain` are not the same data format as the pre-packaged
|
||
static word vectors that would go into
|
||
[`initialize.vectors`](/api/data-formats#config-initialize). The pretraining
|
||
output consists of the weights that the `tok2vec` component should start with in
|
||
an existing pipeline, so it goes in `initialize.init_tok2vec`.
|
||
|
||
</Infobox>
|
||
|
||
#### Pretraining objectives {id="pretraining-objectives"}
|
||
|
||
> ```ini
|
||
> ### Characters objective
|
||
> [pretraining.objective]
|
||
> @architectures = "spacy.PretrainCharacters.v1"
|
||
> maxout_pieces = 3
|
||
> hidden_size = 300
|
||
> n_characters = 4
|
||
> ```
|
||
>
|
||
> ```ini
|
||
> ### Vectors objective
|
||
> [pretraining.objective]
|
||
> @architectures = "spacy.PretrainVectors.v1"
|
||
> maxout_pieces = 3
|
||
> hidden_size = 300
|
||
> loss = "cosine"
|
||
> ```
|
||
|
||
Two pretraining objectives are available, both of which are variants of the
|
||
cloze task [Devlin et al. (2018)](https://arxiv.org/abs/1810.04805) introduced
|
||
for BERT. The objective can be defined and configured via the
|
||
`[pretraining.objective]` config block.
|
||
|
||
- [`PretrainCharacters`](/api/architectures#pretrain_chars): The `"characters"`
|
||
objective asks the model to predict some number of leading and trailing UTF-8
|
||
bytes for the words. For instance, setting `n_characters = 2`, the model will
|
||
try to predict the first two and last two characters of the word.
|
||
|
||
- [`PretrainVectors`](/api/architectures#pretrain_vectors): The `"vectors"`
|
||
objective asks the model to predict the word's vector, from a static
|
||
embeddings table. This requires a word vectors model to be trained and loaded.
|
||
The vectors objective can optimize either a cosine or an L2 loss. We've
|
||
generally found cosine loss to perform better.
|
||
|
||
These pretraining objectives use a trick that we term **language modelling with
|
||
approximate outputs (LMAO)**. The motivation for the trick is that predicting an
|
||
exact word ID introduces a lot of incidental complexity. You need a large output
|
||
layer, and even then, the vocabulary is too large, which motivates tokenization
|
||
schemes that do not align to actual word boundaries. At the end of training, the
|
||
output layer will be thrown away regardless: we just want a task that forces the
|
||
network to model something about word cooccurrence statistics. Predicting
|
||
leading and trailing characters does that more than adequately, as the exact
|
||
word sequence could be recovered with high accuracy if the initial and trailing
|
||
characters are predicted accurately. With the vectors objective, the pretraining
|
||
uses the embedding space learned by an algorithm such as
|
||
[GloVe](https://nlp.stanford.edu/projects/glove/) or
|
||
[Word2vec](https://code.google.com/archive/p/word2vec/), allowing the model to
|
||
focus on the contextual modelling we actual care about.
|