mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Ragged tok2vec
This commit is contained in:
parent
837d241b68
commit
3b2654db8f
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional, List, cast
|
||||
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
|
||||
from thinc.types import Floats2d
|
||||
from thinc.api import Model, chain, Linear, zero_init, use_ops
|
||||
from thinc.types import Floats2d, Ragged
|
||||
|
||||
from ...errors import Errors
|
||||
from ...compat import Literal
|
||||
|
@ -12,7 +12,7 @@ from ...tokens import Doc
|
|||
|
||||
@registry.architectures("spacy.TransitionBasedParser.v2")
|
||||
def build_tb_parser_model(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
tok2vec: Model[List[Doc], Ragged],
|
||||
state_type: Literal["parser", "ner"],
|
||||
extra_state_tokens: bool,
|
||||
hidden_width: int,
|
||||
|
@ -72,7 +72,7 @@ def build_tb_parser_model(
|
|||
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
|
||||
tok2vec = chain(
|
||||
tok2vec,
|
||||
cast(Model[List["Floats2d"], Floats2d], list2array()),
|
||||
ragged2array(),
|
||||
Linear(hidden_width, t2v_width),
|
||||
)
|
||||
tok2vec.set_dim("nO", hidden_width)
|
||||
|
@ -90,6 +90,18 @@ def build_tb_parser_model(
|
|||
return TransitionModel(tok2vec, lower, upper, resize_output)
|
||||
|
||||
|
||||
def ragged2array() -> Model[Ragged, Floats2d]:
|
||||
def _forward(model, X, is_train):
|
||||
lengths = X.lengths
|
||||
|
||||
def backprop(dY):
|
||||
return Ragged(dY, lengths)
|
||||
|
||||
return X.dataXd, backprop
|
||||
|
||||
return Model("ragged2array", _forward)
|
||||
|
||||
|
||||
def _define_upper(nO, nI):
|
||||
return Linear(nO=nO, nI=nI, init_W=zero_init)
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional, List
|
||||
from thinc.api import zero_init, with_array, Softmax, chain, Model
|
||||
from thinc.types import Floats2d
|
||||
from thinc.api import zero_init, with_array, Softmax, chain, Model, ragged2list
|
||||
from thinc.types import Floats2d, Ragged
|
||||
|
||||
from ...util import registry
|
||||
from ...tokens import Doc
|
||||
|
@ -8,7 +8,7 @@ from ...tokens import Doc
|
|||
|
||||
@registry.architectures("spacy.Tagger.v1")
|
||||
def build_tagger_model(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]], nO: Optional[int] = None
|
||||
tok2vec: Model[List[Doc], Ragged], nO: Optional[int] = None
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
"""Build a tagger model, using a provided token-to-vector component. The tagger
|
||||
model simply adds a linear layer with softmax activation to predict scores
|
||||
|
@ -21,7 +21,7 @@ def build_tagger_model(
|
|||
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
|
||||
output_layer = Softmax(nO, t2v_width, init_W=zero_init)
|
||||
softmax = with_array(output_layer) # type: ignore
|
||||
model = chain(tok2vec, softmax)
|
||||
model = chain(tok2vec, softmax, ragged2list())
|
||||
model.set_ref("tok2vec", tok2vec)
|
||||
model.set_ref("softmax", output_layer)
|
||||
model.set_ref("output_layer", output_layer)
|
||||
|
|
|
@ -3,6 +3,7 @@ from thinc.types import Floats2d, Ints2d, Ragged
|
|||
from thinc.api import chain, clone, concatenate, with_array, with_padded
|
||||
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
|
||||
from thinc.api import expand_window, residual, Maxout, Mish, PyTorchLSTM
|
||||
from thinc.api import with_list
|
||||
|
||||
from ...tokens import Doc
|
||||
from ...util import registry
|
||||
|
@ -87,6 +88,159 @@ def build_hash_embed_cnn_tok2vec(
|
|||
)
|
||||
|
||||
|
||||
@registry.architectures("spacy.HashEmbedCNN.v3")
|
||||
def build_hash_embed_cnn_tok2vec(
|
||||
*,
|
||||
width: int,
|
||||
depth: int,
|
||||
embed_size: int,
|
||||
window_size: int,
|
||||
maxout_pieces: int,
|
||||
subword_features: bool,
|
||||
pretrained_vectors: Optional[bool],
|
||||
) -> Model[List[Doc], Ragged]:
|
||||
"""Build spaCy's 'standard' tok2vec layer, which uses hash embedding
|
||||
with subword features and a CNN with layer-normalized maxout.
|
||||
|
||||
width (int): The width of the input and output. These are required to be the
|
||||
same, so that residual connections can be used. Recommended values are
|
||||
96, 128 or 300.
|
||||
depth (int): The number of convolutional layers to use. Recommended values
|
||||
are between 2 and 8.
|
||||
window_size (int): The number of tokens on either side to concatenate during
|
||||
the convolutions. The receptive field of the CNN will be
|
||||
depth * (window_size * 2 + 1), so a 4-layer network with window_size of
|
||||
2 will be sensitive to 20 words at a time. Recommended value is 1.
|
||||
embed_size (int): The number of rows in the hash embedding tables. This can
|
||||
be surprisingly small, due to the use of the hash embeddings. Recommended
|
||||
values are between 2000 and 10000.
|
||||
maxout_pieces (int): The number of pieces to use in the maxout non-linearity.
|
||||
If 1, the Mish non-linearity is used instead. Recommended values are 1-3.
|
||||
subword_features (bool): Whether to also embed subword features, specifically
|
||||
the prefix, suffix and word shape. This is recommended for alphabetic
|
||||
languages like English, but not if single-character tokens are used for
|
||||
a language such as Chinese.
|
||||
pretrained_vectors (bool): Whether to also use static vectors.
|
||||
"""
|
||||
if subword_features:
|
||||
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||
row_sizes = [embed_size, embed_size // 2, embed_size // 2, embed_size // 2]
|
||||
else:
|
||||
attrs = ["NORM"]
|
||||
row_sizes = [embed_size]
|
||||
return build_Tok2Vec_model_ragged(
|
||||
embed=MultiHashEmbed_ragged(
|
||||
width=width,
|
||||
rows=row_sizes,
|
||||
attrs=attrs,
|
||||
include_static_vectors=bool(pretrained_vectors),
|
||||
),
|
||||
encode=MaxoutWindowEncoder(
|
||||
width=width,
|
||||
depth=depth,
|
||||
window_size=window_size,
|
||||
maxout_pieces=maxout_pieces,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@registry.architectures("spacy.Tok2Vec.v3")
|
||||
def build_Tok2Vec_model_ragged(
|
||||
embed: Model[List[Doc], Ragged],
|
||||
encode: Model[List[Floats2d], Ragged],
|
||||
) -> Model[List[Doc], Ragged]:
|
||||
"""Construct a tok2vec model out of embedding and encoding subnetworks.
|
||||
See https://explosion.ai/blog/deep-learning-formula-nlp
|
||||
|
||||
embed (Model[List[Doc], List[Floats2d]]): Embed tokens into context-independent
|
||||
word vector representations.
|
||||
encode (Model[List[Floats2d], List[Floats2d]]): Encode context into the
|
||||
embeddings, using an architecture such as a CNN, BiLSTM or transformer.
|
||||
"""
|
||||
tok2vec = chain(embed, with_array(encode))
|
||||
if encode.has_dim("nO"):
|
||||
tok2vec.set_dim("nO", encode.get_dim("nO"))
|
||||
tok2vec.set_ref("embed", embed)
|
||||
tok2vec.set_ref("encode", encode)
|
||||
return tok2vec
|
||||
|
||||
|
||||
@registry.architectures("spacy.MultiHashEmbed.v3")
|
||||
def MultiHashEmbed_ragged(
|
||||
width: int,
|
||||
attrs: List[Union[str, int]],
|
||||
rows: List[int],
|
||||
include_static_vectors: bool,
|
||||
) -> Model[List[Doc], Ragged]:
|
||||
"""Construct an embedding layer that separately embeds a number of lexical
|
||||
attributes using hash embedding, concatenates the results, and passes it
|
||||
through a feed-forward subnetwork to build a mixed representation.
|
||||
|
||||
The features used can be configured with the 'attrs' argument. The suggested
|
||||
attributes are NORM, PREFIX, SUFFIX and SHAPE. This lets the model take into
|
||||
account some subword information, without constructing a fully character-based
|
||||
representation. If pretrained vectors are available, they can be included in
|
||||
the representation as well, with the vectors table will be kept static
|
||||
(i.e. it's not updated).
|
||||
|
||||
The `width` parameter specifies the output width of the layer and the widths
|
||||
of all embedding tables. If static vectors are included, a learned linear
|
||||
layer is used to map the vectors to the specified width before concatenating
|
||||
it with the other embedding outputs. A single Maxout layer is then used to
|
||||
reduce the concatenated vectors to the final width.
|
||||
|
||||
The `rows` parameter controls the number of rows used by the `HashEmbed`
|
||||
tables. The HashEmbed layer needs surprisingly few rows, due to its use of
|
||||
the hashing trick. Generally between 2000 and 10000 rows is sufficient,
|
||||
even for very large vocabularies. A number of rows must be specified for each
|
||||
table, so the `rows` list must be of the same length as the `attrs` parameter.
|
||||
|
||||
width (int): The output width. Also used as the width of the embedding tables.
|
||||
Recommended values are between 64 and 300.
|
||||
attrs (list of attr IDs): The token attributes to embed. A separate
|
||||
embedding table will be constructed for each attribute.
|
||||
rows (List[int]): The number of rows in the embedding tables. Must have the
|
||||
same length as attrs.
|
||||
include_static_vectors (bool): Whether to also use static word vectors.
|
||||
Requires a vectors table to be loaded in the Doc objects' vocab.
|
||||
"""
|
||||
if len(rows) != len(attrs):
|
||||
raise ValueError(f"Mismatched lengths: {len(rows)} vs {len(attrs)}")
|
||||
seed = 7
|
||||
|
||||
def make_hash_embed(index):
|
||||
nonlocal seed
|
||||
seed += 1
|
||||
return HashEmbed(width, rows[index], column=index, seed=seed, dropout=0.0)
|
||||
|
||||
embeddings = [make_hash_embed(i) for i in range(len(attrs))]
|
||||
concat_size = width * (len(embeddings) + include_static_vectors)
|
||||
max_out: Model[Ragged, Ragged] = with_array(
|
||||
Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True) # type: ignore
|
||||
)
|
||||
if include_static_vectors:
|
||||
feature_extractor: Model[List[Doc], Ragged] = chain(
|
||||
FeatureExtractor(attrs),
|
||||
cast(Model[List[Ints2d], Ragged], list2ragged()),
|
||||
with_array(concatenate(*embeddings)),
|
||||
)
|
||||
model = chain(
|
||||
concatenate(
|
||||
feature_extractor,
|
||||
StaticVectors(width, dropout=0.0),
|
||||
),
|
||||
max_out,
|
||||
)
|
||||
else:
|
||||
model = chain(
|
||||
FeatureExtractor(list(attrs)),
|
||||
cast(Model[List[Ints2d], Ragged], list2ragged()),
|
||||
with_array(concatenate(*embeddings)),
|
||||
max_out,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@registry.architectures("spacy.Tok2Vec.v2")
|
||||
def build_Tok2Vec_model(
|
||||
embed: Model[List[Doc], List[Floats2d]],
|
||||
|
@ -295,6 +449,38 @@ def MaxoutWindowEncoder(
|
|||
return with_array(model, pad=receptive_field) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@registry.architectures("spacy.MaxoutWindowEncoder.v3")
|
||||
def MaxoutWindowEncoder_ragged(
|
||||
width: int, window_size: int, maxout_pieces: int, depth: int
|
||||
) -> Model[Ragged, Ragged]:
|
||||
"""Encode context using convolutions with maxout activation, layer
|
||||
normalization and residual connections.
|
||||
|
||||
width (int): The input and output width. These are required to be the same,
|
||||
to allow residual connections. This value will be determined by the
|
||||
width of the inputs. Recommended values are between 64 and 300.
|
||||
window_size (int): The number of words to concatenate around each token
|
||||
to construct the convolution. Recommended value is 1.
|
||||
maxout_pieces (int): The number of maxout pieces to use. Recommended
|
||||
values are 2 or 3.
|
||||
depth (int): The number of convolutional layers. Recommended value is 4.
|
||||
"""
|
||||
cnn = chain(
|
||||
expand_window(window_size=window_size),
|
||||
Maxout(
|
||||
nO=width,
|
||||
nI=width * ((window_size * 2) + 1),
|
||||
nP=maxout_pieces,
|
||||
dropout=0.0,
|
||||
normalize=True,
|
||||
),
|
||||
)
|
||||
model = clone(residual(cnn), depth) # type: ignore[arg-type]
|
||||
model.set_dim("nO", width)
|
||||
receptive_field = window_size * depth
|
||||
return with_array(model, pad=receptive_field) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@registry.architectures("spacy.MishWindowEncoder.v2")
|
||||
def MishWindowEncoder(
|
||||
width: int, window_size: int, depth: int
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Sequence, Iterable, Optional, Dict, Callable, List, Any
|
||||
from thinc.api import Model, set_dropout_rate, Optimizer, Config
|
||||
from thinc.types import Ragged
|
||||
from itertools import islice
|
||||
|
||||
from .trainable_pipe import TrainablePipe
|
||||
|
@ -132,7 +133,8 @@ class Tok2Vec(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/tok2vec#set_annotations
|
||||
"""
|
||||
for doc, tokvecs in zip(docs, tokvecses):
|
||||
for i, doc in enumerate(docs):
|
||||
tokvecs = tokvecses[i].dataXd
|
||||
assert tokvecs.shape[0] == len(doc)
|
||||
doc.tensor = tokvecs
|
||||
|
||||
|
@ -162,7 +164,9 @@ class Tok2Vec(TrainablePipe):
|
|||
docs = [eg.predicted for eg in examples]
|
||||
set_dropout_rate(self.model, drop)
|
||||
tokvecs, bp_tokvecs = self.model.begin_update(docs)
|
||||
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||
d_tokvecs = Ragged(
|
||||
self.model.ops.alloc2f(*tokvecs.dataXd.shape), tokvecs.lengths
|
||||
)
|
||||
losses.setdefault(self.name, 0.0)
|
||||
|
||||
def accumulate_gradient(one_d_tokvecs):
|
||||
|
@ -170,10 +174,11 @@ class Tok2Vec(TrainablePipe):
|
|||
to all but the last listener. Only the last one does the backprop.
|
||||
"""
|
||||
nonlocal d_tokvecs
|
||||
for i in range(len(one_d_tokvecs)):
|
||||
d_tokvecs[i] += one_d_tokvecs[i]
|
||||
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
|
||||
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||
d_tokvecs.data += one_d_tokvecs.data
|
||||
losses[self.name] += float((one_d_tokvecs.data ** 2).sum())
|
||||
return Ragged(
|
||||
self.model.ops.alloc2f(*tokvecs.dataXd.shape), tokvecs.lengths
|
||||
)
|
||||
|
||||
def backprop(one_d_tokvecs):
|
||||
"""Callback to actually do the backprop. Passed to last listener."""
|
||||
|
@ -302,7 +307,8 @@ def forward(model: Tok2VecListener, inputs, is_train: bool):
|
|||
outputs.append(model.ops.alloc2f(len(doc), width))
|
||||
else:
|
||||
outputs.append(doc.tensor)
|
||||
return outputs, lambda dX: []
|
||||
lengths = model.ops.asarray1i([x.shape[0] for x in outputs])
|
||||
return Ragged(model.ops.flatten(outputs), lengths), lambda dX: []
|
||||
|
||||
|
||||
def _empty_backprop(dX): # for pickling
|
||||
|
|
Loading…
Reference in New Issue
Block a user