mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-30 19:24:07 +03:00
Support Mish activation (requires Thinc 7.3) (#4536)
* Add arch for MishWindowEncoder * Support mish in tok2vec and conv window >=2 * Pass new tok2vec settings from parser * Syntax error * Fix tok2vec setting * Fix registration of MishWindowEncoder * Fix receptive field setting * Fix mish arch * Pass more options from parser * Support more tok2vec options in pretrain * Require thinc 7.3 * Add docs [ci skip] * Require thinc 7.3.0.dev0 to run CI * Run black * Fix typo * Update Thinc version Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
parent
96bb8f2187
commit
d5509e0989
|
@ -1,7 +1,7 @@
|
||||||
# Our libraries
|
# Our libraries
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
thinc>=7.2.0,<7.3.0
|
thinc>=7.3.0,<7.4.0
|
||||||
blis>=0.4.0,<0.5.0
|
blis>=0.4.0,<0.5.0
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
wasabi>=0.3.0,<1.1.0
|
wasabi>=0.3.0,<1.1.0
|
||||||
|
|
|
@ -38,14 +38,14 @@ setup_requires =
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
thinc>=7.2.0,<7.3.0
|
thinc>=7.3.0,<7.4.0
|
||||||
install_requires =
|
install_requires =
|
||||||
setuptools
|
setuptools
|
||||||
numpy>=1.15.0
|
numpy>=1.15.0
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
thinc>=7.2.0,<7.3.0
|
thinc>=7.3.0,<7.4.0
|
||||||
blis>=0.4.0,<0.5.0
|
blis>=0.4.0,<0.5.0
|
||||||
plac>=0.9.6,<1.2.0
|
plac>=0.9.6,<1.2.0
|
||||||
requests>=2.13.0,<3.0.0
|
requests>=2.13.0,<3.0.0
|
||||||
|
|
26
spacy/_ml.py
26
spacy/_ml.py
|
@ -321,6 +321,7 @@ def Tok2Vec(width, embed_size, **kwargs):
|
||||||
char_embed = kwargs.get("char_embed", False)
|
char_embed = kwargs.get("char_embed", False)
|
||||||
conv_depth = kwargs.get("conv_depth", 4)
|
conv_depth = kwargs.get("conv_depth", 4)
|
||||||
bilstm_depth = kwargs.get("bilstm_depth", 0)
|
bilstm_depth = kwargs.get("bilstm_depth", 0)
|
||||||
|
conv_window = kwargs.get("conv_window", 1)
|
||||||
|
|
||||||
cols = ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"]
|
cols = ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"]
|
||||||
|
|
||||||
|
@ -362,16 +363,21 @@ def Tok2Vec(width, embed_size, **kwargs):
|
||||||
"column": cols.index(ID),
|
"column": cols.index(ID),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cnn_cfg = {
|
if cnn_maxout_pieces >= 2:
|
||||||
"arch": "spacy.MaxoutWindowEncoder.v1",
|
cnn_cfg = {
|
||||||
"config": {
|
"arch": "spacy.MaxoutWindowEncoder.v1",
|
||||||
"width": width,
|
"config": {
|
||||||
"window_size": 1,
|
"width": width,
|
||||||
"pieces": cnn_maxout_pieces,
|
"window_size": conv_window,
|
||||||
"depth": conv_depth,
|
"pieces": cnn_maxout_pieces,
|
||||||
},
|
"depth": conv_depth,
|
||||||
}
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
cnn_cfg = {
|
||||||
|
"arch": "spacy.MishWindowEncoder.v1",
|
||||||
|
"config": {"width": width, "window_size": conv_window, "depth": conv_depth},
|
||||||
|
}
|
||||||
bilstm_cfg = {
|
bilstm_cfg = {
|
||||||
"arch": "spacy.TorchBiLSTMEncoder.v1",
|
"arch": "spacy.TorchBiLSTMEncoder.v1",
|
||||||
"config": {"width": width, "depth": bilstm_depth},
|
"config": {"width": width, "depth": bilstm_depth},
|
||||||
|
|
|
@ -35,6 +35,10 @@ from .train import _load_pretrained_tok2vec
|
||||||
output_dir=("Directory to write models to on each epoch", "positional", None, str),
|
output_dir=("Directory to write models to on each epoch", "positional", None, str),
|
||||||
width=("Width of CNN layers", "option", "cw", int),
|
width=("Width of CNN layers", "option", "cw", int),
|
||||||
depth=("Depth of CNN layers", "option", "cd", int),
|
depth=("Depth of CNN layers", "option", "cd", int),
|
||||||
|
cnn_window=("Window size for CNN layers", "option", "cW", int),
|
||||||
|
cnn_pieces=("Maxout size for CNN layers. 1 for Mish", "option", "cP", int),
|
||||||
|
use_chars=("Whether to use character-based embedding", "flag", "chr", bool),
|
||||||
|
sa_depth=("Depth of self-attention layers", "option", "sa", int),
|
||||||
bilstm_depth=("Depth of BiLSTM layers (requires PyTorch)", "option", "lstm", int),
|
bilstm_depth=("Depth of BiLSTM layers (requires PyTorch)", "option", "lstm", int),
|
||||||
embed_rows=("Number of embedding rows", "option", "er", int),
|
embed_rows=("Number of embedding rows", "option", "er", int),
|
||||||
loss_func=(
|
loss_func=(
|
||||||
|
@ -81,7 +85,11 @@ def pretrain(
|
||||||
output_dir,
|
output_dir,
|
||||||
width=96,
|
width=96,
|
||||||
depth=4,
|
depth=4,
|
||||||
bilstm_depth=2,
|
bilstm_depth=0,
|
||||||
|
cnn_pieces=3,
|
||||||
|
sa_depth=0,
|
||||||
|
use_chars=False,
|
||||||
|
cnn_window=1,
|
||||||
embed_rows=2000,
|
embed_rows=2000,
|
||||||
loss_func="cosine",
|
loss_func="cosine",
|
||||||
use_vectors=False,
|
use_vectors=False,
|
||||||
|
@ -158,8 +166,8 @@ def pretrain(
|
||||||
conv_depth=depth,
|
conv_depth=depth,
|
||||||
pretrained_vectors=pretrained_vectors,
|
pretrained_vectors=pretrained_vectors,
|
||||||
bilstm_depth=bilstm_depth, # Requires PyTorch. Experimental.
|
bilstm_depth=bilstm_depth, # Requires PyTorch. Experimental.
|
||||||
cnn_maxout_pieces=3, # You can try setting this higher
|
subword_features=not use_chars, # Set to False for Chinese etc
|
||||||
subword_features=True, # Set to False for Chinese etc
|
cnn_maxout_pieces=cnn_pieces, # If set to 1, use Mish activation.
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
# Load in pretrained weights
|
# Load in pretrained weights
|
||||||
|
|
|
@ -16,8 +16,8 @@ def Tok2Vec(config):
|
||||||
doc2feats = make_layer(config["@doc2feats"])
|
doc2feats = make_layer(config["@doc2feats"])
|
||||||
embed = make_layer(config["@embed"])
|
embed = make_layer(config["@embed"])
|
||||||
encode = make_layer(config["@encode"])
|
encode = make_layer(config["@encode"])
|
||||||
depth = config["@encode"]["config"]["depth"]
|
field_size = getattr(encode, "receptive_field", 0)
|
||||||
tok2vec = chain(doc2feats, with_flatten(chain(embed, encode), pad=depth))
|
tok2vec = chain(doc2feats, with_flatten(chain(embed, encode), pad=field_size))
|
||||||
tok2vec.cfg = config
|
tok2vec.cfg = config
|
||||||
tok2vec.nO = encode.nO
|
tok2vec.nO = encode.nO
|
||||||
tok2vec.embed = embed
|
tok2vec.embed = embed
|
||||||
|
@ -84,6 +84,21 @@ def MaxoutWindowEncoder(config):
|
||||||
)
|
)
|
||||||
model = clone(Residual(cnn), depth)
|
model = clone(Residual(cnn), depth)
|
||||||
model.nO = nO
|
model.nO = nO
|
||||||
|
model.receptive_field = nW * depth
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_architecture("spacy.MishWindowEncoder.v1")
|
||||||
|
def MishWindowEncoder(config):
|
||||||
|
from thinc.v2v import Mish
|
||||||
|
|
||||||
|
nO = config["width"]
|
||||||
|
nW = config["window_size"]
|
||||||
|
depth = config["depth"]
|
||||||
|
|
||||||
|
cnn = chain(ExtractWindow(nW=nW), LayerNorm(Mish(nO, nO * ((nW * 2) + 1))))
|
||||||
|
model = clone(Residual(cnn), depth)
|
||||||
|
model.nO = nO
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,10 @@ cdef class Parser:
|
||||||
subword_features = util.env_opt('subword_features',
|
subword_features = util.env_opt('subword_features',
|
||||||
cfg.get('subword_features', True))
|
cfg.get('subword_features', True))
|
||||||
conv_depth = util.env_opt('conv_depth', cfg.get('conv_depth', 4))
|
conv_depth = util.env_opt('conv_depth', cfg.get('conv_depth', 4))
|
||||||
|
conv_window = util.env_opt('conv_window', cfg.get('conv_depth', 1))
|
||||||
|
t2v_pieces = util.env_opt('cnn_maxout_pieces', cfg.get('cnn_maxout_pieces', 3))
|
||||||
bilstm_depth = util.env_opt('bilstm_depth', cfg.get('bilstm_depth', 0))
|
bilstm_depth = util.env_opt('bilstm_depth', cfg.get('bilstm_depth', 0))
|
||||||
|
self_attn_depth = util.env_opt('self_attn_depth', cfg.get('self_attn_depth', 0))
|
||||||
if depth != 1:
|
if depth != 1:
|
||||||
raise ValueError(TempErrors.T004.format(value=depth))
|
raise ValueError(TempErrors.T004.format(value=depth))
|
||||||
parser_maxout_pieces = util.env_opt('parser_maxout_pieces',
|
parser_maxout_pieces = util.env_opt('parser_maxout_pieces',
|
||||||
|
@ -69,6 +72,8 @@ cdef class Parser:
|
||||||
pretrained_vectors = cfg.get('pretrained_vectors', None)
|
pretrained_vectors = cfg.get('pretrained_vectors', None)
|
||||||
tok2vec = Tok2Vec(token_vector_width, embed_size,
|
tok2vec = Tok2Vec(token_vector_width, embed_size,
|
||||||
conv_depth=conv_depth,
|
conv_depth=conv_depth,
|
||||||
|
conv_window=conv_window,
|
||||||
|
cnn_maxout_pieces=t2v_pieces,
|
||||||
subword_features=subword_features,
|
subword_features=subword_features,
|
||||||
pretrained_vectors=pretrained_vectors,
|
pretrained_vectors=pretrained_vectors,
|
||||||
bilstm_depth=bilstm_depth)
|
bilstm_depth=bilstm_depth)
|
||||||
|
@ -90,7 +95,12 @@ cdef class Parser:
|
||||||
'hidden_width': hidden_width,
|
'hidden_width': hidden_width,
|
||||||
'maxout_pieces': parser_maxout_pieces,
|
'maxout_pieces': parser_maxout_pieces,
|
||||||
'pretrained_vectors': pretrained_vectors,
|
'pretrained_vectors': pretrained_vectors,
|
||||||
'bilstm_depth': bilstm_depth
|
'bilstm_depth': bilstm_depth,
|
||||||
|
'self_attn_depth': self_attn_depth,
|
||||||
|
'conv_depth': conv_depth,
|
||||||
|
'conv_window': conv_window,
|
||||||
|
'embed_size': embed_size,
|
||||||
|
'cnn_maxout_pieces': t2v_pieces
|
||||||
}
|
}
|
||||||
return ParserModel(tok2vec, lower, upper), cfg
|
return ParserModel(tok2vec, lower, upper), cfg
|
||||||
|
|
||||||
|
|
|
@ -446,8 +446,10 @@ improvement.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir]
|
$ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir]
|
||||||
[--width] [--depth] [--embed-rows] [--loss_func] [--dropout] [--batch-size] [--max-length] [--min-length]
|
[--width] [--depth] [--cnn-window] [--cnn-pieces] [--use-chars] [--sa-depth]
|
||||||
[--seed] [--n-iter] [--use-vectors] [--n-save_every] [--init-tok2vec] [--epoch-start]
|
[--embed-rows] [--loss_func] [--dropout] [--batch-size] [--max-length]
|
||||||
|
[--min-length] [--seed] [--n-iter] [--use-vectors] [--n-save_every]
|
||||||
|
[--init-tok2vec] [--epoch-start]
|
||||||
```
|
```
|
||||||
|
|
||||||
| Argument | Type | Description |
|
| Argument | Type | Description |
|
||||||
|
@ -457,6 +459,10 @@ $ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir]
|
||||||
| `output_dir` | positional | Directory to write models to on each epoch. |
|
| `output_dir` | positional | Directory to write models to on each epoch. |
|
||||||
| `--width`, `-cw` | option | Width of CNN layers. |
|
| `--width`, `-cw` | option | Width of CNN layers. |
|
||||||
| `--depth`, `-cd` | option | Depth of CNN layers. |
|
| `--depth`, `-cd` | option | Depth of CNN layers. |
|
||||||
|
| `--cnn-window`, `-cW` <Tag variant="new">2.2.2</Tag> | option | Window size for CNN layers. |
|
||||||
|
| `--cnn-pieces`, `-cP` <Tag variant="new">2.2.2</Tag> | option | Maxout size for CNN layers. `1` for [Mish](https://github.com/digantamisra98/Mish). |
|
||||||
|
| `--use-chars`, `-chr` <Tag variant="new">2.2.2</Tag> | flag | Whether to use character-based embedding. |
|
||||||
|
| `--sa-depth`, `-sa` <Tag variant="new">2.2.2</Tag> | option | Depth of self-attention layers. |
|
||||||
| `--embed-rows`, `-er` | option | Number of embedding rows. |
|
| `--embed-rows`, `-er` | option | Number of embedding rows. |
|
||||||
| `--loss-func`, `-L` | option | Loss function to use for the objective. Either `"L2"` or `"cosine"`. |
|
| `--loss-func`, `-L` | option | Loss function to use for the objective. Either `"L2"` or `"cosine"`. |
|
||||||
| `--dropout`, `-d` | option | Dropout rate. |
|
| `--dropout`, `-d` | option | Dropout rate. |
|
||||||
|
@ -469,7 +475,7 @@ $ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir]
|
||||||
| `--n-save-every`, `-se` | option | Save model every X batches. |
|
| `--n-save-every`, `-se` | option | Save model every X batches. |
|
||||||
| `--init-tok2vec`, `-t2v` <Tag variant="new">2.1</Tag> | option | Path to pretrained weights for the token-to-vector parts of the models. See `spacy pretrain`. Experimental. |
|
| `--init-tok2vec`, `-t2v` <Tag variant="new">2.1</Tag> | option | Path to pretrained weights for the token-to-vector parts of the models. See `spacy pretrain`. Experimental. |
|
||||||
| `--epoch-start`, `-es` <Tag variant="new">2.1.5</Tag> | option | The epoch to start counting at. Only relevant when using `--init-tok2vec` and the given weight file has been renamed. Prevents unintended overwriting of existing weight files. |
|
| `--epoch-start`, `-es` <Tag variant="new">2.1.5</Tag> | option | The epoch to start counting at. Only relevant when using `--init-tok2vec` and the given weight file has been renamed. Prevents unintended overwriting of existing weight files. |
|
||||||
| **CREATES** | weights | The pretrained weights that can be used to initialize `spacy train`. |
|
| **CREATES** | weights | The pretrained weights that can be used to initialize `spacy train`. |
|
||||||
|
|
||||||
### JSONL format for raw text {#pretrain-jsonl}
|
### JSONL format for raw text {#pretrain-jsonl}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user