Refactor Tok2Vec to use architecture registry (#4518)

* Add refactored tok2vec, using register_architecture

* Refactor Tok2Vec

* Fix ml

* Fix new tok2vec

* Move make_layer to util

* Add wire

* Fix missing import
This commit is contained in:
Matthew Honnibal 2019-10-25 22:28:20 +02:00 committed by GitHub
parent 99e309bb19
commit 406eb95a47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 298 additions and 76 deletions

View File

@ -3,16 +3,14 @@ from __future__ import unicode_literals
import numpy
from thinc.v2v import Model, Maxout, Softmax, Affine, ReLu
from thinc.i2v import HashEmbed, StaticVectors
from thinc.t2t import ExtractWindow, ParametricAttention
from thinc.t2v import Pooling, sum_pool, mean_pool
from thinc.misc import Residual
from thinc.i2v import HashEmbed
from thinc.misc import Residual, FeatureExtracter
from thinc.misc import LayerNorm as LN
from thinc.misc import FeatureExtracter
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
from thinc.api import with_getitem, flatten_add_lengths
from thinc.api import uniqued, wrap, noop
from thinc.api import with_square_sequences
from thinc.linear.linear import LinearModel
from thinc.neural.ops import NumpyOps, CupyOps
from thinc.neural.util import get_array_module, copy_array
@ -26,12 +24,8 @@ import thinc.extra.load_nlp
from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE
from .errors import Errors, user_warning, Warnings
from . import util
from . import ml as new_ml
try:
import torch.nn
from thinc.extra.wrappers import PyTorchWrapperRNN
except ImportError:
torch = None
VECTORS_KEY = "spacy_pretrained_vectors"
@ -310,6 +304,9 @@ def link_vectors_to_models(vocab):
def PyTorchBiLSTM(nO, nI, depth, dropout=0.2):
import torch.nn
from thinc.api import with_square_sequences
from thinc.extra.wrappers import PyTorchWrapperRNN
if depth == 0:
return layerize(noop())
model = torch.nn.LSTM(nI, nO // 2, depth, bidirectional=True, dropout=dropout)
@ -321,77 +318,91 @@ def Tok2Vec(width, embed_size, **kwargs):
cnn_maxout_pieces = kwargs.get("cnn_maxout_pieces", 3)
subword_features = kwargs.get("subword_features", True)
char_embed = kwargs.get("char_embed", False)
if char_embed:
subword_features = False
conv_depth = kwargs.get("conv_depth", 4)
bilstm_depth = kwargs.get("bilstm_depth", 0)
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
with Model.define_operators(
{">>": chain, "|": concatenate, "**": clone, "+": add, "*": reapply}
):
norm = HashEmbed(width, embed_size, column=cols.index(NORM), name="embed_norm")
if subword_features:
prefix = HashEmbed(
width, embed_size // 2, column=cols.index(PREFIX), name="embed_prefix"
)
suffix = HashEmbed(
width, embed_size // 2, column=cols.index(SUFFIX), name="embed_suffix"
)
shape = HashEmbed(
width, embed_size // 2, column=cols.index(SHAPE), name="embed_shape"
)
else:
prefix, suffix, shape = (None, None, None)
if pretrained_vectors is not None:
glove = StaticVectors(pretrained_vectors, width, column=cols.index(ID))
if subword_features:
embed = uniqued(
(glove | norm | prefix | suffix | shape)
>> LN(Maxout(width, width * 5, pieces=3)),
column=cols.index(ORTH),
)
else:
embed = uniqued(
(glove | norm) >> LN(Maxout(width, width * 2, pieces=3)),
column=cols.index(ORTH),
)
elif subword_features:
embed = uniqued(
(norm | prefix | suffix | shape)
>> LN(Maxout(width, width * 4, pieces=3)),
column=cols.index(ORTH),
)
elif char_embed:
embed = concatenate_lists(
CharacterEmbed(nM=64, nC=8),
FeatureExtracter(cols) >> with_flatten(norm),
)
reduce_dimensions = LN(
Maxout(width, 64 * 8 + width, pieces=cnn_maxout_pieces)
)
else:
embed = norm
cols = ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"]
doc2feats_cfg = {"arch": "spacy.Doc2Feats.v1", "config": {"columns": cols}}
if char_embed:
embed_cfg = {
"arch": "spacy.CharacterEmbed.v1",
"config": {
"width": 64,
"chars": 6,
"@mix": {
"arch": "spacy.LayerNormalizedMaxout.v1",
"config": {
"width": width,
"pieces": 3
}
},
"@embed_features": None
}
}
else:
embed_cfg = {
"arch": "spacy.MultiHashEmbed.v1",
"config": {
"width": width,
"rows": embed_size,
"columns": cols,
"use_subwords": subword_features,
"@pretrained_vectors": None,
"@mix": {
"arch": "spacy.LayerNormalizedMaxout.v1",
"config": {
"width": width,
"pieces": 3
}
},
}
}
if pretrained_vectors:
embed_cfg["config"]["@pretrained_vectors"] = {
"arch": "spacy.PretrainedVectors.v1",
"config": {
"vectors_name": pretrained_vectors,
"width": width,
"column": cols.index(ID)
}
}
cnn_cfg = {
"arch": "spacy.MaxoutWindowEncoder.v1",
"config": {
"width": width,
"window_size": 1,
"pieces": cnn_maxout_pieces,
"depth": conv_depth
}
}
convolution = Residual(
ExtractWindow(nW=1)
>> LN(Maxout(width, width * 3, pieces=cnn_maxout_pieces))
)
if char_embed:
tok2vec = embed >> with_flatten(
reduce_dimensions >> convolution ** conv_depth, pad=conv_depth
)
else:
tok2vec = FeatureExtracter(cols) >> with_flatten(
embed >> convolution ** conv_depth, pad=conv_depth
)
if bilstm_depth >= 1:
tok2vec = tok2vec >> PyTorchBiLSTM(width, width, bilstm_depth)
# Work around thinc API limitations :(. TODO: Revise in Thinc 7
tok2vec.nO = width
tok2vec.embed = embed
return tok2vec
bilstm_cfg = {
"arch": "spacy.TorchBiLSTMEncoder.v1",
"config": {
"width": width,
"depth": bilstm_depth,
}
}
if conv_depth == 0 and bilstm_depth == 0:
encode_cfg = {}
elif conv_depth >= 1 and bilstm_depth >= 1:
encode_cfg = {
"arch": "thinc.FeedForward.v1",
"config": {
"children": [cnn_cfg, bilstm_cfg]
}
}
elif conv_depth >= 1:
encode_cfg = cnn_cfg
else:
encode_cfg = bilstm_cfg
config = {
"@doc2feats": doc2feats_cfg,
"@embed": embed_cfg,
"@encode": encode_cfg
}
return new_ml.Tok2Vec(config)
def reapply(layer, n_times):

1
spacy/ml/__init__.py Normal file
View File

@ -0,0 +1 @@
from .tok2vec import Tok2Vec

42
spacy/ml/_wire.py Normal file
View File

@ -0,0 +1,42 @@
from __future__ import unicode_literals
from thinc.api import layerize, wrap, noop, chain, concatenate
from thinc.v2v import Model
def concatenate_lists(*layers, **kwargs): # pragma: no cover
"""Compose two or more models `f`, `g`, etc, such that their outputs are
concatenated, i.e. `concatenate(f, g)(x)` computes `hstack(f(x), g(x))`
"""
if not layers:
return layerize(noop())
drop_factor = kwargs.get("drop_factor", 1.0)
ops = layers[0].ops
layers = [chain(layer, flatten) for layer in layers]
concat = concatenate(*layers)
def concatenate_lists_fwd(Xs, drop=0.0):
if drop is not None:
drop *= drop_factor
lengths = ops.asarray([len(X) for X in Xs], dtype="i")
flat_y, bp_flat_y = concat.begin_update(Xs, drop=drop)
ys = ops.unflatten(flat_y, lengths)
def concatenate_lists_bwd(d_ys, sgd=None):
return bp_flat_y(ops.flatten(d_ys), sgd=sgd)
return ys, concatenate_lists_bwd
model = wrap(concatenate_lists_fwd, concat)
return model
@layerize
def flatten(seqs, drop=0.0):
ops = Model.ops
lengths = ops.asarray([len(seq) for seq in seqs], dtype="i")
def finish_update(d_X, sgd=None):
return ops.unflatten(d_X, lengths, pad=0)
X = ops.flatten(seqs, pad=0)
return X, finish_update

22
spacy/ml/common.py Normal file
View File

@ -0,0 +1,22 @@
from __future__ import unicode_literals
from thinc.api import chain
from thinc.v2v import Maxout
from thinc.misc import LayerNorm
from ..util import register_architecture, make_layer
@register_architecture("thinc.FeedForward.v1")
def FeedForward(config):
layers = [make_layer(layer_cfg) for layer_cfg in config["layers"]]
model = chain(*layers)
model.cfg = config
return model
@register_architecture("spacy.LayerNormalizedMaxout.v1")
def LayerNormalizedMaxout(config):
width = config["width"]
pieces = config["pieces"]
layer = chain(Maxout(width, pieces=pieces), LayerNorm(nO=width))
layer.nO = width
return layer

141
spacy/ml/tok2vec.py Normal file
View File

@ -0,0 +1,141 @@
from __future__ import unicode_literals
from thinc.api import chain, layerize, clone, concatenate, with_flatten, uniqued
from thinc.api import noop, with_square_sequences
from thinc.v2v import Maxout
from thinc.i2v import HashEmbed, StaticVectors
from thinc.t2t import ExtractWindow
from thinc.misc import Residual, LayerNorm, FeatureExtracter
from ..util import make_layer, register_architecture
from ._wire import concatenate_lists
from .common import *
@register_architecture("spacy.Tok2Vec.v1")
def Tok2Vec(config):
doc2feats = make_layer(config["@doc2feats"])
embed = make_layer(config["@embed"])
encode = make_layer(config["@encode"])
tok2vec = chain(doc2feats, with_flatten(chain(embed, encode)))
tok2vec.cfg = config
tok2vec.nO = encode.nO
tok2vec.embed = embed
tok2vec.encode = encode
return tok2vec
@register_architecture("spacy.Doc2Feats.v1")
def Doc2Feats(config):
columns = config["columns"]
return FeatureExtracter(columns)
@register_architecture("spacy.MultiHashEmbed.v1")
def MultiHashEmbed(config):
cols = config["columns"]
width = config["width"]
rows = config["rows"]
tables = [HashEmbed(width, rows, column=cols.index("NORM"), name="embed_norm")]
if config["use_subwords"]:
for feature in ["PREFIX", "SUFFIX", "SHAPE"]:
tables.append(
HashEmbed(
width,
rows // 2,
column=cols.index(feature),
name="embed_%s" % feature.lower(),
)
)
if config.get("@pretrained_vectors"):
tables.append(make_layer(config["@pretrained_vectors"]))
mix = make_layer(config["@mix"])
# This is a pretty ugly hack. Not sure what the best solution should be.
mix._layers[0].nI = sum(table.nO for table in tables)
layer = uniqued(chain(concatenate(*tables), mix), column=cols.index("ORTH"))
layer.cfg = config
return layer
@register_architecture("spacy.CharacterEmbed.v1")
def CharacterEmbed(config):
width = config["width"]
chars = config["chars"]
chr_embed = CharacterEmbed(nM=width, nC=chars)
other_tables = make_layer(config["@embed_features"])
mix = make_layer(config["@mix"])
model = chain(concatenate_lists(chr_embed, other_tables), mix)
model.cfg = config
return model
@register_architecture("spacy.MaxoutWindowEncoder.v1")
def MaxoutWindowEncoder(config):
nO = config["width"]
nW = config["window_size"]
nP = config["pieces"]
depth = config["depth"]
cnn = chain(
ExtractWindow(nW=nW),
Maxout(nO, nO * ((nW * 2) + 1), pieces=nP),
LayerNorm(nO=nO),
)
model = clone(Residual(cnn), depth)
model.nO = nO
return model
@register_architecture("spacy.PretrainedVectors.v1")
def PretrainedVectors(config):
return StaticVectors(config["vectors_name"], config["width"], config["column"])
@register_architecture("spacy.TorchBiLSTMEncoder.v1")
def TorchBiLSTMEncoder(config):
import torch.nn
from thinc.extra.wrappers import PyTorchWrapperRNN
width = config["width"]
depth = config["depth"]
if depth == 0:
return layerize(noop())
return with_square_sequences(
PyTorchWrapperRNN(torch.nn.LSTM(width, width // 2, depth, bidirectional=True))
)
_EXAMPLE_CONFIG = {
"@doc2feats": {
"arch": "Doc2Feats",
"config": {"columns": ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"]},
},
"@embed": {
"arch": "spacy.MultiHashEmbed.v1",
"config": {
"width": 96,
"rows": 2000,
"columns": ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"],
"use_subwords": True,
"@pretrained_vectors": {
"arch": "TransformedStaticVectors",
"config": {
"vectors_name": "en_vectors_web_lg.vectors",
"width": 96,
"column": 0,
},
},
"@mix": {
"arch": "LayerNormalizedMaxout",
"config": {"width": 96, "pieces": 3},
},
},
},
"@encode": {
"arch": "MaxoutWindowEncode",
"config": {"width": 96, "window_size": 1, "depth": 4, "pieces": 3},
},
}

View File

@ -142,6 +142,11 @@ def register_architecture(name, arch=None):
return do_registration
def make_layer(arch_config):
arch_func = get_architecture(arch_config["arch"])
return arch_func(arch_config["config"])
def get_architecture(name):
"""Get a model architecture function by name. Raises a KeyError if the
architecture is not found.