From 123f8b832d7c00e5479e3d814bbaadb54ba54966 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 13:51:43 +0200 Subject: [PATCH] Refactor Tok2Vec model --- spacy/ml/models/tok2vec.py | 365 ++++++------------------------------- 1 file changed, 57 insertions(+), 308 deletions(-) diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index caa9c467c..4bcd61625 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -1,8 +1,8 @@ from typing import Optional, List -from thinc.api import chain, clone, concatenate, with_array, uniqued -from thinc.api import Model, noop, with_padded, Maxout, expand_window -from thinc.api import HashEmbed, StaticVectors, PyTorchLSTM -from thinc.api import residual, LayerNorm, FeatureExtractor, Mish +from thinc.api import chain, clone, concatenate, with_array, with_padded +from thinc.api import Model, noop +from thinc.api import FeatureExtractor, HashEmbed, StaticVectors +from thincapi import expand_window, residual, Maxout, Mish from thinc.types import Floats2d from ... import util @@ -12,199 +12,72 @@ from ...pipeline.tok2vec import Tok2VecListener from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE -@registry.architectures.register("spacy.Tok2VecTensors.v1") -def tok2vec_tensors_v1(width, upstream="*"): +@registry.architectures.register("spacy.Tok2VecListener.v1") +def tok2vec_listener_v1(width, upstream="*"): tok2vec = Tok2VecListener(upstream_name=upstream, width=width) return tok2vec -@registry.architectures.register("spacy.VocabVectors.v1") -def get_vocab_vectors(name): - nlp = util.load_model(name) - return nlp.vocab.vectors - - @registry.architectures.register("spacy.Tok2Vec.v1") -def Tok2Vec(extract, embed, encode): - field_size = 0 - if encode.attrs.get("receptive_field", None): - field_size = encode.attrs["receptive_field"] - with Model.define_operators({">>": chain, "|": concatenate}): - tok2vec = extract >> with_array(embed >> encode, pad=field_size) +def Tok2Vec( + embed: Model[List[Doc], List[Floats2d]], + encode: Model[List[Floats2d], List[Floats2d] +) -> Model[List[Doc], List[Floats2d]]: + tok2vec = with_array( + chain(embed, encode), + pad=encode.attrs.get("receptive_field", 0) + ) tok2vec.set_dim("nO", encode.get_dim("nO")) tok2vec.set_ref("embed", embed) tok2vec.set_ref("encode", encode) return tok2vec -@registry.architectures.register("spacy.Doc2Feats.v1") -def Doc2Feats(columns): - return FeatureExtractor(columns) - - -@registry.architectures.register("spacy.HashEmbedCNN.v1") -def hash_embed_cnn( - pretrained_vectors: str, +@registry.architectures.register("spacy.HashEmbed.v1") +def HashEmbed( width: int, - depth: int, - embed_size: int, - maxout_pieces: int, - window_size: int, - subword_features: bool, - dropout: float, -) -> Model[List[Doc], List[Floats2d]: - # Does not use character embeddings: set to False by default - return build_Tok2Vec_model( - width=width, - embed_size=embed_size, - pretrained_vectors=pretrained_vectors, - conv_depth=depth, - bilstm_depth=0, - maxout_pieces=maxout_pieces, - window_size=window_size, - subword_features=subword_features, - char_embed=False, - nM=0, - nC=0, - dropout=dropout, - ) - - -@registry.architectures.register("spacy.HashCharEmbedCNN.v1") -def hash_charembed_cnn( - pretrained_vectors, - width, - depth, - embed_size, - maxout_pieces, - window_size, - nM, - nC, - dropout, + rows: int, + also_embed_subwords: bool, + also_use_static_vectors: bool ): - # Allows using character embeddings by setting nC, nM and char_embed=True - return build_Tok2Vec_model( - width=width, - embed_size=embed_size, - pretrained_vectors=pretrained_vectors, - conv_depth=depth, - bilstm_depth=0, - maxout_pieces=maxout_pieces, - window_size=window_size, - subword_features=False, - char_embed=True, - nM=nM, - nC=nC, - dropout=dropout, - ) - - -@registry.architectures.register("spacy.HashEmbedBiLSTM.v1") -def hash_embed_bilstm_v1( - pretrained_vectors, - width, - depth, - embed_size, - subword_features, - maxout_pieces, - dropout, -): - # Does not use character embeddings: set to False by default - return build_Tok2Vec_model( - width=width, - embed_size=embed_size, - pretrained_vectors=pretrained_vectors, - bilstm_depth=depth, - conv_depth=0, - maxout_pieces=maxout_pieces, - window_size=1, - subword_features=subword_features, - char_embed=False, - nM=0, - nC=0, - dropout=dropout, - ) - - -@registry.architectures.register("spacy.HashCharEmbedBiLSTM.v1") -def hash_char_embed_bilstm_v1( - pretrained_vectors, width, depth, embed_size, maxout_pieces, nM, nC, dropout -): - # Allows using character embeddings by setting nC, nM and char_embed=True - return build_Tok2Vec_model( - width=width, - embed_size=embed_size, - pretrained_vectors=pretrained_vectors, - bilstm_depth=depth, - conv_depth=0, - maxout_pieces=maxout_pieces, - window_size=1, - subword_features=False, - char_embed=True, - nM=nM, - nC=nC, - dropout=dropout, - ) - - -@registry.architectures.register("spacy.LayerNormalizedMaxout.v1") -def LayerNormalizedMaxout(width, maxout_pieces): - return Maxout(nO=width, nP=maxout_pieces, dropout=0.0, normalize=True) - - -@registry.architectures.register("spacy.MultiHashEmbed.v1") -def MultiHashEmbed( - columns, width, rows, use_subwords, pretrained_vectors, mix, dropout -): - norm = HashEmbed( - nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6 - ) - if use_subwords: - prefix = HashEmbed( - nO=width, - nV=rows // 2, - column=columns.index("PREFIX"), - dropout=dropout, - seed=7, + cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH] + + seed = 7 + def make_hash_embed(feature): + nonlocal seed + seed += 1 + return HashEmbed( + width, + rows if feature == NORM else rows // 2, + column=cols.index(feature), + seed=seed ) - suffix = HashEmbed( - nO=width, - nV=rows // 2, - column=columns.index("SUFFIX"), - dropout=dropout, - seed=8, + + if also_embed_subwords: + embeddings = [ + make_hash_embed(NORM) + make_hash_embed(PREFIX) + make_hash_embed(SUFFIX) + make_hash_embed(SHAPE) + ] + else: + embeddings = [make_hash_embed(NORM)] + + if also_use_static_vectors: + model = chain( + concatenate( + chain(FeatureExtractor(cols), concatenate(*embeddings)), + StaticVectors(width, dropout=dropout) + ), + Maxout(width, dropout=dropout, normalize=True) ) - shape = HashEmbed( - nO=width, - nV=rows // 2, - column=columns.index("SHAPE"), - dropout=dropout, - seed=9, + else: + model = chain( + chain(FeatureExtractor(cols), concatenate(*embeddings)), + Maxout(width, concat_size, dropout=dropout, normalize=True) ) - - if pretrained_vectors: - glove = StaticVectors( - vectors_name=pretrained_vectors, - nO=width, - column=columns.index(ID), - dropout=dropout, - ) - - with Model.define_operators({">>": chain, "|": concatenate}): - if not use_subwords and not pretrained_vectors: - embed_layer = norm - else: - if use_subwords and pretrained_vectors: - concat_columns = glove | norm | prefix | suffix | shape - elif use_subwords: - concat_columns = norm | prefix | suffix | shape - else: - concat_columns = glove | norm - - embed_layer = uniqued(concat_columns >> mix, column=columns.index("ORTH")) - - return embed_layer - + return model + @registry.architectures.register("spacy.CharacterEmbed.v1") def CharacterEmbed(columns, width, rows, nM, nC, features, dropout): @@ -219,7 +92,7 @@ def CharacterEmbed(columns, width, rows, nM, nC, features, dropout): @registry.architectures.register("spacy.MaxoutWindowEncoder.v1") -def MaxoutWindowEncoder(width, window_size, maxout_pieces, depth): +def MaxoutWindowEncoder(width: int, window_size: int, maxout_pieces: int, depth: int): cnn = chain( expand_window(window_size=window_size), Maxout( @@ -249,133 +122,9 @@ def MishWindowEncoder(width, window_size, depth): @registry.architectures.register("spacy.TorchBiLSTMEncoder.v1") -def TorchBiLSTMEncoder(width, depth): - import torch.nn - - # TODO FIX - from thinc.api import PyTorchRNNWrapper - +def BiLSTMEncoder(width, depth, dropout): if depth == 0: return noop() return with_padded( - PyTorchRNNWrapper(torch.nn.LSTM(width, width // 2, depth, bidirectional=True)) + PyTorchLSTM(width, width, bi=True, depth=depth, dropout=dropout) ) - - -def build_Tok2Vec_model( - width: int, - embed_size: int, - pretrained_vectors: Optional[str], - window_size: int, - maxout_pieces: int, - subword_features: bool, - char_embed: bool, - nM: int, - nC: int, - conv_depth: int, - bilstm_depth: int, - dropout: float, -) -> Model: - if char_embed: - subword_features = False - cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH] - with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): - norm = HashEmbed( - nO=width, nV=embed_size, column=cols.index(NORM), dropout=None, seed=0 - ) - if subword_features: - prefix = HashEmbed( - nO=width, - nV=embed_size // 2, - column=cols.index(PREFIX), - dropout=None, - seed=1, - ) - suffix = HashEmbed( - nO=width, - nV=embed_size // 2, - column=cols.index(SUFFIX), - dropout=None, - seed=2, - ) - shape = HashEmbed( - nO=width, - nV=embed_size // 2, - column=cols.index(SHAPE), - dropout=None, - seed=3, - ) - else: - prefix, suffix, shape = (None, None, None) - if pretrained_vectors is not None: - glove = StaticVectors( - vectors=pretrained_vectors.data, - nO=width, - column=cols.index(ID), - dropout=dropout, - ) - - if subword_features: - columns = 5 - embed = uniqued( - (glove | norm | prefix | suffix | shape) - >> Maxout( - nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True, - ), - column=cols.index(ORTH), - ) - else: - columns = 2 - embed = uniqued( - (glove | norm) - >> Maxout( - nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True, - ), - column=cols.index(ORTH), - ) - elif subword_features: - columns = 4 - embed = uniqued( - concatenate(norm, prefix, suffix, shape) - >> Maxout( - nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True, - ), - column=cols.index(ORTH), - ) - elif char_embed: - embed = _character_embed.CharacterEmbed(nM=nM, nC=nC) | FeatureExtractor( - cols - ) >> with_array(norm) - reduce_dimensions = Maxout( - nO=width, nI=nM * nC + width, nP=3, dropout=0.0, normalize=True, - ) - else: - embed = norm - - convolution = residual( - expand_window(window_size=window_size) - >> Maxout( - nO=width, - nI=width * ((window_size * 2) + 1), - nP=maxout_pieces, - dropout=0.0, - normalize=True, - ) - ) - if char_embed: - tok2vec = embed >> with_array( - reduce_dimensions >> convolution ** conv_depth, pad=conv_depth - ) - else: - tok2vec = FeatureExtractor(cols) >> with_array( - embed >> convolution ** conv_depth, pad=conv_depth - ) - - if bilstm_depth >= 1: - tok2vec = tok2vec >> PyTorchLSTM( - nO=width, nI=width, depth=bilstm_depth, bi=True - ) - if tok2vec.has_dim("nO") is not False: - tok2vec.set_dim("nO", width) - tok2vec.set_ref("embed", embed) - return tok2vec