mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-25 00:34:20 +03:00
Merge branch 'develop' of https://github.com/explosion/spaCy into develop
This commit is contained in:
commit
f038841798
|
@ -317,7 +317,8 @@ def test_doc_from_array_morph(en_vocab):
|
|||
|
||||
|
||||
def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
|
||||
en_texts = ["Merging the docs is fun.", "They don't think alike."]
|
||||
en_texts = ["Merging the docs is fun.", "", "They don't think alike."]
|
||||
en_texts_without_empty = [t for t in en_texts if len(t)]
|
||||
de_text = "Wie war die Frage?"
|
||||
en_docs = [en_tokenizer(text) for text in en_texts]
|
||||
docs_idx = en_texts[0].index("docs")
|
||||
|
@ -338,14 +339,14 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
|
|||
Doc.from_docs(en_docs + [de_doc])
|
||||
|
||||
m_doc = Doc.from_docs(en_docs)
|
||||
assert len(en_docs) == len(list(m_doc.sents))
|
||||
assert len(en_texts_without_empty) == len(list(m_doc.sents))
|
||||
assert len(str(m_doc)) > len(en_texts[0]) + len(en_texts[1])
|
||||
assert str(m_doc) == " ".join(en_texts)
|
||||
assert str(m_doc) == " ".join(en_texts_without_empty)
|
||||
p_token = m_doc[len(en_docs[0]) - 1]
|
||||
assert p_token.text == "." and bool(p_token.whitespace_)
|
||||
en_docs_tokens = [t for doc in en_docs for t in doc]
|
||||
assert len(m_doc) == len(en_docs_tokens)
|
||||
think_idx = len(en_texts[0]) + 1 + en_texts[1].index("think")
|
||||
think_idx = len(en_texts[0]) + 1 + en_texts[2].index("think")
|
||||
assert m_doc[9].idx == think_idx
|
||||
with pytest.raises(AttributeError):
|
||||
# not callable, because it was not set via set_extension
|
||||
|
@ -353,14 +354,14 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
|
|||
assert len(m_doc.user_data) == len(en_docs[0].user_data) # but it's there
|
||||
|
||||
m_doc = Doc.from_docs(en_docs, ensure_whitespace=False)
|
||||
assert len(en_docs) == len(list(m_doc.sents))
|
||||
assert len(str(m_doc)) == len(en_texts[0]) + len(en_texts[1])
|
||||
assert len(en_texts_without_empty) == len(list(m_doc.sents))
|
||||
assert len(str(m_doc)) == sum(len(t) for t in en_texts)
|
||||
assert str(m_doc) == "".join(en_texts)
|
||||
p_token = m_doc[len(en_docs[0]) - 1]
|
||||
assert p_token.text == "." and not bool(p_token.whitespace_)
|
||||
en_docs_tokens = [t for doc in en_docs for t in doc]
|
||||
assert len(m_doc) == len(en_docs_tokens)
|
||||
think_idx = len(en_texts[0]) + 0 + en_texts[1].index("think")
|
||||
think_idx = len(en_texts[0]) + 0 + en_texts[2].index("think")
|
||||
assert m_doc[9].idx == think_idx
|
||||
|
||||
m_doc = Doc.from_docs(en_docs, attrs=["lemma", "length", "pos"])
|
||||
|
@ -369,12 +370,12 @@ def test_doc_api_from_docs(en_tokenizer, de_tokenizer):
|
|||
assert list(m_doc.sents)
|
||||
assert len(str(m_doc)) > len(en_texts[0]) + len(en_texts[1])
|
||||
# space delimiter considered, although spacy attribute was missing
|
||||
assert str(m_doc) == " ".join(en_texts)
|
||||
assert str(m_doc) == " ".join(en_texts_without_empty)
|
||||
p_token = m_doc[len(en_docs[0]) - 1]
|
||||
assert p_token.text == "." and bool(p_token.whitespace_)
|
||||
en_docs_tokens = [t for doc in en_docs for t in doc]
|
||||
assert len(m_doc) == len(en_docs_tokens)
|
||||
think_idx = len(en_texts[0]) + 1 + en_texts[1].index("think")
|
||||
think_idx = len(en_texts[0]) + 1 + en_texts[2].index("think")
|
||||
assert m_doc[9].idx == think_idx
|
||||
|
||||
|
||||
|
|
|
@ -34,9 +34,9 @@ cdef class Tokenizer:
|
|||
vector[SpanC] &filtered)
|
||||
cdef int _retokenize_special_spans(self, Doc doc, TokenC* tokens,
|
||||
object span_data)
|
||||
cdef int _try_cache(self, hash_t key, Doc tokens) except -1
|
||||
cdef int _try_specials(self, hash_t key, Doc tokens,
|
||||
int* has_special) except -1
|
||||
cdef int _try_specials_and_cache(self, hash_t key, Doc tokens,
|
||||
int* has_special,
|
||||
bint with_special_cases) except -1
|
||||
cdef int _tokenize(self, Doc tokens, unicode span, hash_t key,
|
||||
int* has_special, bint with_special_cases) except -1
|
||||
cdef unicode _split_affixes(self, Pool mem, unicode string,
|
||||
|
|
|
@ -169,8 +169,6 @@ cdef class Tokenizer:
|
|||
cdef int i = 0
|
||||
cdef int start = 0
|
||||
cdef int has_special = 0
|
||||
cdef bint specials_hit = 0
|
||||
cdef bint cache_hit = 0
|
||||
cdef bint in_ws = string[0].isspace()
|
||||
cdef unicode span
|
||||
# The task here is much like string.split, but not quite
|
||||
|
@ -186,13 +184,7 @@ cdef class Tokenizer:
|
|||
# we don't have to create the slice when we hit the cache.
|
||||
span = string[start:i]
|
||||
key = hash_string(span)
|
||||
specials_hit = 0
|
||||
cache_hit = 0
|
||||
if with_special_cases:
|
||||
specials_hit = self._try_specials(key, doc, &has_special)
|
||||
if not specials_hit:
|
||||
cache_hit = self._try_cache(key, doc)
|
||||
if not specials_hit and not cache_hit:
|
||||
if not self._try_specials_and_cache(key, doc, &has_special, with_special_cases):
|
||||
self._tokenize(doc, span, key, &has_special, with_special_cases)
|
||||
if uc == ' ':
|
||||
doc.c[doc.length - 1].spacy = True
|
||||
|
@ -204,13 +196,7 @@ cdef class Tokenizer:
|
|||
if start < i:
|
||||
span = string[start:]
|
||||
key = hash_string(span)
|
||||
specials_hit = 0
|
||||
cache_hit = 0
|
||||
if with_special_cases:
|
||||
specials_hit = self._try_specials(key, doc, &has_special)
|
||||
if not specials_hit:
|
||||
cache_hit = self._try_cache(key, doc)
|
||||
if not specials_hit and not cache_hit:
|
||||
if not self._try_specials_and_cache(key, doc, &has_special, with_special_cases):
|
||||
self._tokenize(doc, span, key, &has_special, with_special_cases)
|
||||
doc.c[doc.length - 1].spacy = string[-1] == " " and not in_ws
|
||||
return doc
|
||||
|
@ -364,27 +350,33 @@ cdef class Tokenizer:
|
|||
offset += span[3]
|
||||
return offset
|
||||
|
||||
cdef int _try_cache(self, hash_t key, Doc tokens) except -1:
|
||||
cached = <_Cached*>self._cache.get(key)
|
||||
if cached == NULL:
|
||||
return False
|
||||
cdef int _try_specials_and_cache(self, hash_t key, Doc tokens, int* has_special, bint with_special_cases) except -1:
|
||||
cdef bint specials_hit = 0
|
||||
cdef bint cache_hit = 0
|
||||
cdef int i
|
||||
if cached.is_lex:
|
||||
for i in range(cached.length):
|
||||
tokens.push_back(cached.data.lexemes[i], False)
|
||||
else:
|
||||
for i in range(cached.length):
|
||||
tokens.push_back(&cached.data.tokens[i], False)
|
||||
return True
|
||||
|
||||
cdef int _try_specials(self, hash_t key, Doc tokens, int* has_special) except -1:
|
||||
cached = <_Cached*>self._specials.get(key)
|
||||
if cached == NULL:
|
||||
if with_special_cases:
|
||||
cached = <_Cached*>self._specials.get(key)
|
||||
if cached == NULL:
|
||||
specials_hit = False
|
||||
else:
|
||||
for i in range(cached.length):
|
||||
tokens.push_back(&cached.data.tokens[i], False)
|
||||
has_special[0] = 1
|
||||
specials_hit = True
|
||||
if not specials_hit:
|
||||
cached = <_Cached*>self._cache.get(key)
|
||||
if cached == NULL:
|
||||
cache_hit = False
|
||||
else:
|
||||
if cached.is_lex:
|
||||
for i in range(cached.length):
|
||||
tokens.push_back(cached.data.lexemes[i], False)
|
||||
else:
|
||||
for i in range(cached.length):
|
||||
tokens.push_back(&cached.data.tokens[i], False)
|
||||
cache_hit = True
|
||||
if not specials_hit and not cache_hit:
|
||||
return False
|
||||
cdef int i
|
||||
for i in range(cached.length):
|
||||
tokens.push_back(&cached.data.tokens[i], False)
|
||||
has_special[0] = 1
|
||||
return True
|
||||
|
||||
cdef int _tokenize(self, Doc tokens, unicode span, hash_t orig_key, int* has_special, bint with_special_cases) except -1:
|
||||
|
@ -462,12 +454,7 @@ cdef class Tokenizer:
|
|||
for i in range(prefixes.size()):
|
||||
tokens.push_back(prefixes[0][i], False)
|
||||
if string:
|
||||
if with_special_cases:
|
||||
specials_hit = self._try_specials(hash_string(string), tokens,
|
||||
has_special)
|
||||
if not specials_hit:
|
||||
cache_hit = self._try_cache(hash_string(string), tokens)
|
||||
if specials_hit or cache_hit:
|
||||
if self._try_specials_and_cache(hash_string(string), tokens, has_special, with_special_cases):
|
||||
pass
|
||||
elif (self.token_match and self.token_match(string)) or \
|
||||
(self.url_match and \
|
||||
|
|
|
@ -920,7 +920,9 @@ cdef class Doc:
|
|||
warnings.warn(Warnings.W101.format(name=name))
|
||||
else:
|
||||
warnings.warn(Warnings.W102.format(key=key, value=value))
|
||||
char_offset += len(doc.text) if not ensure_whitespace or doc[-1].is_space else len(doc.text) + 1
|
||||
char_offset += len(doc.text)
|
||||
if ensure_whitespace and not (len(doc) > 0 and doc[-1].is_space):
|
||||
char_offset += 1
|
||||
|
||||
arrays = [doc.to_array(attrs) for doc in docs]
|
||||
|
||||
|
@ -932,7 +934,7 @@ cdef class Doc:
|
|||
token_offset = -1
|
||||
for doc in docs[:-1]:
|
||||
token_offset += len(doc)
|
||||
if not doc[-1].is_space:
|
||||
if not (len(doc) > 0 and doc[-1].is_space):
|
||||
concat_spaces[token_offset] = True
|
||||
|
||||
concat_array = numpy.concatenate(arrays)
|
||||
|
|
|
@ -25,36 +25,6 @@ usage documentation on
|
|||
|
||||
## Tok2Vec architectures {#tok2vec-arch source="spacy/ml/models/tok2vec.py"}
|
||||
|
||||
### spacy.HashEmbedCNN.v1 {#HashEmbedCNN}
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
> ```ini
|
||||
> [model]
|
||||
> @architectures = "spacy.HashEmbedCNN.v1"
|
||||
> pretrained_vectors = null
|
||||
> width = 96
|
||||
> depth = 4
|
||||
> embed_size = 2000
|
||||
> window_size = 1
|
||||
> maxout_pieces = 3
|
||||
> subword_features = true
|
||||
> ```
|
||||
|
||||
Build spaCy's "standard" embedding layer, which uses hash embedding with subword
|
||||
features and a CNN with layer-normalized maxout.
|
||||
|
||||
| Name | Description |
|
||||
| -------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `width` | The width of the input and output. These are required to be the same, so that residual connections can be used. Recommended values are `96`, `128` or `300`. ~~int~~ |
|
||||
| `depth` | The number of convolutional layers to use. Recommended values are between `2` and `8`. ~~int~~ |
|
||||
| `embed_size` | The number of rows in the hash embedding tables. This can be surprisingly small, due to the use of the hash embeddings. Recommended values are between `2000` and `10000`. ~~int~~ |
|
||||
| `window_size` | The number of tokens on either side to concatenate during the convolutions. The receptive field of the CNN will be `depth * (window_size * 2 + 1)`, so a 4-layer network with a window size of `2` will be sensitive to 17 words at a time. Recommended value is `1`. ~~int~~ |
|
||||
| `maxout_pieces` | The number of pieces to use in the maxout non-linearity. If `1`, the [`Mish`](https://thinc.ai/docs/api-layers#mish) non-linearity is used instead. Recommended values are `1`-`3`. ~~int~~ |
|
||||
| `subword_features` | Whether to also embed subword features, specifically the prefix, suffix and word shape. This is recommended for alphabetic languages like English, but not if single-character tokens are used for a language such as Chinese. ~~bool~~ |
|
||||
| `pretrained_vectors` | Whether to also use static vectors. ~~bool~~ |
|
||||
| **CREATES** | The model using the architecture. ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||
|
||||
### spacy.Tok2Vec.v1 {#Tok2Vec}
|
||||
|
||||
> #### Example config
|
||||
|
@ -72,7 +42,8 @@ features and a CNN with layer-normalized maxout.
|
|||
> # ...
|
||||
> ```
|
||||
|
||||
Construct a tok2vec model out of embedding and encoding subnetworks. See the
|
||||
Construct a tok2vec model out of two subnetworks: one for embedding and one for
|
||||
encoding. See the
|
||||
["Embed, Encode, Attend, Predict"](https://explosion.ai/blog/deep-learning-formula-nlp)
|
||||
blog post for background.
|
||||
|
||||
|
@ -82,6 +53,39 @@ blog post for background.
|
|||
| `encode` | Encode context into the embeddings, using an architecture such as a CNN, BiLSTM or transformer. For example, [MaxoutWindowEncoder](/api/architectures#MaxoutWindowEncoder). ~~Model[List[Floats2d], List[Floats2d]]~~ |
|
||||
| **CREATES** | The model using the architecture. ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||
|
||||
### spacy.HashEmbedCNN.v1 {#HashEmbedCNN}
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
> ```ini
|
||||
> [model]
|
||||
> @architectures = "spacy.HashEmbedCNN.v1"
|
||||
> pretrained_vectors = null
|
||||
> width = 96
|
||||
> depth = 4
|
||||
> embed_size = 2000
|
||||
> window_size = 1
|
||||
> maxout_pieces = 3
|
||||
> subword_features = true
|
||||
> ```
|
||||
|
||||
Build spaCy's "standard" tok2vec layer. This layer is defined by a
|
||||
[MultiHashEmbed](/api/architectures#MultiHashEmbed) embedding layer that uses
|
||||
subword features, and a
|
||||
[MaxoutWindowEncoder](/api/architectures#MaxoutWindowEncoder) encoding layer
|
||||
consisting of a CNN and a layer-normalized maxout activation function.
|
||||
|
||||
| Name | Description |
|
||||
| -------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `width` | The width of the input and output. These are required to be the same, so that residual connections can be used. Recommended values are `96`, `128` or `300`. ~~int~~ |
|
||||
| `depth` | The number of convolutional layers to use. Recommended values are between `2` and `8`. ~~int~~ |
|
||||
| `embed_size` | The number of rows in the hash embedding tables. This can be surprisingly small, due to the use of the hash embeddings. Recommended values are between `2000` and `10000`. ~~int~~ |
|
||||
| `window_size` | The number of tokens on either side to concatenate during the convolutions. The receptive field of the CNN will be `depth * (window_size * 2 + 1)`, so a 4-layer network with a window size of `2` will be sensitive to 17 words at a time. Recommended value is `1`. ~~int~~ |
|
||||
| `maxout_pieces` | The number of pieces to use in the maxout non-linearity. If `1`, the [`Mish`](https://thinc.ai/docs/api-layers#mish) non-linearity is used instead. Recommended values are `1`-`3`. ~~int~~ |
|
||||
| `subword_features` | Whether to also embed subword features, specifically the prefix, suffix and word shape. This is recommended for alphabetic languages like English, but not if single-character tokens are used for a language such as Chinese. ~~bool~~ |
|
||||
| `pretrained_vectors` | Whether to also use static vectors. ~~bool~~ |
|
||||
| **CREATES** | The model using the architecture. ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||
|
||||
### spacy.Tok2VecListener.v1 {#Tok2VecListener}
|
||||
|
||||
> #### Example config
|
||||
|
|
|
@ -10,49 +10,72 @@ menu:
|
|||
next: /usage/projects
|
||||
---
|
||||
|
||||
A **model architecture** is a function that wires up a
|
||||
[Thinc `Model`](https://thinc.ai/docs/api-model) instance, which you can then
|
||||
use in a component or as a layer of a larger network. You can use Thinc as a
|
||||
thin wrapper around frameworks such as PyTorch, TensorFlow or MXNet, or you can
|
||||
implement your logic in Thinc directly. spaCy's built-in components will never
|
||||
construct their `Model` instances themselves, so you won't have to subclass the
|
||||
component to change its model architecture. You can just **update the config**
|
||||
so that it refers to a different registered function. Once the component has
|
||||
been created, its model instance has already been assigned, so you cannot change
|
||||
its model architecture. The architecture is like a recipe for the network, and
|
||||
you can't change the recipe once the dish has already been prepared. You have to
|
||||
make a new one.
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from thinc.api import Model, chain
|
||||
>
|
||||
> @spacy.registry.architectures.register("model.v1")
|
||||
> def build_model(width: int, classes: int) -> Model:
|
||||
> tok2vec = build_tok2vec(width)
|
||||
> output_layer = build_output_layer(width, classes)
|
||||
> model = chain(tok2vec, output_layer)
|
||||
> return model
|
||||
> ```
|
||||
|
||||
A **model architecture** is a function that wires up a
|
||||
[Thinc `Model`](https://thinc.ai/docs/api-model) instance. It describes the
|
||||
neural network that is run internally as part of a component in a spaCy
|
||||
pipeline. To define the actual architecture, you can implement your logic in
|
||||
Thinc directly, or you can use Thinc as a thin wrapper around frameworks such as
|
||||
PyTorch, TensorFlow and MXNet. Each Model can also be used as a sublayer of a
|
||||
larger network, allowing you to freely combine implementations from different
|
||||
frameworks into one `Thinc` Model.
|
||||
|
||||
spaCy's built-in components require a `Model` instance to be passed to them via
|
||||
the config system. To change the model architecture of an existing component,
|
||||
you just need to [**update the config**](#swap-architectures) so that it refers
|
||||
to a different registered function. Once the component has been created from
|
||||
this config, you won't be able to change it anymore. The architecture is like a
|
||||
recipe for the network, and you can't change the recipe once the dish has
|
||||
already been prepared. You have to make a new one.
|
||||
|
||||
```ini
|
||||
### config.cfg (excerpt)
|
||||
[components.tagger]
|
||||
factory = "tagger"
|
||||
|
||||
[components.tagger.model]
|
||||
@architectures = "model.v1"
|
||||
width = 512
|
||||
classes = 16
|
||||
```
|
||||
|
||||
## Type signatures {#type-sigs}
|
||||
|
||||
<!-- TODO: update example, maybe simplify definition? -->
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> @spacy.registry.architectures.register("spacy.Tagger.v1")
|
||||
> def build_tagger_model(
|
||||
> tok2vec: Model[List[Doc], List[Floats2d]], nO: Optional[int] = None
|
||||
> ) -> Model[List[Doc], List[Floats2d]]:
|
||||
> t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
|
||||
> output_layer = Softmax(nO, t2v_width, init_W=zero_init)
|
||||
> softmax = with_array(output_layer)
|
||||
> model = chain(tok2vec, softmax)
|
||||
> model.set_ref("tok2vec", tok2vec)
|
||||
> model.set_ref("softmax", output_layer)
|
||||
> model.set_ref("output_layer", output_layer)
|
||||
> from typing import List
|
||||
> from thinc.api import Model, chain
|
||||
> from thinc.types import Floats2d
|
||||
> def chain_model(
|
||||
> tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
> layer1: Model[List[Floats2d], Floats2d],
|
||||
> layer2: Model[Floats2d, Floats2d]
|
||||
> ) -> Model[List[Doc], Floats2d]:
|
||||
> model = chain(tok2vec, layer1, layer2)
|
||||
> return model
|
||||
> ```
|
||||
|
||||
The Thinc `Model` class is a **generic type** that can specify its input and
|
||||
The Thinc `Model` class is a **generic type** that can specify its input and
|
||||
output types. Python uses a square-bracket notation for this, so the type
|
||||
~~Model[List, Dict]~~ says that each batch of inputs to the model will be a
|
||||
list, and the outputs will be a dictionary. Both `typing.List` and `typing.Dict`
|
||||
are also generics, allowing you to be more specific about the data. For
|
||||
instance, you can write ~~Model[List[Doc], Dict[str, float]]~~ to specify that
|
||||
the model expects a list of [`Doc`](/api/doc) objects as input, and returns a
|
||||
dictionary mapping strings to floats. Some of the most common types you'll see
|
||||
are:
|
||||
list, and the outputs will be a dictionary. You can be even more specific and
|
||||
write for instance~~Model[List[Doc], Dict[str, float]]~~ to specify that the
|
||||
model expects a list of [`Doc`](/api/doc) objects as input, and returns a
|
||||
dictionary mapping of strings to floats. Some of the most common types you'll
|
||||
see are:
|
||||
|
||||
| Type | Description |
|
||||
| ------------------ | ---------------------------------------------------------------------------------------------------- |
|
||||
|
@ -61,7 +84,7 @@ are:
|
|||
| ~~Ints2d~~ | A two-dimensional `numpy` or `cupy` array of integers. Common dtypes include uint64, int32 and int8. |
|
||||
| ~~List[Floats2d]~~ | A list of two-dimensional arrays, generally with one array per `Doc` and one row per token. |
|
||||
| ~~Ragged~~ | A container to handle variable-length sequence data in an unpadded contiguous array. |
|
||||
| ~~Padded~~ | A container to handle variable-length sequence data in a passed contiguous array. |
|
||||
| ~~Padded~~ | A container to handle variable-length sequence data in a padded contiguous array. |
|
||||
|
||||
The model type signatures help you figure out which model architectures and
|
||||
components can **fit together**. For instance, the
|
||||
|
@ -77,10 +100,10 @@ interchangeably. There are many other ways they could be incompatible. However,
|
|||
if the types don't match, they almost surely _won't_ be compatible. This little
|
||||
bit of validation goes a long way, especially if you
|
||||
[configure your editor](https://thinc.ai/docs/usage-type-checking) or other
|
||||
tools to highlight these errors early. Thinc will also verify that your types
|
||||
match correctly when your config file is processed at the beginning of training.
|
||||
tools to highlight these errors early. The config file is also validated at the
|
||||
beginning of training, to verify that all the types match correctly.
|
||||
|
||||
<Infobox title="Tip: Static type checking in your editor" emoji="💡">
|
||||
<Accordion title="Tip: Static type checking in your editor">
|
||||
|
||||
If you're using a modern editor like Visual Studio Code, you can
|
||||
[set up `mypy`](https://thinc.ai/docs/usage-type-checking#install) with the
|
||||
|
@ -89,35 +112,114 @@ code.
|
|||
|
||||
[![](../images/thinc_mypy.jpg)](https://thinc.ai/docs/usage-type-checking#linting)
|
||||
|
||||
</Infobox>
|
||||
</Accordion>
|
||||
|
||||
## Swapping model architectures {#swap-architectures}
|
||||
|
||||
<!-- TODO: textcat example, using different architecture in the config -->
|
||||
If no model is specified for the [`TextCategorizer`](/api/textcategorizer), the
|
||||
[TextCatEnsemble](/api/architectures#TextCatEnsemble) architecture is used by
|
||||
default. This architecture combines a simpel bag-of-words model with a neural
|
||||
network, usually resulting in the most accurate results, but at the cost of
|
||||
speed. The config file for this model would look something like this:
|
||||
|
||||
```ini
|
||||
### config.cfg (excerpt)
|
||||
[components.textcat]
|
||||
factory = "textcat"
|
||||
labels = []
|
||||
|
||||
[components.textcat.model]
|
||||
@architectures = "spacy.TextCatEnsemble.v1"
|
||||
exclusive_classes = false
|
||||
pretrained_vectors = null
|
||||
width = 64
|
||||
conv_depth = 2
|
||||
embed_size = 2000
|
||||
window_size = 1
|
||||
ngram_size = 1
|
||||
dropout = 0
|
||||
nO = null
|
||||
```
|
||||
|
||||
spaCy has two additional built-in `textcat` architectures, and you can easily
|
||||
use those by swapping out the definition of the textcat's model. For instance,
|
||||
to use the simple and fast bag-of-words model
|
||||
[TextCatBOW](/api/architectures#TextCatBOW), you can change the config to:
|
||||
|
||||
```ini
|
||||
### config.cfg (excerpt) {highlight="6-10"}
|
||||
[components.textcat]
|
||||
factory = "textcat"
|
||||
labels = []
|
||||
|
||||
[components.textcat.model]
|
||||
@architectures = "spacy.TextCatBOW.v1"
|
||||
exclusive_classes = false
|
||||
ngram_size = 1
|
||||
no_output_layer = false
|
||||
nO = null
|
||||
```
|
||||
|
||||
For details on all pre-defined architectures shipped with spaCy and how to
|
||||
configure them, check out the [model architectures](/api/architectures)
|
||||
documentation.
|
||||
|
||||
### Defining sublayers {#sublayers}
|
||||
|
||||
Model architecture functions often accept **sublayers as arguments**, so that
|
||||
Model architecture functions often accept **sublayers as arguments**, so that
|
||||
you can try **substituting a different layer** into the network. Depending on
|
||||
how the architecture function is structured, you might be able to define your
|
||||
network structure entirely through the [config system](/usage/training#config),
|
||||
using layers that have already been defined. The
|
||||
[transformers documentation](/usage/embeddings-transformers#transformers)
|
||||
section shows a common example of swapping in a different sublayer.
|
||||
using layers that have already been defined.
|
||||
|
||||
In most neural network models for NLP, the most important parts of the network
|
||||
are what we refer to as the
|
||||
[embed and encode](https://explosion.ai/blog/embed-encode-attend-predict) steps.
|
||||
[embed and encode](https://explosion.ai/blog/deep-learning-formula-nlp) steps.
|
||||
These steps together compute dense, context-sensitive representations of the
|
||||
tokens. Most of spaCy's default architectures accept a
|
||||
[`tok2vec` embedding layer](/api/architectures#tok2vec-arch) as an argument, so
|
||||
you can control this important part of the network separately. This makes it
|
||||
easy to **switch between** transformer, CNN, BiLSTM or other feature extraction
|
||||
approaches. And if you want to define your own solution, all you need to do is
|
||||
register a ~~Model[List[Doc], List[Floats2d]]~~ architecture function, and
|
||||
you'll be able to try it out in any of spaCy components.
|
||||
tokens, and their combination forms a typical
|
||||
[`Tok2Vec`](/api/architectures#Tok2Vec) layer:
|
||||
|
||||
<!-- TODO: example of swapping sublayers -->
|
||||
```ini
|
||||
### config.cfg (excerpt)
|
||||
[components.tok2vec]
|
||||
factory = "tok2vec"
|
||||
|
||||
[components.tok2vec.model]
|
||||
@architectures = "spacy.Tok2Vec.v1"
|
||||
|
||||
[components.tok2vec.model.embed]
|
||||
@architectures = "spacy.MultiHashEmbed.v1"
|
||||
# ...
|
||||
|
||||
[components.tok2vec.model.encode]
|
||||
@architectures = "spacy.MaxoutWindowEncoder.v1"
|
||||
# ...
|
||||
```
|
||||
|
||||
By defining these sublayers specifically, it becomes straightforward to swap out
|
||||
a sublayer for another one, for instance changing the first sublayer to a
|
||||
character embedding with the [CharacterEmbed](/api/architectures#CharacterEmbed)
|
||||
architecture:
|
||||
|
||||
```ini
|
||||
### config.cfg (excerpt)
|
||||
[components.tok2vec.model.embed]
|
||||
@architectures = "spacy.CharacterEmbed.v1"
|
||||
# ...
|
||||
|
||||
[components.tok2vec.model.encode]
|
||||
@architectures = "spacy.MaxoutWindowEncoder.v1"
|
||||
# ...
|
||||
```
|
||||
|
||||
Most of spaCy's default architectures accept a `tok2vec` layer as a sublayer
|
||||
within the larger task-specific neural network. This makes it easy to **switch
|
||||
between** transformer, CNN, BiLSTM or other feature extraction approaches. The
|
||||
[transformers documentation](/usage/embeddings-transformers#training-custom-model)
|
||||
section shows an example of swapping out a model's standard `tok2vec` layer with
|
||||
a transformer. And if you want to define your own solution, all you need to do
|
||||
is register a ~~Model[List[Doc], List[Floats2d]]~~ architecture function, and
|
||||
you'll be able to try it out in any of the spaCy components.
|
||||
|
||||
## Wrapping PyTorch, TensorFlow and other frameworks {#frameworks}
|
||||
|
||||
|
|
|
@ -377,7 +377,8 @@ A **model architecture** is a function that wires up a Thinc
|
|||
component or as a layer of a larger network. You can use Thinc as a thin
|
||||
[wrapper around frameworks](https://thinc.ai/docs/usage-frameworks) such as
|
||||
PyTorch, TensorFlow or MXNet, or you can implement your logic in Thinc
|
||||
[directly](https://thinc.ai/docs/usage-models).
|
||||
[directly](https://thinc.ai/docs/usage-models). For more details and examples,
|
||||
see the usage guide on [layers and architectures](/usage/layers-architectures).
|
||||
|
||||
spaCy's built-in components will never construct their `Model` instances
|
||||
themselves, so you won't have to subclass the component to change its model
|
||||
|
@ -395,8 +396,6 @@ different tasks. For example:
|
|||
| [TransitionBasedParser](/api/architectures#TransitionBasedParser) | Build a [transition-based parser](https://explosion.ai/blog/parsing-english-in-python) model used in the default [`EntityRecognizer`](/api/entityrecognizer) and [`DependencyParser`](/api/dependencyparser). ~~Model[List[Docs], List[List[Floats2d]]]~~ |
|
||||
| [TextCatEnsemble](/api/architectures#TextCatEnsemble) | Stacked ensemble of a bag-of-words model and a neural network model with an internal CNN embedding layer. Used in the default [`TextCategorizer`](/api/textcategorizer). ~~Model[List[Doc], Floats2d]~~ |
|
||||
|
||||
<!-- TODO: link to not yet existing usage page on custom architectures etc. -->
|
||||
|
||||
### Metrics, training output and weighted scores {#metrics}
|
||||
|
||||
When you train a model using the [`spacy train`](/api/cli#train) command, you'll
|
||||
|
@ -474,11 +473,9 @@ Each custom function can have any numbers of arguments that are passed in via
|
|||
the [config](#config), just the built-in functions. If your function defines
|
||||
**default argument values**, spaCy is able to auto-fill your config when you run
|
||||
[`init fill-config`](/api/cli#init-fill-config). If you want to make sure that a
|
||||
given parameter is always explicitely set in the config, avoid setting a default
|
||||
given parameter is always explicitly set in the config, avoid setting a default
|
||||
value for it.
|
||||
|
||||
<!-- TODO: possibly link to new (not yet created) page on creating models ? -->
|
||||
|
||||
### Training with custom code {#custom-code}
|
||||
|
||||
> #### Example
|
||||
|
@ -669,10 +666,9 @@ def custom_logger(log_path):
|
|||
|
||||
#### Example: Custom batch size schedule {#custom-code-schedule}
|
||||
|
||||
For example, let's say you've implemented your own batch size schedule to use
|
||||
during training. The `@spacy.registry.schedules` decorator lets you register
|
||||
that function in the `schedules` [registry](/api/top-level#registry) and assign
|
||||
it a string name:
|
||||
You can also implement your own batch size schedule to use during training. The
|
||||
`@spacy.registry.schedules` decorator lets you register that function in the
|
||||
`schedules` [registry](/api/top-level#registry) and assign it a string name:
|
||||
|
||||
> #### Why the version in the name?
|
||||
>
|
||||
|
@ -806,7 +802,35 @@ def filter_batch(size: int) -> Callable[[Iterable[Example]], Iterator[List[Examp
|
|||
|
||||
### Defining custom architectures {#custom-architectures}
|
||||
|
||||
<!-- TODO: this should probably move to new section on models -->
|
||||
Built-in pipeline components such as the tagger or named entity recognizer are
|
||||
constructed with default neural network [models](/api/architectures). You can
|
||||
change the model architecture entirely by implementing your own custom models
|
||||
and providing those in the config when creating the pipeline component. See the
|
||||
documentation on [layers and model architectures](/usage/layers-architectures)
|
||||
for more details.
|
||||
|
||||
> ```ini
|
||||
> ### config.cfg
|
||||
> [components.tagger]
|
||||
> factory = "tagger"
|
||||
>
|
||||
> [components.tagger.model]
|
||||
> @architectures = "custom_neural_network.v1"
|
||||
> output_width = 512
|
||||
> ```
|
||||
|
||||
```python
|
||||
### functions.py
|
||||
from typing import List
|
||||
from thinc.types import Floats2d
|
||||
from thinc.api import Model
|
||||
import spacy
|
||||
from spacy.tokens import Doc
|
||||
|
||||
@spacy.registry.architectures("custom_neural_network.v1")
|
||||
def MyModel(output_width: int) -> Model[List[Doc], List[Floats2d]]:
|
||||
return create_model(output_width)
|
||||
```
|
||||
|
||||
## Internal training API {#api}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user