mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Ragged tok2vec
This commit is contained in:
		
							parent
							
								
									837d241b68
								
							
						
					
					
						commit
						3b2654db8f
					
				| 
						 | 
				
			
			@ -1,6 +1,6 @@
 | 
			
		|||
from typing import Optional, List, cast
 | 
			
		||||
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
 | 
			
		||||
from thinc.types import Floats2d
 | 
			
		||||
from thinc.api import Model, chain, Linear, zero_init, use_ops
 | 
			
		||||
from thinc.types import Floats2d, Ragged
 | 
			
		||||
 | 
			
		||||
from ...errors import Errors
 | 
			
		||||
from ...compat import Literal
 | 
			
		||||
| 
						 | 
				
			
			@ -12,7 +12,7 @@ from ...tokens import Doc
 | 
			
		|||
 | 
			
		||||
@registry.architectures("spacy.TransitionBasedParser.v2")
 | 
			
		||||
def build_tb_parser_model(
 | 
			
		||||
    tok2vec: Model[List[Doc], List[Floats2d]],
 | 
			
		||||
    tok2vec: Model[List[Doc], Ragged],
 | 
			
		||||
    state_type: Literal["parser", "ner"],
 | 
			
		||||
    extra_state_tokens: bool,
 | 
			
		||||
    hidden_width: int,
 | 
			
		||||
| 
						 | 
				
			
			@ -72,7 +72,7 @@ def build_tb_parser_model(
 | 
			
		|||
    t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
 | 
			
		||||
    tok2vec = chain(
 | 
			
		||||
        tok2vec,
 | 
			
		||||
        cast(Model[List["Floats2d"], Floats2d], list2array()),
 | 
			
		||||
        ragged2array(),
 | 
			
		||||
        Linear(hidden_width, t2v_width),
 | 
			
		||||
    )
 | 
			
		||||
    tok2vec.set_dim("nO", hidden_width)
 | 
			
		||||
| 
						 | 
				
			
			@ -90,6 +90,18 @@ def build_tb_parser_model(
 | 
			
		|||
    return TransitionModel(tok2vec, lower, upper, resize_output)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ragged2array() -> Model[Ragged, Floats2d]:
 | 
			
		||||
    def _forward(model, X, is_train):
 | 
			
		||||
        lengths = X.lengths
 | 
			
		||||
 | 
			
		||||
        def backprop(dY):
 | 
			
		||||
            return Ragged(dY, lengths)
 | 
			
		||||
 | 
			
		||||
        return X.dataXd, backprop
 | 
			
		||||
 | 
			
		||||
    return Model("ragged2array", _forward)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _define_upper(nO, nI):
 | 
			
		||||
    return Linear(nO=nO, nI=nI, init_W=zero_init)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,6 @@
 | 
			
		|||
from typing import Optional, List
 | 
			
		||||
from thinc.api import zero_init, with_array, Softmax, chain, Model
 | 
			
		||||
from thinc.types import Floats2d
 | 
			
		||||
from thinc.api import zero_init, with_array, Softmax, chain, Model, ragged2list
 | 
			
		||||
from thinc.types import Floats2d, Ragged
 | 
			
		||||
 | 
			
		||||
from ...util import registry
 | 
			
		||||
from ...tokens import Doc
 | 
			
		||||
| 
						 | 
				
			
			@ -8,7 +8,7 @@ from ...tokens import Doc
 | 
			
		|||
 | 
			
		||||
@registry.architectures("spacy.Tagger.v1")
 | 
			
		||||
def build_tagger_model(
 | 
			
		||||
    tok2vec: Model[List[Doc], List[Floats2d]], nO: Optional[int] = None
 | 
			
		||||
    tok2vec: Model[List[Doc], Ragged], nO: Optional[int] = None
 | 
			
		||||
) -> Model[List[Doc], List[Floats2d]]:
 | 
			
		||||
    """Build a tagger model, using a provided token-to-vector component. The tagger
 | 
			
		||||
    model simply adds a linear layer with softmax activation to predict scores
 | 
			
		||||
| 
						 | 
				
			
			@ -21,7 +21,7 @@ def build_tagger_model(
 | 
			
		|||
    t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
 | 
			
		||||
    output_layer = Softmax(nO, t2v_width, init_W=zero_init)
 | 
			
		||||
    softmax = with_array(output_layer)  # type: ignore
 | 
			
		||||
    model = chain(tok2vec, softmax)
 | 
			
		||||
    model = chain(tok2vec, softmax, ragged2list())
 | 
			
		||||
    model.set_ref("tok2vec", tok2vec)
 | 
			
		||||
    model.set_ref("softmax", output_layer)
 | 
			
		||||
    model.set_ref("output_layer", output_layer)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,6 +3,7 @@ from thinc.types import Floats2d, Ints2d, Ragged
 | 
			
		|||
from thinc.api import chain, clone, concatenate, with_array, with_padded
 | 
			
		||||
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
 | 
			
		||||
from thinc.api import expand_window, residual, Maxout, Mish, PyTorchLSTM
 | 
			
		||||
from thinc.api import with_list
 | 
			
		||||
 | 
			
		||||
from ...tokens import Doc
 | 
			
		||||
from ...util import registry
 | 
			
		||||
| 
						 | 
				
			
			@ -87,6 +88,159 @@ def build_hash_embed_cnn_tok2vec(
 | 
			
		|||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.architectures("spacy.HashEmbedCNN.v3")
 | 
			
		||||
def build_hash_embed_cnn_tok2vec(
 | 
			
		||||
    *,
 | 
			
		||||
    width: int,
 | 
			
		||||
    depth: int,
 | 
			
		||||
    embed_size: int,
 | 
			
		||||
    window_size: int,
 | 
			
		||||
    maxout_pieces: int,
 | 
			
		||||
    subword_features: bool,
 | 
			
		||||
    pretrained_vectors: Optional[bool],
 | 
			
		||||
) -> Model[List[Doc], Ragged]:
 | 
			
		||||
    """Build spaCy's 'standard' tok2vec layer, which uses hash embedding
 | 
			
		||||
    with subword features and a CNN with layer-normalized maxout.
 | 
			
		||||
 | 
			
		||||
    width (int): The width of the input and output. These are required to be the
 | 
			
		||||
        same, so that residual connections can be used. Recommended values are
 | 
			
		||||
        96, 128 or 300.
 | 
			
		||||
    depth (int): The number of convolutional layers to use. Recommended values
 | 
			
		||||
        are between 2 and 8.
 | 
			
		||||
    window_size (int): The number of tokens on either side to concatenate during
 | 
			
		||||
        the convolutions. The receptive field of the CNN will be
 | 
			
		||||
        depth * (window_size * 2 + 1), so a 4-layer network with window_size of
 | 
			
		||||
        2 will be sensitive to 20 words at a time. Recommended value is 1.
 | 
			
		||||
    embed_size (int): The number of rows in the hash embedding tables. This can
 | 
			
		||||
        be surprisingly small, due to the use of the hash embeddings. Recommended
 | 
			
		||||
        values are between 2000 and 10000.
 | 
			
		||||
    maxout_pieces (int): The number of pieces to use in the maxout non-linearity.
 | 
			
		||||
        If 1, the Mish non-linearity is used instead. Recommended values are 1-3.
 | 
			
		||||
    subword_features (bool): Whether to also embed subword features, specifically
 | 
			
		||||
        the prefix, suffix and word shape. This is recommended for alphabetic
 | 
			
		||||
        languages like English, but not if single-character tokens are used for
 | 
			
		||||
        a language such as Chinese.
 | 
			
		||||
    pretrained_vectors (bool): Whether to also use static vectors.
 | 
			
		||||
    """
 | 
			
		||||
    if subword_features:
 | 
			
		||||
        attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
 | 
			
		||||
        row_sizes = [embed_size, embed_size // 2, embed_size // 2, embed_size // 2]
 | 
			
		||||
    else:
 | 
			
		||||
        attrs = ["NORM"]
 | 
			
		||||
        row_sizes = [embed_size]
 | 
			
		||||
    return build_Tok2Vec_model_ragged(
 | 
			
		||||
        embed=MultiHashEmbed_ragged(
 | 
			
		||||
            width=width,
 | 
			
		||||
            rows=row_sizes,
 | 
			
		||||
            attrs=attrs,
 | 
			
		||||
            include_static_vectors=bool(pretrained_vectors),
 | 
			
		||||
        ),
 | 
			
		||||
        encode=MaxoutWindowEncoder(
 | 
			
		||||
            width=width,
 | 
			
		||||
            depth=depth,
 | 
			
		||||
            window_size=window_size,
 | 
			
		||||
            maxout_pieces=maxout_pieces,
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.architectures("spacy.Tok2Vec.v3")
 | 
			
		||||
def build_Tok2Vec_model_ragged(
 | 
			
		||||
    embed: Model[List[Doc], Ragged],
 | 
			
		||||
    encode: Model[List[Floats2d], Ragged],
 | 
			
		||||
) -> Model[List[Doc], Ragged]:
 | 
			
		||||
    """Construct a tok2vec model out of embedding and encoding subnetworks.
 | 
			
		||||
    See https://explosion.ai/blog/deep-learning-formula-nlp
 | 
			
		||||
 | 
			
		||||
    embed (Model[List[Doc], List[Floats2d]]): Embed tokens into context-independent
 | 
			
		||||
        word vector representations.
 | 
			
		||||
    encode (Model[List[Floats2d], List[Floats2d]]): Encode context into the
 | 
			
		||||
        embeddings, using an architecture such as a CNN, BiLSTM or transformer.
 | 
			
		||||
    """
 | 
			
		||||
    tok2vec = chain(embed, with_array(encode))
 | 
			
		||||
    if encode.has_dim("nO"):
 | 
			
		||||
        tok2vec.set_dim("nO", encode.get_dim("nO"))
 | 
			
		||||
    tok2vec.set_ref("embed", embed)
 | 
			
		||||
    tok2vec.set_ref("encode", encode)
 | 
			
		||||
    return tok2vec
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.architectures("spacy.MultiHashEmbed.v3")
 | 
			
		||||
def MultiHashEmbed_ragged(
 | 
			
		||||
    width: int,
 | 
			
		||||
    attrs: List[Union[str, int]],
 | 
			
		||||
    rows: List[int],
 | 
			
		||||
    include_static_vectors: bool,
 | 
			
		||||
) -> Model[List[Doc], Ragged]:
 | 
			
		||||
    """Construct an embedding layer that separately embeds a number of lexical
 | 
			
		||||
    attributes using hash embedding, concatenates the results, and passes it
 | 
			
		||||
    through a feed-forward subnetwork to build a mixed representation.
 | 
			
		||||
 | 
			
		||||
    The features used can be configured with the 'attrs' argument. The suggested
 | 
			
		||||
    attributes are NORM, PREFIX, SUFFIX and SHAPE. This lets the model take into
 | 
			
		||||
    account some subword information, without constructing a fully character-based
 | 
			
		||||
    representation. If pretrained vectors are available, they can be included in
 | 
			
		||||
    the representation as well, with the vectors table will be kept static
 | 
			
		||||
    (i.e. it's not updated).
 | 
			
		||||
 | 
			
		||||
    The `width` parameter specifies the output width of the layer and the widths
 | 
			
		||||
    of all embedding tables. If static vectors are included, a learned linear
 | 
			
		||||
    layer is used to map the vectors to the specified width before concatenating
 | 
			
		||||
    it with the other embedding outputs. A single Maxout layer is then used to
 | 
			
		||||
    reduce the concatenated vectors to the final width.
 | 
			
		||||
 | 
			
		||||
    The `rows` parameter controls the number of rows used by the `HashEmbed`
 | 
			
		||||
    tables. The HashEmbed layer needs surprisingly few rows, due to its use of
 | 
			
		||||
    the hashing trick. Generally between 2000 and 10000 rows is sufficient,
 | 
			
		||||
    even for very large vocabularies. A number of rows must be specified for each
 | 
			
		||||
    table, so the `rows` list must be of the same length as the `attrs` parameter.
 | 
			
		||||
 | 
			
		||||
    width (int): The output width. Also used as the width of the embedding tables.
 | 
			
		||||
        Recommended values are between 64 and 300.
 | 
			
		||||
    attrs (list of attr IDs): The token attributes to embed. A separate
 | 
			
		||||
        embedding table will be constructed for each attribute.
 | 
			
		||||
    rows (List[int]): The number of rows in the embedding tables. Must have the
 | 
			
		||||
        same length as attrs.
 | 
			
		||||
    include_static_vectors (bool): Whether to also use static word vectors.
 | 
			
		||||
        Requires a vectors table to be loaded in the Doc objects' vocab.
 | 
			
		||||
    """
 | 
			
		||||
    if len(rows) != len(attrs):
 | 
			
		||||
        raise ValueError(f"Mismatched lengths: {len(rows)} vs {len(attrs)}")
 | 
			
		||||
    seed = 7
 | 
			
		||||
 | 
			
		||||
    def make_hash_embed(index):
 | 
			
		||||
        nonlocal seed
 | 
			
		||||
        seed += 1
 | 
			
		||||
        return HashEmbed(width, rows[index], column=index, seed=seed, dropout=0.0)
 | 
			
		||||
 | 
			
		||||
    embeddings = [make_hash_embed(i) for i in range(len(attrs))]
 | 
			
		||||
    concat_size = width * (len(embeddings) + include_static_vectors)
 | 
			
		||||
    max_out: Model[Ragged, Ragged] = with_array(
 | 
			
		||||
        Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)  # type: ignore
 | 
			
		||||
    )
 | 
			
		||||
    if include_static_vectors:
 | 
			
		||||
        feature_extractor: Model[List[Doc], Ragged] = chain(
 | 
			
		||||
            FeatureExtractor(attrs),
 | 
			
		||||
            cast(Model[List[Ints2d], Ragged], list2ragged()),
 | 
			
		||||
            with_array(concatenate(*embeddings)),
 | 
			
		||||
        )
 | 
			
		||||
        model = chain(
 | 
			
		||||
            concatenate(
 | 
			
		||||
                feature_extractor,
 | 
			
		||||
                StaticVectors(width, dropout=0.0),
 | 
			
		||||
            ),
 | 
			
		||||
            max_out,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        model = chain(
 | 
			
		||||
            FeatureExtractor(list(attrs)),
 | 
			
		||||
            cast(Model[List[Ints2d], Ragged], list2ragged()),
 | 
			
		||||
            with_array(concatenate(*embeddings)),
 | 
			
		||||
            max_out,
 | 
			
		||||
        )
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.architectures("spacy.Tok2Vec.v2")
 | 
			
		||||
def build_Tok2Vec_model(
 | 
			
		||||
    embed: Model[List[Doc], List[Floats2d]],
 | 
			
		||||
| 
						 | 
				
			
			@ -295,6 +449,38 @@ def MaxoutWindowEncoder(
 | 
			
		|||
    return with_array(model, pad=receptive_field)  # type: ignore[arg-type]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.architectures("spacy.MaxoutWindowEncoder.v3")
 | 
			
		||||
def MaxoutWindowEncoder_ragged(
 | 
			
		||||
    width: int, window_size: int, maxout_pieces: int, depth: int
 | 
			
		||||
) -> Model[Ragged, Ragged]:
 | 
			
		||||
    """Encode context using convolutions with maxout activation, layer
 | 
			
		||||
    normalization and residual connections.
 | 
			
		||||
 | 
			
		||||
    width (int): The input and output width. These are required to be the same,
 | 
			
		||||
        to allow residual connections. This value will be determined by the
 | 
			
		||||
        width of the inputs. Recommended values are between 64 and 300.
 | 
			
		||||
    window_size (int): The number of words to concatenate around each token
 | 
			
		||||
        to construct the convolution. Recommended value is 1.
 | 
			
		||||
    maxout_pieces (int): The number of maxout pieces to use. Recommended
 | 
			
		||||
        values are 2 or 3.
 | 
			
		||||
    depth (int): The number of convolutional layers. Recommended value is 4.
 | 
			
		||||
    """
 | 
			
		||||
    cnn = chain(
 | 
			
		||||
        expand_window(window_size=window_size),
 | 
			
		||||
        Maxout(
 | 
			
		||||
            nO=width,
 | 
			
		||||
            nI=width * ((window_size * 2) + 1),
 | 
			
		||||
            nP=maxout_pieces,
 | 
			
		||||
            dropout=0.0,
 | 
			
		||||
            normalize=True,
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    model = clone(residual(cnn), depth)  # type: ignore[arg-type]
 | 
			
		||||
    model.set_dim("nO", width)
 | 
			
		||||
    receptive_field = window_size * depth
 | 
			
		||||
    return with_array(model, pad=receptive_field)  # type: ignore[arg-type]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.architectures("spacy.MishWindowEncoder.v2")
 | 
			
		||||
def MishWindowEncoder(
 | 
			
		||||
    width: int, window_size: int, depth: int
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,6 @@
 | 
			
		|||
from typing import Sequence, Iterable, Optional, Dict, Callable, List, Any
 | 
			
		||||
from thinc.api import Model, set_dropout_rate, Optimizer, Config
 | 
			
		||||
from thinc.types import Ragged
 | 
			
		||||
from itertools import islice
 | 
			
		||||
 | 
			
		||||
from .trainable_pipe import TrainablePipe
 | 
			
		||||
| 
						 | 
				
			
			@ -132,7 +133,8 @@ class Tok2Vec(TrainablePipe):
 | 
			
		|||
 | 
			
		||||
        DOCS: https://spacy.io/api/tok2vec#set_annotations
 | 
			
		||||
        """
 | 
			
		||||
        for doc, tokvecs in zip(docs, tokvecses):
 | 
			
		||||
        for i, doc in enumerate(docs):
 | 
			
		||||
            tokvecs = tokvecses[i].dataXd
 | 
			
		||||
            assert tokvecs.shape[0] == len(doc)
 | 
			
		||||
            doc.tensor = tokvecs
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -162,7 +164,9 @@ class Tok2Vec(TrainablePipe):
 | 
			
		|||
        docs = [eg.predicted for eg in examples]
 | 
			
		||||
        set_dropout_rate(self.model, drop)
 | 
			
		||||
        tokvecs, bp_tokvecs = self.model.begin_update(docs)
 | 
			
		||||
        d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
 | 
			
		||||
        d_tokvecs = Ragged(
 | 
			
		||||
            self.model.ops.alloc2f(*tokvecs.dataXd.shape), tokvecs.lengths
 | 
			
		||||
        )
 | 
			
		||||
        losses.setdefault(self.name, 0.0)
 | 
			
		||||
 | 
			
		||||
        def accumulate_gradient(one_d_tokvecs):
 | 
			
		||||
| 
						 | 
				
			
			@ -170,10 +174,11 @@ class Tok2Vec(TrainablePipe):
 | 
			
		|||
            to all but the last listener. Only the last one does the backprop.
 | 
			
		||||
            """
 | 
			
		||||
            nonlocal d_tokvecs
 | 
			
		||||
            for i in range(len(one_d_tokvecs)):
 | 
			
		||||
                d_tokvecs[i] += one_d_tokvecs[i]
 | 
			
		||||
                losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
 | 
			
		||||
            return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
 | 
			
		||||
            d_tokvecs.data += one_d_tokvecs.data
 | 
			
		||||
            losses[self.name] += float((one_d_tokvecs.data ** 2).sum())
 | 
			
		||||
            return Ragged(
 | 
			
		||||
                self.model.ops.alloc2f(*tokvecs.dataXd.shape), tokvecs.lengths
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        def backprop(one_d_tokvecs):
 | 
			
		||||
            """Callback to actually do the backprop. Passed to last listener."""
 | 
			
		||||
| 
						 | 
				
			
			@ -302,7 +307,8 @@ def forward(model: Tok2VecListener, inputs, is_train: bool):
 | 
			
		|||
                outputs.append(model.ops.alloc2f(len(doc), width))
 | 
			
		||||
            else:
 | 
			
		||||
                outputs.append(doc.tensor)
 | 
			
		||||
        return outputs, lambda dX: []
 | 
			
		||||
        lengths = model.ops.asarray1i([x.shape[0] for x in outputs])
 | 
			
		||||
        return Ragged(model.ops.flatten(outputs), lengths), lambda dX: []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _empty_backprop(dX):  # for pickling
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user