mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Merge pull request #5894 from explosion/docs/model-docstrings
This commit is contained in:
commit
2611d7a9af
|
@ -1,20 +1,73 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
|
||||
from thinc.types import Floats2d
|
||||
|
||||
from ...util import registry
|
||||
from .._precomputable_affine import PrecomputableAffine
|
||||
from ..tb_framework import TransitionModel
|
||||
from ...tokens import Doc
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.TransitionBasedParser.v1")
|
||||
def build_tb_parser_model(
|
||||
tok2vec: Model,
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
nr_feature_tokens: int,
|
||||
hidden_width: int,
|
||||
maxout_pieces: int,
|
||||
use_upper: bool = True,
|
||||
nO: Optional[int] = None,
|
||||
) -> Model:
|
||||
"""
|
||||
Build a transition-based parser model. Can apply to NER or dependency-parsing.
|
||||
|
||||
Transition-based parsing is an approach to structured prediction where the
|
||||
task of predicting the structure is mapped to a series of state transitions.
|
||||
You might find this tutorial helpful as background:
|
||||
https://explosion.ai/blog/parsing-english-in-python
|
||||
|
||||
The neural network state prediction model consists of either two or three
|
||||
subnetworks:
|
||||
|
||||
* tok2vec: Map each token into a vector representations. This subnetwork
|
||||
is run once for each batch.
|
||||
* lower: Construct a feature-specific vector for each (token, feature) pair.
|
||||
This is also run once for each batch. Constructing the state
|
||||
representation is then simply a matter of summing the component features
|
||||
and applying the non-linearity.
|
||||
* upper (optional): A feed-forward network that predicts scores from the
|
||||
state representation. If not present, the output from the lower model is
|
||||
used as action scores directly.
|
||||
|
||||
tok2vec (Model[List[Doc], List[Floats2d]]):
|
||||
Subnetwork to map tokens into vector representations.
|
||||
nr_feature_tokens (int): The number of tokens in the context to use to
|
||||
construct the state vector. Valid choices are 1, 2, 3, 6, 8 and 13. The
|
||||
2, 8 and 13 feature sets are designed for the parser, while the 3 and 6
|
||||
feature sets are designed for the NER. The recommended feature sets are
|
||||
3 for NER, and 8 for the dependency parser.
|
||||
|
||||
TODO: This feature should be split into two, state_type: ["deps", "ner"]
|
||||
and extra_state_features: [True, False]. This would map into:
|
||||
|
||||
(deps, False): 8
|
||||
(deps, True): 13
|
||||
(ner, False): 3
|
||||
(ner, True): 6
|
||||
|
||||
hidden_width (int): The width of the hidden layer.
|
||||
maxout_pieces (int): How many pieces to use in the state prediction layer.
|
||||
Recommended values are 1, 2 or 3. If 1, the maxout non-linearity
|
||||
is replaced with a ReLu non-linearity if use_upper=True, and no
|
||||
non-linearity if use_upper=False.
|
||||
use_upper (bool): Whether to use an additional hidden layer after the state
|
||||
vector in order to predict the action scores. It is recommended to set
|
||||
this to False for large pretrained models such as transformers, and False
|
||||
for smaller networks. The upper layer is computed on CPU, which becomes
|
||||
a bottleneck on larger GPU-based models, where it's also less necessary.
|
||||
nO (int or None): The number of actions the model will predict between.
|
||||
Usually inferred from data at the beginning of training, or loaded from
|
||||
disk.
|
||||
"""
|
||||
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
|
||||
tok2vec = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width),)
|
||||
tok2vec.set_dim("nO", hidden_width)
|
||||
|
|
|
@ -10,10 +10,24 @@ from .._iob import IOB
|
|||
from ...util import registry
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.BiluoTagger.v1")
|
||||
@registry.architectures.register("spacy.BILUOTagger.v1")
|
||||
def BiluoTagger(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]]
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
"""Construct a simple NER tagger, that predicts BILUO tag scores for each
|
||||
token and uses greedy decoding with transition-constraints to return a valid
|
||||
BILUO tag sequence.
|
||||
|
||||
A BILUO tag sequence encodes a sequence of non-overlapping labelled spans
|
||||
into tags assigned to each token. The first token of a span is given the
|
||||
tag B-LABEL, the last token of the span is given the tag L-LABEL, and tokens
|
||||
within the span are given the tag U-LABEL. Single-token spans are given
|
||||
the tag U-LABEL. All other tokens are assigned the tag O.
|
||||
|
||||
The BILUO tag scheme generally results in better linear separation between
|
||||
classes, especially for non-CRF models, because there are more distinct classes
|
||||
for the different situations (Ratinov et al., 2009).
|
||||
"""
|
||||
biluo = BILUO()
|
||||
linear = Linear(
|
||||
nO=None, nI=tok2vec.get_dim("nO"), init_W=configure_normal_init(mean=0.02)
|
||||
|
@ -41,6 +55,15 @@ def BiluoTagger(
|
|||
def IOBTagger(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]]
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
"""Construct a simple NER tagger, that predicts IOB tag scores for each
|
||||
token and uses greedy decoding with transition-constraints to return a valid
|
||||
IOB tag sequence.
|
||||
|
||||
An IOB tag sequence encodes a sequence of non-overlapping labelled spans
|
||||
into tags assigned to each token. The first token of a span is given the
|
||||
tag B-LABEL, and subsequent tokens are given the tag I-LABEL.
|
||||
All other tokens are assigned the tag O.
|
||||
"""
|
||||
biluo = IOB()
|
||||
linear = Linear(nO=None, nI=tok2vec.get_dim("nO"))
|
||||
model = chain(
|
||||
|
|
|
@ -1,11 +1,22 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from thinc.api import zero_init, with_array, Softmax, chain, Model
|
||||
from thinc.types import Floats2d
|
||||
|
||||
from ...util import registry
|
||||
from ...tokens import Doc
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.Tagger.v1")
|
||||
def build_tagger_model(tok2vec: Model, nO: Optional[int] = None) -> Model:
|
||||
def build_tagger_model(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]], nO: Optional[int] = None
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
"""Build a tagger model, using a provided token-to-vector component. The tagger
|
||||
model simply adds a linear layer with softmax activation to predict scores
|
||||
given the token vectors.
|
||||
|
||||
tok2vec (Model[List[Doc], List[Floats2d]]): The token-to-vector subnetwork.
|
||||
nO (int or None): The number of tags to output. Inferred from the data if None.
|
||||
"""
|
||||
# TODO: glorot_uniform_init seems to work a bit better than zero_init here?!
|
||||
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
|
||||
output_layer = Softmax(nO, t2v_width, init_W=zero_init)
|
||||
|
|
|
@ -45,6 +45,7 @@ def build_bow_text_classifier(
|
|||
no_output_layer: bool,
|
||||
nO: Optional[int] = None,
|
||||
) -> Model:
|
||||
# Don't document this yet, I'm not sure it's right.
|
||||
with Model.define_operators({">>": chain}):
|
||||
sparse_linear = SparseLinear(nO)
|
||||
model = extract_ngrams(ngram_size, attr=ORTH) >> sparse_linear
|
||||
|
@ -69,6 +70,7 @@ def build_text_classifier(
|
|||
dropout: Optional[float],
|
||||
nO: Optional[int] = None,
|
||||
) -> Model:
|
||||
# Don't document this yet, I'm not sure it's right.
|
||||
cols = [ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID]
|
||||
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
|
||||
lower = HashEmbed(
|
||||
|
@ -160,6 +162,7 @@ def build_text_classifier_lowdata(
|
|||
dropout: Optional[float],
|
||||
nO: Optional[int] = None,
|
||||
) -> Model:
|
||||
# Don't document this yet, I'm not sure it's right.
|
||||
# Note, before v.3, this was the default if setting "low_data" and "pretrained_dims"
|
||||
with Model.define_operators({">>": chain, "**": clone}):
|
||||
model = (
|
||||
|
|
|
@ -28,11 +28,31 @@ def build_hash_embed_cnn_tok2vec(
|
|||
window_size: int,
|
||||
maxout_pieces: int,
|
||||
subword_features: bool,
|
||||
dropout: Optional[float],
|
||||
pretrained_vectors: Optional[bool]
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
"""Build spaCy's 'standard' tok2vec layer, which uses hash embedding
|
||||
with subword features and a CNN with layer-normalized maxout."""
|
||||
with subword features and a CNN with layer-normalized maxout.
|
||||
|
||||
width (int): The width of the input and output. These are required to be the
|
||||
same, so that residual connections can be used. Recommended values are
|
||||
96, 128 or 300.
|
||||
depth (int): The number of convolutional layers to use. Recommended values
|
||||
are between 2 and 8.
|
||||
window_size (int): The number of tokens on either side to concatenate during
|
||||
the convolutions. The receptive field of the CNN will be
|
||||
depth * (window_size * 2 + 1), so a 4-layer network with window_size of
|
||||
2 will be sensitive to 17 words at a time. Recommended value is 1.
|
||||
embed_size (int): The number of rows in the hash embedding tables. This can
|
||||
be surprisingly small, due to the use of the hash embeddings. Recommended
|
||||
values are between 2000 and 10000.
|
||||
maxout_pieces (int): The number of pieces to use in the maxout non-linearity.
|
||||
If 1, the Mish non-linearity is used instead. Recommended values are 1-3.
|
||||
subword_features (bool): Whether to also embed subword features, specifically
|
||||
the prefix, suffix and word shape. This is recommended for alphabetic
|
||||
languages like English, but not if single-character tokens are used for
|
||||
a language such as Chinese.
|
||||
pretrained_vectors (bool): Whether to also use static vectors.
|
||||
"""
|
||||
return build_Tok2Vec_model(
|
||||
embed=MultiHashEmbed(
|
||||
width=width,
|
||||
|
@ -54,7 +74,14 @@ def build_Tok2Vec_model(
|
|||
embed: Model[List[Doc], List[Floats2d]],
|
||||
encode: Model[List[Floats2d], List[Floats2d]],
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
"""Construct a tok2vec model out of embedding and encoding subnetworks.
|
||||
See https://explosion.ai/blog/deep-learning-formula-nlp
|
||||
|
||||
embed (Model[List[Doc], List[Floats2d]]): Embed tokens into context-independent
|
||||
word vector representations.
|
||||
encode (Model[List[Floats2d], List[Floats2d]]): Encode context into the
|
||||
embeddings, using an architecture such as a CNN, BiLSTM or transformer.
|
||||
"""
|
||||
receptive_field = encode.attrs.get("receptive_field", 0)
|
||||
tok2vec = chain(embed, with_array(encode, pad=receptive_field))
|
||||
tok2vec.set_dim("nO", encode.get_dim("nO"))
|
||||
|
@ -67,6 +94,27 @@ def build_Tok2Vec_model(
|
|||
def MultiHashEmbed(
|
||||
width: int, rows: int, also_embed_subwords: bool, also_use_static_vectors: bool
|
||||
):
|
||||
"""Construct an embedding layer that separately embeds a number of lexical
|
||||
attributes using hash embedding, concatenates the results, and passes it
|
||||
through a feed-forward subnetwork to build a mixed representations.
|
||||
|
||||
The features used are the NORM, PREFIX, SUFFIX and SHAPE, which can have
|
||||
varying definitions depending on the Vocab of the Doc object passed in.
|
||||
Vectors from pretrained static vectors can also be incorporated into the
|
||||
concatenated representation.
|
||||
|
||||
width (int): The output width. Also used as the width of the embedding tables.
|
||||
Recommended values are between 64 and 300.
|
||||
rows (int): The number of rows for the embedding tables. Can be low, due
|
||||
to the hashing trick. Embeddings for prefix, suffix and word shape
|
||||
use half as many rows. Recommended values are between 2000 and 10000.
|
||||
also_embed_subwords (bool): Whether to use the PREFIX, SUFFIX and SHAPE
|
||||
features in the embeddings. If not using these, you may need more
|
||||
rows in your hash embeddings, as there will be increased chance of
|
||||
collisions.
|
||||
also_use_static_vectors (bool): Whether to also use static word vectors.
|
||||
Requires a vectors table to be loaded in the Doc objects' vocab.
|
||||
"""
|
||||
cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
||||
seed = 7
|
||||
|
||||
|
@ -117,6 +165,30 @@ def MultiHashEmbed(
|
|||
|
||||
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
||||
def CharacterEmbed(width: int, rows: int, nM: int, nC: int):
|
||||
"""Construct an embedded representations based on character embeddings, using
|
||||
a feed-forward network. A fixed number of UTF-8 byte characters are used for
|
||||
each word, taken from the beginning and end of the word equally. Padding is
|
||||
used in the centre for words that are too short.
|
||||
|
||||
For instance, let's say nC=4, and the word is "jumping". The characters
|
||||
used will be jung (two from the start, two from the end). If we had nC=8,
|
||||
the characters would be "jumpping": 4 from the start, 4 from the end. This
|
||||
ensures that the final character is always in the last position, instead
|
||||
of being in an arbitrary position depending on the word length.
|
||||
|
||||
The characters are embedded in a embedding table with 256 rows, and the
|
||||
vectors concatenated. A hash-embedded vector of the NORM of the word is
|
||||
also concatenated on, and the result is then passed through a feed-forward
|
||||
network to construct a single vector to represent the information.
|
||||
|
||||
width (int): The width of the output vector and the NORM hash embedding.
|
||||
rows (int): The number of rows in the NORM hash embedding table.
|
||||
nM (int): The dimensionality of the character embeddings. Recommended values
|
||||
are between 16 and 64.
|
||||
nC (int): The number of UTF-8 bytes to embed per word. Recommended values
|
||||
are between 3 and 8, although it may depend on the length of words in the
|
||||
language.
|
||||
"""
|
||||
model = chain(
|
||||
concatenate(
|
||||
chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
|
||||
|
@ -133,7 +205,21 @@ def CharacterEmbed(width: int, rows: int, nM: int, nC: int):
|
|||
|
||||
|
||||
@registry.architectures.register("spacy.MaxoutWindowEncoder.v1")
|
||||
def MaxoutWindowEncoder(width: int, window_size: int, maxout_pieces: int, depth: int):
|
||||
def MaxoutWindowEncoder(
|
||||
width: int, window_size: int, maxout_pieces: int, depth: int
|
||||
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||
"""Encode context using convolutions with maxout activation, layer
|
||||
normalization and residual connections.
|
||||
|
||||
width (int): The input and output width. These are required to be the same,
|
||||
to allow residual connections. This value will be determined by the
|
||||
width of the inputs. Recommended values are between 64 and 300.
|
||||
window_size (int): The number of words to concatenate around each token
|
||||
to construct the convolution. Recommended value is 1.
|
||||
maxout_pieces (int): The number of maxout pieces to use. Recommended
|
||||
values are 2 or 3.
|
||||
depth (int): The number of convolutional layers. Recommended value is 4.
|
||||
"""
|
||||
cnn = chain(
|
||||
expand_window(window_size=window_size),
|
||||
Maxout(
|
||||
|
@ -151,7 +237,19 @@ def MaxoutWindowEncoder(width: int, window_size: int, maxout_pieces: int, depth:
|
|||
|
||||
|
||||
@registry.architectures.register("spacy.MishWindowEncoder.v1")
|
||||
def MishWindowEncoder(width, window_size, depth):
|
||||
def MishWindowEncoder(
|
||||
width: int, window_size: int, depth: int
|
||||
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||
"""Encode context using convolutions with mish activation, layer
|
||||
normalization and residual connections.
|
||||
|
||||
width (int): The input and output width. These are required to be the same,
|
||||
to allow residual connections. This value will be determined by the
|
||||
width of the inputs. Recommended values are between 64 and 300.
|
||||
window_size (int): The number of words to concatenate around each token
|
||||
to construct the convolution. Recommended value is 1.
|
||||
depth (int): The number of convolutional layers. Recommended value is 4.
|
||||
"""
|
||||
cnn = chain(
|
||||
expand_window(window_size=window_size),
|
||||
Mish(nO=width, nI=width * ((window_size * 2) + 1), dropout=0.0, normalize=True),
|
||||
|
@ -162,7 +260,18 @@ def MishWindowEncoder(width, window_size, depth):
|
|||
|
||||
|
||||
@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1")
|
||||
def BiLSTMEncoder(width, depth, dropout):
|
||||
def BiLSTMEncoder(
|
||||
width: int, depth: int, dropout: float
|
||||
) -> Model[List[Floats2d], List[Floats2d]]:
|
||||
"""Encode context using bidirectonal LSTM layers. Requires PyTorch.
|
||||
|
||||
width (int): The input and output width. These are required to be the same,
|
||||
to allow residual connections. This value will be determined by the
|
||||
width of the inputs. Recommended values are between 64 and 300.
|
||||
window_size (int): The number of words to concatenate around each token
|
||||
to construct the convolution. Recommended value is 1.
|
||||
depth (int): The number of convolutional layers. Recommended value is 4.
|
||||
"""
|
||||
if depth == 0:
|
||||
return noop()
|
||||
return with_padded(PyTorchLSTM(width, width, bi=True, depth=depth, dropout=dropout))
|
||||
|
|
|
@ -27,7 +27,6 @@ embed_size = 2000
|
|||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
||||
"""
|
||||
DEFAULT_PARSER_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
|
|
@ -29,7 +29,6 @@ embed_size = 300
|
|||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
||||
"""
|
||||
DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
|
|
@ -29,7 +29,6 @@ embed_size = 2000
|
|||
window_size = 1
|
||||
maxout_pieces = 2
|
||||
subword_features = true
|
||||
dropout = null
|
||||
"""
|
||||
DEFAULT_MT_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
|
|
@ -25,7 +25,6 @@ embed_size = 2000
|
|||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
||||
"""
|
||||
DEFAULT_NER_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
|
|
@ -25,7 +25,6 @@ embed_size = 2000
|
|||
window_size = 1
|
||||
maxout_pieces = 2
|
||||
subword_features = true
|
||||
dropout = null
|
||||
"""
|
||||
DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ from .pipe import Pipe
|
|||
|
||||
default_model_config = """
|
||||
[model]
|
||||
@architectures = "spacy.BiluoTagger.v1"
|
||||
@architectures = "spacy.BILUOTagger.v1"
|
||||
|
||||
[model.tok2vec]
|
||||
@architectures = "spacy.HashEmbedCNN.v1"
|
||||
|
@ -26,7 +26,6 @@ embed_size = 7000
|
|||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
||||
"""
|
||||
DEFAULT_SIMPLE_NER_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
|
|
@ -31,7 +31,6 @@ embed_size = 2000
|
|||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
||||
"""
|
||||
DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
|
|
@ -48,7 +48,6 @@ embed_size = 2000
|
|||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
||||
"""
|
||||
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ embed_size = 2000
|
|||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
||||
"""
|
||||
DEFAULT_TOK2VEC_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
|
|
@ -48,7 +48,6 @@ window_size = 1
|
|||
embed_size = 2000
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
||||
|
||||
[components.tagger]
|
||||
factory = "tagger"
|
||||
|
@ -78,7 +77,6 @@ embed_size = 5555
|
|||
window_size = 1
|
||||
maxout_pieces = 7
|
||||
subword_features = false
|
||||
dropout = null
|
||||
"""
|
||||
|
||||
|
||||
|
|
|
@ -15,37 +15,194 @@ TODO: intro and how architectures work, link to
|
|||
[`registry`](/api/top-level#registry),
|
||||
[custom models](/usage/training#custom-models) usage etc.
|
||||
|
||||
## Tok2Vec architectures {#tok2vec source="spacy/ml/models/tok2vec.py"}
|
||||
## Tok2Vec architectures {#tok2vec-arch source="spacy/ml/models/tok2vec.py"}
|
||||
|
||||
### spacy.HashEmbedCNN.v1 {#HashEmbedCNN}
|
||||
|
||||
<!-- TODO: intro -->
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
> ```ini
|
||||
> [model]
|
||||
> @architectures = "spacy.HashEmbedCNN.v1"
|
||||
> # TODO: ...
|
||||
>
|
||||
> [model.tok2vec]
|
||||
> # ...
|
||||
> pretrained_vectors = null
|
||||
> width = 96
|
||||
> depth = 4
|
||||
> embed_size = 2000
|
||||
> window_size = 1
|
||||
> maxout_pieces = 3
|
||||
> subword_features = true
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------------- | ----- | ----------- |
|
||||
| `width` | int | |
|
||||
| `depth` | int | |
|
||||
| `embed_size` | int | |
|
||||
| `window_size` | int | |
|
||||
| `maxout_pieces` | int | |
|
||||
| `subword_features` | bool | |
|
||||
| `dropout` | float | |
|
||||
| `pretrained_vectors` | bool | |
|
||||
Build spaCy's 'standard' tok2vec layer, which uses hash embedding with subword
|
||||
features and a CNN with layer-normalized maxout.
|
||||
|
||||
### spacy.HashCharEmbedCNN.v1 {#HashCharEmbedCNN}
|
||||
| Name | Type | Description |
|
||||
| -------------------- | ---- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `width` | int | The width of the input and output. These are required to be the same, so that residual connections can be used. Recommended values are `96`, `128` or `300`. |
|
||||
| `depth` | int | The number of convolutional layers to use. Recommended values are between `2` and `8`. |
|
||||
| `embed_size` | int | The number of rows in the hash embedding tables. This can be surprisingly small, due to the use of the hash embeddings. Recommended values are between `2000` and `10000`. |
|
||||
| `window_size` | int | The number of tokens on either side to concatenate during the convolutions. The receptive field of the CNN will be `depth * (window_size * 2 + 1)`, so a 4-layer network with a window size of `2` will be sensitive to 17 words at a time. Recommended value is `1`. |
|
||||
| `maxout_pieces` | int | The number of pieces to use in the maxout non-linearity. If `1`, the [`Mish`](https://thinc.ai/docs/api-layers#mish) non-linearity is used instead. Recommended values are `1`-`3`. |
|
||||
| `subword_features` | bool | Whether to also embed subword features, specifically the prefix, suffix and word shape. This is recommended for alphabetic languages like English, but not if single-character tokens are used for a language such as Chinese. |
|
||||
| `pretrained_vectors` | bool | Whether to also use static vectors. |
|
||||
|
||||
### spacy.HashCharEmbedBiLSTM.v1 {#HashCharEmbedBiLSTM}
|
||||
### spacy.Tok2Vec.v1 {#Tok2Vec}
|
||||
|
||||
<!-- TODO: example config -->
|
||||
|
||||
> #### Example config
|
||||
>
|
||||
> ```ini
|
||||
> [model]
|
||||
> @architectures = "spacy.Tok2Vec.v1"
|
||||
>
|
||||
> [model.embed]
|
||||
>
|
||||
> [model.encode]
|
||||
> ```
|
||||
|
||||
Construct a tok2vec model out of embedding and encoding subnetworks. See the
|
||||
["Embed, Encode, Attend, Predict"](https://explosion.ai/blog/deep-learning-formula-nlp)
|
||||
blog post for background.
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------- | ------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `embed` | [`Model`](https://thinc.ai/docs/api-model) | **Input:** `List[Doc]`. **Output:** `List[Floats2d]`. Embed tokens into context-independent word vector representations. |
|
||||
| `encode` | [`Model`](https://thinc.ai/docs/api-model) | **Input:** `List[Floats2d]`. **Output:** `List[Floats2d]`. Encode context into the embeddings, using an architecture such as a CNN, BiLSTM or transformer. |
|
||||
|
||||
### spacy.MultiHashEmbed.v1 {#MultiHashEmbed}
|
||||
|
||||
<!-- TODO: check example config -->
|
||||
|
||||
> #### Example config
|
||||
>
|
||||
> ```ini
|
||||
> [model]
|
||||
> @architectures = "spacy.MultiHashEmbed.v1"
|
||||
> width = 64
|
||||
> rows = 2000
|
||||
> also_embed_subwords = false
|
||||
> also_use_static_vectors = false
|
||||
> ```
|
||||
|
||||
Construct an embedding layer that separately embeds a number of lexical
|
||||
attributes using hash embedding, concatenates the results, and passes it through
|
||||
a feed-forward subnetwork to build a mixed representations. The features used
|
||||
are the `NORM`, `PREFIX`, `SUFFIX` and `SHAPE`, which can have varying
|
||||
definitions depending on the `Vocab` of the `Doc` object passed in. Vectors from
|
||||
pretrained static vectors can also be incorporated into the concatenated
|
||||
representation.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------------------- | ---- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `width` | int | The output width. Also used as the width of the embedding tables. Recommended values are between `64` and `300`. |
|
||||
| `rows` | int | The number of rows for the embedding tables. Can be low, due to the hashing trick. Embeddings for prefix, suffix and word shape use half as many rows. Recommended values are between `2000` and `10000`. |
|
||||
| `also_embed_subwords` | bool | Whether to use the `PREFIX`, `SUFFIX` and `SHAPE` features in the embeddings. If not using these, you may need more rows in your hash embeddings, as there will be increased chance of collisions. |
|
||||
| `also_use_static_vectors` | bool | Whether to also use static word vectors. Requires a vectors table to be loaded in the [Doc](/api/doc) objects' vocab. |
|
||||
|
||||
### spacy.CharacterEmbed.v1 {#CharacterEmbed}
|
||||
|
||||
<!-- TODO: check example config -->
|
||||
|
||||
> #### Example config
|
||||
>
|
||||
> ```ini
|
||||
> [model]
|
||||
> @architectures = "spacy.CharacterEmbed.v1"
|
||||
> width = 64
|
||||
> rows = 2000
|
||||
> nM = 16
|
||||
> nC = 4
|
||||
> ```
|
||||
|
||||
Construct an embedded representations based on character embeddings, using a
|
||||
feed-forward network. A fixed number of UTF-8 byte characters are used for each
|
||||
word, taken from the beginning and end of the word equally. Padding is used in
|
||||
the center for words that are too short.
|
||||
|
||||
For instance, let's say `nC=4`, and the word is "jumping". The characters used
|
||||
will be `"jung"` (two from the start, two from the end). If we had `nC=8`, the
|
||||
characters would be `"jumpping"`: 4 from the start, 4 from the end. This ensures
|
||||
that the final character is always in the last position, instead of being in an
|
||||
arbitrary position depending on the word length.
|
||||
|
||||
The characters are embedded in a embedding table with 256 rows, and the vectors
|
||||
concatenated. A hash-embedded vector of the `NORM` of the word is also
|
||||
concatenated on, and the result is then passed through a feed-forward network to
|
||||
construct a single vector to represent the information.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------- | ---- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `width` | int | The width of the output vector and the `NORM` hash embedding. |
|
||||
| `rows` | int | The number of rows in the `NORM` hash embedding table. |
|
||||
| `nM` | int | The dimensionality of the character embeddings. Recommended values are between `16` and `64`. |
|
||||
| `nC` | int | The number of UTF-8 bytes to embed per word. Recommended values are between `3` and `8`, although it may depend on the length of words in the language. |
|
||||
|
||||
### spacy.MaxoutWindowEncoder.v1 {#MaxoutWindowEncoder}
|
||||
|
||||
> #### Example config
|
||||
>
|
||||
> ```ini
|
||||
> [model]
|
||||
> @architectures = "spacy.MaxoutWindowEncoder.v1"
|
||||
> width = 64
|
||||
> window_size = 1
|
||||
> maxout_pieces = 2
|
||||
> depth = 4
|
||||
> ```
|
||||
|
||||
Encode context using convolutions with maxout activation, layer normalization
|
||||
and residual connections.
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------------- | ---- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `width` | int | The input and output width. These are required to be the same, to allow residual connections. This value will be determined by the width of the inputs. Recommended values are between `64` and `300`. |
|
||||
| `window_size` | int | The number of words to concatenate around each token to construct the convolution. Recommended value is `1`. |
|
||||
| `maxout_pieces` | int | The number of maxout pieces to use. Recommended values are `2` or `3`. |
|
||||
| `depth` | int | The number of convolutional layers. Recommended value is `4`. |
|
||||
|
||||
### spacy.MishWindowEncoder.v1 {#MishWindowEncoder}
|
||||
|
||||
> #### Example config
|
||||
>
|
||||
> ```ini
|
||||
> [model]
|
||||
> @architectures = "spacy.MishWindowEncoder.v1"
|
||||
> width = 64
|
||||
> window_size = 1
|
||||
> depth = 4
|
||||
> ```
|
||||
|
||||
Encode context using convolutions with
|
||||
[`Mish`](https://thinc.ai/docs/api-layers#mish) activation, layer normalization
|
||||
and residual connections.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------- | ---- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `width` | int | The input and output width. These are required to be the same, to allow residual connections. This value will be determined by the width of the inputs. Recommended values are between `64` and `300`. |
|
||||
| `window_size` | int | The number of words to concatenate around each token to construct the convolution. Recommended value is `1`. |
|
||||
| `depth` | int | The number of convolutional layers. Recommended value is `4`. |
|
||||
|
||||
### spacy.TorchBiLSTMEncoder.v1 {#TorchBiLSTMEncoder}
|
||||
|
||||
> #### Example config
|
||||
>
|
||||
> ```ini
|
||||
> [model]
|
||||
> @architectures = "spacy.TorchBiLSTMEncoder.v1"
|
||||
> width = 64
|
||||
> window_size = 1
|
||||
> depth = 4
|
||||
> ```
|
||||
|
||||
Encode context using bidirectonal LSTM layers. Requires
|
||||
[PyTorch](https://pytorch.org).
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------- | ---- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `width` | int | The input and output width. These are required to be the same, to allow residual connections. This value will be determined by the width of the inputs. Recommended values are between `64` and `300`. |
|
||||
| `window_size` | int | The number of words to concatenate around each token to construct the convolution. Recommended value is `1`. |
|
||||
| `depth` | int | The number of convolutional layers. Recommended value is `4`. |
|
||||
|
||||
## Transformer architectures {#transformers source="github.com/explosion/spacy-transformers/blob/master/spacy_transformers/architectures.py"}
|
||||
|
||||
|
@ -98,9 +255,9 @@ architectures into your training config.
|
|||
| `grad_factor` | float | Factor for weighting the gradient if multiple components listen to the same transformer model. |
|
||||
| `pooling` | `Model[Ragged, Floats2d]` | Pooling layer to determine how the vector for each spaCy token will be computed. |
|
||||
|
||||
## Parser & NER architectures {#parser source="spacy/ml/models/parser.py"}
|
||||
## Parser & NER architectures {#parser}
|
||||
|
||||
### spacy.TransitionBasedParser.v1 {#TransitionBasedParser}
|
||||
### spacy.TransitionBasedParser.v1 {#TransitionBasedParser source="spacy/ml/models/parser.py"}
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
|
@ -112,24 +269,100 @@ architectures into your training config.
|
|||
> maxout_pieces = 2
|
||||
>
|
||||
> [model.tok2vec]
|
||||
> # ...
|
||||
> @architectures = "spacy.HashEmbedCNN.v1"
|
||||
> pretrained_vectors = null
|
||||
> width = 96
|
||||
> depth = 4
|
||||
> embed_size = 2000
|
||||
> window_size = 1
|
||||
> maxout_pieces = 3
|
||||
> subword_features = true
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------------- | ------------------------------------------ | ----------- |
|
||||
| `tok2vec` | [`Model`](https://thinc.ai/docs/api-model) | |
|
||||
| `nr_feature_tokens` | int | |
|
||||
| `hidden_width` | int | |
|
||||
| `maxout_pieces` | int | |
|
||||
| `use_upper` | bool | |
|
||||
| `nO` | int | |
|
||||
Build a transition-based parser model. Can apply to NER or dependency-parsing.
|
||||
Transition-based parsing is an approach to structured prediction where the task
|
||||
of predicting the structure is mapped to a series of state transitions. You
|
||||
might find [this tutorial](https://explosion.ai/blog/parsing-english-in-python)
|
||||
helpful for background information. The neural network state prediction model
|
||||
consists of either two or three subnetworks:
|
||||
|
||||
- **tok2vec**: Map each token into a vector representations. This subnetwork is
|
||||
run once for each batch.
|
||||
- **lower**: Construct a feature-specific vector for each `(token, feature)`
|
||||
pair. This is also run once for each batch. Constructing the state
|
||||
representation is then simply a matter of summing the component features and
|
||||
applying the non-linearity.
|
||||
- **upper** (optional): A feed-forward network that predicts scores from the
|
||||
state representation. If not present, the output from the lower model is used
|
||||
as action scores directly.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------------- | ------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `tok2vec` | [`Model`](https://thinc.ai/docs/api-model) | **Input:** `List[Doc]`. **Output:** `List[Floats2d]`. Subnetwork to map tokens into vector representations. |
|
||||
| `nr_feature_tokens` | int | The number of tokens in the context to use to construct the state vector. Valid choices are `1`, `2`, `3`, `6`, `8` and `13`. The `2`, `8` and `13` feature sets are designed for the parser, while the `3` and `6` feature sets are designed for the entity recognizer. The recommended feature sets are `3` for NER, and `8` for the dependency parser. |
|
||||
| `hidden_width` | int | The width of the hidden layer. |
|
||||
| `maxout_pieces` | int | How many pieces to use in the state prediction layer. Recommended values are `1`, `2` or `3`. If `1`, the maxout non-linearity is replaced with a [`Relu`](https://thinc.ai/docs/api-layers#relu) non-linearity if `use_upper` is `True`, and no non-linearity if `False`. |
|
||||
| `use_upper` | bool | Whether to use an additional hidden layer after the state vector in order to predict the action scores. It is recommended to set this to `False` for large pretrained models such as transformers, and `True` for smaller networks. The upper layer is computed on CPU, which becomes a bottleneck on larger GPU-based models, where it's also less necessary. |
|
||||
| `nO` | int | The number of actions the model will predict between. Usually inferred from data at the beginning of training, or loaded from disk. |
|
||||
|
||||
### spacy.BILUOTagger.v1 {#BILUOTagger source="spacy/ml/models/simple_ner.py"}
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
> ```ini
|
||||
> [model]
|
||||
> @architectures = "spacy.BILUOTagger.v1 "
|
||||
>
|
||||
> [model.tok2vec]
|
||||
> @architectures = "spacy.HashEmbedCNN.v1"
|
||||
> # etc.
|
||||
> ```
|
||||
|
||||
Construct a simple NER tagger that predicts
|
||||
[BILUO](/usage/linguistic-features#accessing-ner) tag scores for each token and
|
||||
uses greedy decoding with transition-constraints to return a valid BILUO tag
|
||||
sequence. A BILUO tag sequence encodes a sequence of non-overlapping labelled
|
||||
spans into tags assigned to each token. The first token of a span is given the
|
||||
tag `B-LABEL`, the last token of the span is given the tag `L-LABEL`, and tokens
|
||||
within the span are given the tag `U-LABEL`. Single-token spans are given the
|
||||
tag `U-LABEL`. All other tokens are assigned the tag `O`. The BILUO tag scheme
|
||||
generally results in better linear separation between classes, especially for
|
||||
non-CRF models, because there are more distinct classes for the different
|
||||
situations ([Ratinov et al., 2009](https://www.aclweb.org/anthology/W09-1119/)).
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | ------------------------------------------ | ----------------------------------------------------------------------------------------------------------- |
|
||||
| `tok2vec` | [`Model`](https://thinc.ai/docs/api-model) | **Input:** `List[Doc]`. **Output:** `List[Floats2d]`. Subnetwork to map tokens into vector representations. |
|
||||
|
||||
### spacy.IOBTagger.v1 {#IOBTagger source="spacy/ml/models/simple_ner.py"}
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
> ```ini
|
||||
> [model]
|
||||
> @architectures = "spacy.IOBTagger.v1 "
|
||||
>
|
||||
> [model.tok2vec]
|
||||
> @architectures = "spacy.HashEmbedCNN.v1"
|
||||
> # etc.
|
||||
> ```
|
||||
|
||||
Construct a simple NER tagger, that predicts
|
||||
[IOB](/usage/linguistic-features#accessing-ner) tag scores for each token and
|
||||
uses greedy decoding with transition-constraints to return a valid IOB tag
|
||||
sequence. An IOB tag sequence encodes a sequence of non-overlapping labeled
|
||||
spans into tags assigned to each token. The first token of a span is given the
|
||||
tag B-LABEL, and subsequent tokens are given the tag I-LABEL. All other tokens
|
||||
are assigned the tag O.
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | ------------------------------------------ | ----------------------------------------------------------------------------------------------------------- |
|
||||
| `tok2vec` | [`Model`](https://thinc.ai/docs/api-model) | **Input:** `List[Doc]`. **Output:** `List[Floats2d]`. Subnetwork to map tokens into vector representations. |
|
||||
|
||||
## Tagging architectures {#tagger source="spacy/ml/models/tagger.py"}
|
||||
|
||||
### spacy.Tagger.v1 {#Tagger}
|
||||
|
||||
<!-- TODO: intro -->
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
> ```ini
|
||||
|
@ -141,18 +374,22 @@ architectures into your training config.
|
|||
> # ...
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | ------------------------------------------ | ----------- |
|
||||
| `tok2vec` | [`Model`](https://thinc.ai/docs/api-model) | |
|
||||
| `nO` | int | |
|
||||
Build a tagger model, using a provided token-to-vector component. The tagger
|
||||
model simply adds a linear layer with softmax activation to predict scores given
|
||||
the token vectors.
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | ------------------------------------------ | ----------------------------------------------------------------------------------------------------------- |
|
||||
| `tok2vec` | [`Model`](https://thinc.ai/docs/api-model) | **Input:** `List[Doc]`. **Output:** `List[Floats2d]`. Subnetwork to map tokens into vector representations. |
|
||||
| `nO` | int | The number of tags to output. Inferred from the data if `None`. |
|
||||
|
||||
## Text classification architectures {#textcat source="spacy/ml/models/textcat.py"}
|
||||
|
||||
A text classification architecture needs to take a `Doc` as input, and produce a
|
||||
score for each potential label class. Textcat challenges can be binary (e.g.
|
||||
sentiment analysis) or involve multiple possible labels. Multi-label challenges
|
||||
can either have mutually exclusive labels (each example has exactly one label),
|
||||
or multiple labels may be applicable at the same time.
|
||||
A text classification architecture needs to take a [`Doc`](/api/doc) as input,
|
||||
and produce a score for each potential label class. Textcat challenges can be
|
||||
binary (e.g. sentiment analysis) or involve multiple possible labels.
|
||||
Multi-label challenges can either have mutually exclusive labels (each example
|
||||
has exactly one label), or multiple labels may be applicable at the same time.
|
||||
|
||||
As the properties of text classification problems can vary widely, we provide
|
||||
several different built-in architectures. It is recommended to experiment with
|
||||
|
@ -214,7 +451,6 @@ If the `nO` dimension is not set, the TextCategorizer component will set it when
|
|||
> window_size = 1
|
||||
> maxout_pieces = 3
|
||||
> subword_features = true
|
||||
> dropout = null
|
||||
> ```
|
||||
|
||||
A neural network model where token vectors are calculated using a CNN. The
|
||||
|
@ -232,20 +468,20 @@ If the `nO` dimension is not set, the TextCategorizer component will set it when
|
|||
|
||||
### spacy.TextCatBOW.v1 {#TextCatBOW}
|
||||
|
||||
An ngram "bag-of-words" model. This architecture should run much faster than the
|
||||
others, but may not be as accurate, especially if texts are short.
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
> ```ini
|
||||
> [model]
|
||||
> @architectures = "spacy.TextCatBOW.v1"
|
||||
> exclusive_classes = false
|
||||
> ngram_size: 1
|
||||
> no_output_layer: false
|
||||
> ngram_size = 1
|
||||
> no_output_layer = false
|
||||
> nO = null
|
||||
> ```
|
||||
|
||||
An ngram "bag-of-words" model. This architecture should run much faster than the
|
||||
others, but may not be as accurate, especially if texts are short.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------------- | ----- | ------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `exclusive_classes` | bool | Whether or not categories are mutually exclusive. |
|
||||
|
@ -260,9 +496,9 @@ If the `nO` dimension is not set, the TextCategorizer component will set it when
|
|||
|
||||
## Entity linking architectures {#entitylinker source="spacy/ml/models/entity_linker.py"}
|
||||
|
||||
An `EntityLinker` component disambiguates textual mentions (tagged as named
|
||||
entities) to unique identifiers, grounding the named entities into the "real
|
||||
world". This requires 3 main components:
|
||||
An [`EntityLinker`](/api/entitylinker) component disambiguates textual mentions
|
||||
(tagged as named entities) to unique identifiers, grounding the named entities
|
||||
into the "real world". This requires 3 main components:
|
||||
|
||||
- A [`KnowledgeBase`](/api/kb) (KB) holding the unique identifiers, potential
|
||||
synonyms and prior probabilities.
|
||||
|
@ -292,7 +528,6 @@ layer.
|
|||
> window_size = 1
|
||||
> maxout_pieces = 3
|
||||
> subword_features = true
|
||||
> dropout = null
|
||||
>
|
||||
> [kb_loader]
|
||||
> @assets = "spacy.EmptyKB.v1"
|
||||
|
|
|
@ -217,7 +217,7 @@ $ python -m spacy convert ./data.json ./output
|
|||
|
||||
</Infobox>
|
||||
|
||||
> #### Annotating entities {#biluo}
|
||||
> #### Annotating entities
|
||||
>
|
||||
> Named entities are provided in the
|
||||
> [BILUO](/usage/linguistic-features#accessing-ner) notation. Tokens outside an
|
||||
|
@ -328,7 +328,7 @@ to keep track of your settings and hyperparameters and your own
|
|||
| `sent_starts` | `List[bool]` | List of boolean values indicating whether each token is the first of a sentence or not. |
|
||||
| `deps` | `List[str]` | List of string values indicating the [dependency relation](/usage/linguistic-features#dependency-parse) of a token to its head. |
|
||||
| `heads` | `List[int]` | List of integer values indicating the dependency head of each token, referring to the absolute index of each token in the text. |
|
||||
| `entities` | `List[str]` | **Option 1:** List of [BILUO tags](#biluo) per token of the format `"{action}-{label}"`, or `None` for unannotated tokens. |
|
||||
| `entities` | `List[str]` | **Option 1:** List of [BILUO tags](/usage/linguistic-features#accessing-ner) per token of the format `"{action}-{label}"`, or `None` for unannotated tokens. |
|
||||
| `entities` | `List[Tuple[int, int, str]]` | **Option 2:** List of `"(start, end, label)"` tuples defining all entities in the text. |
|
||||
| `cats` | `Dict[str, float]` | Dictionary of `label`/`value` pairs indicating how relevant a certain [text category](/api/textcategorizer) is for the text. |
|
||||
| `links` | `Dict[(int, int), Dict]` | Dictionary of `offset`/`dict` pairs defining [named entity links](/usage/linguistic-features#entity-linking). The character offsets are linked to a dictionary of relevant knowledge base IDs. |
|
||||
|
|
Loading…
Reference in New Issue
Block a user