Ragged tok2vec

This commit is contained in:
Matthew Honnibal 2021-12-24 16:39:52 +01:00
parent 837d241b68
commit 3b2654db8f
4 changed files with 219 additions and 15 deletions

View File

@ -1,6 +1,6 @@
from typing import Optional, List, cast
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
from thinc.types import Floats2d
from thinc.api import Model, chain, Linear, zero_init, use_ops
from thinc.types import Floats2d, Ragged
from ...errors import Errors
from ...compat import Literal
@ -12,7 +12,7 @@ from ...tokens import Doc
@registry.architectures("spacy.TransitionBasedParser.v2")
def build_tb_parser_model(
tok2vec: Model[List[Doc], List[Floats2d]],
tok2vec: Model[List[Doc], Ragged],
state_type: Literal["parser", "ner"],
extra_state_tokens: bool,
hidden_width: int,
@ -72,7 +72,7 @@ def build_tb_parser_model(
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
tok2vec = chain(
tok2vec,
cast(Model[List["Floats2d"], Floats2d], list2array()),
ragged2array(),
Linear(hidden_width, t2v_width),
)
tok2vec.set_dim("nO", hidden_width)
@ -90,6 +90,18 @@ def build_tb_parser_model(
return TransitionModel(tok2vec, lower, upper, resize_output)
def ragged2array() -> Model[Ragged, Floats2d]:
def _forward(model, X, is_train):
lengths = X.lengths
def backprop(dY):
return Ragged(dY, lengths)
return X.dataXd, backprop
return Model("ragged2array", _forward)
def _define_upper(nO, nI):
return Linear(nO=nO, nI=nI, init_W=zero_init)

View File

@ -1,6 +1,6 @@
from typing import Optional, List
from thinc.api import zero_init, with_array, Softmax, chain, Model
from thinc.types import Floats2d
from thinc.api import zero_init, with_array, Softmax, chain, Model, ragged2list
from thinc.types import Floats2d, Ragged
from ...util import registry
from ...tokens import Doc
@ -8,7 +8,7 @@ from ...tokens import Doc
@registry.architectures("spacy.Tagger.v1")
def build_tagger_model(
tok2vec: Model[List[Doc], List[Floats2d]], nO: Optional[int] = None
tok2vec: Model[List[Doc], Ragged], nO: Optional[int] = None
) -> Model[List[Doc], List[Floats2d]]:
"""Build a tagger model, using a provided token-to-vector component. The tagger
model simply adds a linear layer with softmax activation to predict scores
@ -21,7 +21,7 @@ def build_tagger_model(
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
output_layer = Softmax(nO, t2v_width, init_W=zero_init)
softmax = with_array(output_layer) # type: ignore
model = chain(tok2vec, softmax)
model = chain(tok2vec, softmax, ragged2list())
model.set_ref("tok2vec", tok2vec)
model.set_ref("softmax", output_layer)
model.set_ref("output_layer", output_layer)

View File

@ -3,6 +3,7 @@ from thinc.types import Floats2d, Ints2d, Ragged
from thinc.api import chain, clone, concatenate, with_array, with_padded
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
from thinc.api import expand_window, residual, Maxout, Mish, PyTorchLSTM
from thinc.api import with_list
from ...tokens import Doc
from ...util import registry
@ -87,6 +88,159 @@ def build_hash_embed_cnn_tok2vec(
)
@registry.architectures("spacy.HashEmbedCNN.v3")
def build_hash_embed_cnn_tok2vec(
*,
width: int,
depth: int,
embed_size: int,
window_size: int,
maxout_pieces: int,
subword_features: bool,
pretrained_vectors: Optional[bool],
) -> Model[List[Doc], Ragged]:
"""Build spaCy's 'standard' tok2vec layer, which uses hash embedding
with subword features and a CNN with layer-normalized maxout.
width (int): The width of the input and output. These are required to be the
same, so that residual connections can be used. Recommended values are
96, 128 or 300.
depth (int): The number of convolutional layers to use. Recommended values
are between 2 and 8.
window_size (int): The number of tokens on either side to concatenate during
the convolutions. The receptive field of the CNN will be
depth * (window_size * 2 + 1), so a 4-layer network with window_size of
2 will be sensitive to 20 words at a time. Recommended value is 1.
embed_size (int): The number of rows in the hash embedding tables. This can
be surprisingly small, due to the use of the hash embeddings. Recommended
values are between 2000 and 10000.
maxout_pieces (int): The number of pieces to use in the maxout non-linearity.
If 1, the Mish non-linearity is used instead. Recommended values are 1-3.
subword_features (bool): Whether to also embed subword features, specifically
the prefix, suffix and word shape. This is recommended for alphabetic
languages like English, but not if single-character tokens are used for
a language such as Chinese.
pretrained_vectors (bool): Whether to also use static vectors.
"""
if subword_features:
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
row_sizes = [embed_size, embed_size // 2, embed_size // 2, embed_size // 2]
else:
attrs = ["NORM"]
row_sizes = [embed_size]
return build_Tok2Vec_model_ragged(
embed=MultiHashEmbed_ragged(
width=width,
rows=row_sizes,
attrs=attrs,
include_static_vectors=bool(pretrained_vectors),
),
encode=MaxoutWindowEncoder(
width=width,
depth=depth,
window_size=window_size,
maxout_pieces=maxout_pieces,
),
)
@registry.architectures("spacy.Tok2Vec.v3")
def build_Tok2Vec_model_ragged(
embed: Model[List[Doc], Ragged],
encode: Model[List[Floats2d], Ragged],
) -> Model[List[Doc], Ragged]:
"""Construct a tok2vec model out of embedding and encoding subnetworks.
See https://explosion.ai/blog/deep-learning-formula-nlp
embed (Model[List[Doc], List[Floats2d]]): Embed tokens into context-independent
word vector representations.
encode (Model[List[Floats2d], List[Floats2d]]): Encode context into the
embeddings, using an architecture such as a CNN, BiLSTM or transformer.
"""
tok2vec = chain(embed, with_array(encode))
if encode.has_dim("nO"):
tok2vec.set_dim("nO", encode.get_dim("nO"))
tok2vec.set_ref("embed", embed)
tok2vec.set_ref("encode", encode)
return tok2vec
@registry.architectures("spacy.MultiHashEmbed.v3")
def MultiHashEmbed_ragged(
width: int,
attrs: List[Union[str, int]],
rows: List[int],
include_static_vectors: bool,
) -> Model[List[Doc], Ragged]:
"""Construct an embedding layer that separately embeds a number of lexical
attributes using hash embedding, concatenates the results, and passes it
through a feed-forward subnetwork to build a mixed representation.
The features used can be configured with the 'attrs' argument. The suggested
attributes are NORM, PREFIX, SUFFIX and SHAPE. This lets the model take into
account some subword information, without constructing a fully character-based
representation. If pretrained vectors are available, they can be included in
the representation as well, with the vectors table will be kept static
(i.e. it's not updated).
The `width` parameter specifies the output width of the layer and the widths
of all embedding tables. If static vectors are included, a learned linear
layer is used to map the vectors to the specified width before concatenating
it with the other embedding outputs. A single Maxout layer is then used to
reduce the concatenated vectors to the final width.
The `rows` parameter controls the number of rows used by the `HashEmbed`
tables. The HashEmbed layer needs surprisingly few rows, due to its use of
the hashing trick. Generally between 2000 and 10000 rows is sufficient,
even for very large vocabularies. A number of rows must be specified for each
table, so the `rows` list must be of the same length as the `attrs` parameter.
width (int): The output width. Also used as the width of the embedding tables.
Recommended values are between 64 and 300.
attrs (list of attr IDs): The token attributes to embed. A separate
embedding table will be constructed for each attribute.
rows (List[int]): The number of rows in the embedding tables. Must have the
same length as attrs.
include_static_vectors (bool): Whether to also use static word vectors.
Requires a vectors table to be loaded in the Doc objects' vocab.
"""
if len(rows) != len(attrs):
raise ValueError(f"Mismatched lengths: {len(rows)} vs {len(attrs)}")
seed = 7
def make_hash_embed(index):
nonlocal seed
seed += 1
return HashEmbed(width, rows[index], column=index, seed=seed, dropout=0.0)
embeddings = [make_hash_embed(i) for i in range(len(attrs))]
concat_size = width * (len(embeddings) + include_static_vectors)
max_out: Model[Ragged, Ragged] = with_array(
Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True) # type: ignore
)
if include_static_vectors:
feature_extractor: Model[List[Doc], Ragged] = chain(
FeatureExtractor(attrs),
cast(Model[List[Ints2d], Ragged], list2ragged()),
with_array(concatenate(*embeddings)),
)
model = chain(
concatenate(
feature_extractor,
StaticVectors(width, dropout=0.0),
),
max_out,
)
else:
model = chain(
FeatureExtractor(list(attrs)),
cast(Model[List[Ints2d], Ragged], list2ragged()),
with_array(concatenate(*embeddings)),
max_out,
)
return model
@registry.architectures("spacy.Tok2Vec.v2")
def build_Tok2Vec_model(
embed: Model[List[Doc], List[Floats2d]],
@ -295,6 +449,38 @@ def MaxoutWindowEncoder(
return with_array(model, pad=receptive_field) # type: ignore[arg-type]
@registry.architectures("spacy.MaxoutWindowEncoder.v3")
def MaxoutWindowEncoder_ragged(
width: int, window_size: int, maxout_pieces: int, depth: int
) -> Model[Ragged, Ragged]:
"""Encode context using convolutions with maxout activation, layer
normalization and residual connections.
width (int): The input and output width. These are required to be the same,
to allow residual connections. This value will be determined by the
width of the inputs. Recommended values are between 64 and 300.
window_size (int): The number of words to concatenate around each token
to construct the convolution. Recommended value is 1.
maxout_pieces (int): The number of maxout pieces to use. Recommended
values are 2 or 3.
depth (int): The number of convolutional layers. Recommended value is 4.
"""
cnn = chain(
expand_window(window_size=window_size),
Maxout(
nO=width,
nI=width * ((window_size * 2) + 1),
nP=maxout_pieces,
dropout=0.0,
normalize=True,
),
)
model = clone(residual(cnn), depth) # type: ignore[arg-type]
model.set_dim("nO", width)
receptive_field = window_size * depth
return with_array(model, pad=receptive_field) # type: ignore[arg-type]
@registry.architectures("spacy.MishWindowEncoder.v2")
def MishWindowEncoder(
width: int, window_size: int, depth: int

View File

@ -1,5 +1,6 @@
from typing import Sequence, Iterable, Optional, Dict, Callable, List, Any
from thinc.api import Model, set_dropout_rate, Optimizer, Config
from thinc.types import Ragged
from itertools import islice
from .trainable_pipe import TrainablePipe
@ -132,7 +133,8 @@ class Tok2Vec(TrainablePipe):
DOCS: https://spacy.io/api/tok2vec#set_annotations
"""
for doc, tokvecs in zip(docs, tokvecses):
for i, doc in enumerate(docs):
tokvecs = tokvecses[i].dataXd
assert tokvecs.shape[0] == len(doc)
doc.tensor = tokvecs
@ -162,7 +164,9 @@ class Tok2Vec(TrainablePipe):
docs = [eg.predicted for eg in examples]
set_dropout_rate(self.model, drop)
tokvecs, bp_tokvecs = self.model.begin_update(docs)
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
d_tokvecs = Ragged(
self.model.ops.alloc2f(*tokvecs.dataXd.shape), tokvecs.lengths
)
losses.setdefault(self.name, 0.0)
def accumulate_gradient(one_d_tokvecs):
@ -170,10 +174,11 @@ class Tok2Vec(TrainablePipe):
to all but the last listener. Only the last one does the backprop.
"""
nonlocal d_tokvecs
for i in range(len(one_d_tokvecs)):
d_tokvecs[i] += one_d_tokvecs[i]
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
d_tokvecs.data += one_d_tokvecs.data
losses[self.name] += float((one_d_tokvecs.data ** 2).sum())
return Ragged(
self.model.ops.alloc2f(*tokvecs.dataXd.shape), tokvecs.lengths
)
def backprop(one_d_tokvecs):
"""Callback to actually do the backprop. Passed to last listener."""
@ -302,7 +307,8 @@ def forward(model: Tok2VecListener, inputs, is_train: bool):
outputs.append(model.ops.alloc2f(len(doc), width))
else:
outputs.append(doc.tensor)
return outputs, lambda dX: []
lengths = model.ops.asarray1i([x.shape[0] for x in outputs])
return Ragged(model.ops.flatten(outputs), lengths), lambda dX: []
def _empty_backprop(dX): # for pickling