Tok2Vec: extract-embed-encode (#5102)

* avoid changing original config

* fix elif structure, batch with just int crashes otherwise

* tok2vec example with doc2feats, encode and embed architectures

* further clean up MultiHashEmbed

* further generalize Tok2Vec to work with extract-embed-encode parts

* avoid initializing the charembed layer with Docs (for now ?)

* small fixes for bilstm config (still does not run)

* rename to core layer

* move new configs

* walk model to set nI instead of using core ref

* fix senter overfitting test to be more similar to the training data (avoid flakey behaviour)
This commit is contained in:
Sofie Van Landeghem 2020-03-08 13:23:18 +01:00 committed by GitHub
parent c95ce96c44
commit 5847be6022
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 227 additions and 141 deletions

View File

@ -62,4 +62,4 @@ width = 96
depth = 4 depth = 4
embed_size = 2000 embed_size = 2000
subword_features = true subword_features = true
char_embed = false maxout_pieces = 3

View File

@ -0,0 +1,65 @@
[training]
use_gpu = -1
limit = 0
dropout = 0.2
patience = 10000
eval_frequency = 200
scores = ["ents_f"]
score_weights = {"ents_f": 1}
orth_variant_level = 0.0
gold_preproc = true
max_length = 0
batch_size = 25
[optimizer]
@optimizers = "Adam.v1"
learn_rate = 0.001
beta1 = 0.9
beta2 = 0.999
[nlp]
lang = "en"
vectors = null
[nlp.pipeline.tok2vec]
factory = "tok2vec"
[nlp.pipeline.tok2vec.model]
@architectures = "spacy.Tok2Vec.v1"
[nlp.pipeline.tok2vec.model.extract]
@architectures = "spacy.CharacterEmbed.v1"
width = 96
nM = 64
nC = 8
rows = 2000
columns = ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"]
[nlp.pipeline.tok2vec.model.extract.features]
@architectures = "spacy.Doc2Feats.v1"
columns = ${nlp.pipeline.tok2vec.model.extract:columns}
[nlp.pipeline.tok2vec.model.embed]
@architectures = "spacy.LayerNormalizedMaxout.v1"
width = ${nlp.pipeline.tok2vec.model.extract:width}
maxout_pieces = 4
[nlp.pipeline.tok2vec.model.encode]
@architectures = "spacy.MaxoutWindowEncoder.v1"
width = ${nlp.pipeline.tok2vec.model.extract:width}
window_size = 1
maxout_pieces = 2
depth = 2
[nlp.pipeline.ner]
factory = "ner"
[nlp.pipeline.ner.model]
@architectures = "spacy.TransitionBasedParser.v1"
nr_feature_tokens = 6
hidden_width = 64
maxout_pieces = 2
[nlp.pipeline.ner.model.tok2vec]
@architectures = "spacy.Tok2VecTensors.v1"
width = ${nlp.pipeline.tok2vec.model.extract:width}

View File

@ -0,0 +1,65 @@
[training]
use_gpu = -1
limit = 0
dropout = 0.2
patience = 10000
eval_frequency = 200
scores = ["ents_f"]
score_weights = {"ents_f": 1}
orth_variant_level = 0.0
gold_preproc = true
max_length = 0
batch_size = 25
[optimizer]
@optimizers = "Adam.v1"
learn_rate = 0.001
beta1 = 0.9
beta2 = 0.999
[nlp]
lang = "en"
vectors = null
[nlp.pipeline.tok2vec]
factory = "tok2vec"
[nlp.pipeline.tok2vec.model]
@architectures = "spacy.Tok2Vec.v1"
[nlp.pipeline.tok2vec.model.extract]
@architectures = "spacy.Doc2Feats.v1"
columns = ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"]
[nlp.pipeline.tok2vec.model.embed]
@architectures = "spacy.MultiHashEmbed.v1"
columns = ${nlp.pipeline.tok2vec.model.extract:columns}
width = 96
rows = 2000
use_subwords = true
pretrained_vectors = null
[nlp.pipeline.tok2vec.model.embed.mix]
@architectures = "spacy.LayerNormalizedMaxout.v1"
width = ${nlp.pipeline.tok2vec.model.embed:width}
maxout_pieces = 3
[nlp.pipeline.tok2vec.model.encode]
@architectures = "spacy.MaxoutWindowEncoder.v1"
width = ${nlp.pipeline.tok2vec.model.embed:width}
window_size = 1
maxout_pieces = 3
depth = 2
[nlp.pipeline.ner]
factory = "ner"
[nlp.pipeline.ner.model]
@architectures = "spacy.TransitionBasedParser.v1"
nr_feature_tokens = 6
hidden_width = 64
maxout_pieces = 2
[nlp.pipeline.ner.model.tok2vec]
@architectures = "spacy.Tok2VecTensors.v1"
width = ${nlp.pipeline.tok2vec.model.embed:width}

View File

@ -337,13 +337,14 @@ class Language(object):
default_config = self.defaults.get(name, None) default_config = self.defaults.get(name, None)
# transform the model's config to an actual Model # transform the model's config to an actual Model
factory_cfg = dict(config)
model_cfg = None model_cfg = None
if "model" in config: if "model" in factory_cfg:
model_cfg = config["model"] model_cfg = factory_cfg["model"]
if not isinstance(model_cfg, dict): if not isinstance(model_cfg, dict):
warnings.warn(Warnings.W099.format(type=type(model_cfg), pipe=name)) warnings.warn(Warnings.W099.format(type=type(model_cfg), pipe=name))
model_cfg = None model_cfg = None
del config["model"] del factory_cfg["model"]
if model_cfg is None and default_config is not None: if model_cfg is None and default_config is not None:
warnings.warn(Warnings.W098.format(name=name)) warnings.warn(Warnings.W098.format(name=name))
model_cfg = default_config["model"] model_cfg = default_config["model"]
@ -353,7 +354,7 @@ class Language(object):
model = registry.make_from_config({"model": model_cfg}, validate=True)[ model = registry.make_from_config({"model": model_cfg}, validate=True)[
"model" "model"
] ]
return factory(self, model, **config) return factory(self, model, **factory_cfg)
def add_pipe( def add_pipe(
self, component, name=None, before=None, after=None, first=None, last=None self, component, name=None, before=None, after=None, first=None, last=None

View File

@ -21,7 +21,7 @@ def init(model, X=None, Y=None):
def forward(model, docs, is_train): def forward(model, docs, is_train):
if not docs: if docs is None:
return [] return []
ids = [] ids = []
output = [] output = []

View File

@ -4,7 +4,7 @@ from thinc.api import HashEmbed, StaticVectors, PyTorchLSTM
from thinc.api import residual, LayerNorm, FeatureExtractor, Mish from thinc.api import residual, LayerNorm, FeatureExtractor, Mish
from ... import util from ... import util
from ...util import registry, make_layer from ...util import registry
from ...ml import _character_embed from ...ml import _character_embed
from ...pipeline.tok2vec import Tok2VecListener from ...pipeline.tok2vec import Tok2VecListener
from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE
@ -23,15 +23,14 @@ def get_vocab_vectors(name):
@registry.architectures.register("spacy.Tok2Vec.v1") @registry.architectures.register("spacy.Tok2Vec.v1")
def Tok2Vec(config): def Tok2Vec(extract, embed, encode):
doc2feats = make_layer(config["@doc2feats"])
embed = make_layer(config["@embed"])
encode = make_layer(config["@encode"])
field_size = 0 field_size = 0
if encode.has_attr("receptive_field"): if encode.attrs.get("receptive_field", None):
field_size = encode.attrs["receptive_field"] field_size = encode.attrs["receptive_field"]
tok2vec = chain(doc2feats, with_array(chain(embed, encode), pad=field_size)) with Model.define_operators({">>": chain, "|": concatenate}):
tok2vec.attrs["cfg"] = config if extract.has_dim("nO"):
_set_dims(embed, "nI", extract.get_dim("nO"))
tok2vec = extract >> with_array(embed >> encode, pad=field_size)
tok2vec.set_dim("nO", encode.get_dim("nO")) tok2vec.set_dim("nO", encode.get_dim("nO"))
tok2vec.set_ref("embed", embed) tok2vec.set_ref("embed", embed)
tok2vec.set_ref("encode", encode) tok2vec.set_ref("encode", encode)
@ -39,8 +38,7 @@ def Tok2Vec(config):
@registry.architectures.register("spacy.Doc2Feats.v1") @registry.architectures.register("spacy.Doc2Feats.v1")
def Doc2Feats(config): def Doc2Feats(columns):
columns = config["columns"]
return FeatureExtractor(columns) return FeatureExtractor(columns)
@ -79,8 +77,8 @@ def hash_charembed_cnn(
maxout_pieces, maxout_pieces,
window_size, window_size,
subword_features, subword_features,
nM=0, nM,
nC=0, nC,
): ):
# Allows using character embeddings by setting nC, nM and char_embed=True # Allows using character embeddings by setting nC, nM and char_embed=True
return build_Tok2Vec_model( return build_Tok2Vec_model(
@ -100,7 +98,7 @@ def hash_charembed_cnn(
@registry.architectures.register("spacy.HashEmbedBiLSTM.v1") @registry.architectures.register("spacy.HashEmbedBiLSTM.v1")
def hash_embed_bilstm_v1( def hash_embed_bilstm_v1(
pretrained_vectors, width, depth, embed_size, subword_features pretrained_vectors, width, depth, embed_size, subword_features, maxout_pieces
): ):
# Does not use character embeddings: set to False by default # Does not use character embeddings: set to False by default
return build_Tok2Vec_model( return build_Tok2Vec_model(
@ -109,7 +107,7 @@ def hash_embed_bilstm_v1(
pretrained_vectors=pretrained_vectors, pretrained_vectors=pretrained_vectors,
bilstm_depth=depth, bilstm_depth=depth,
conv_depth=0, conv_depth=0,
maxout_pieces=0, maxout_pieces=maxout_pieces,
window_size=1, window_size=1,
subword_features=subword_features, subword_features=subword_features,
char_embed=False, char_embed=False,
@ -120,7 +118,7 @@ def hash_embed_bilstm_v1(
@registry.architectures.register("spacy.HashCharEmbedBiLSTM.v1") @registry.architectures.register("spacy.HashCharEmbedBiLSTM.v1")
def hash_char_embed_bilstm_v1( def hash_char_embed_bilstm_v1(
pretrained_vectors, width, depth, embed_size, subword_features, nM=0, nC=0 pretrained_vectors, width, depth, embed_size, subword_features, nM, nC, maxout_pieces
): ):
# Allows using character embeddings by setting nC, nM and char_embed=True # Allows using character embeddings by setting nC, nM and char_embed=True
return build_Tok2Vec_model( return build_Tok2Vec_model(
@ -129,7 +127,7 @@ def hash_char_embed_bilstm_v1(
pretrained_vectors=pretrained_vectors, pretrained_vectors=pretrained_vectors,
bilstm_depth=depth, bilstm_depth=depth,
conv_depth=0, conv_depth=0,
maxout_pieces=0, maxout_pieces=maxout_pieces,
window_size=1, window_size=1,
subword_features=subword_features, subword_features=subword_features,
char_embed=True, char_embed=True,
@ -138,104 +136,99 @@ def hash_char_embed_bilstm_v1(
) )
@registry.architectures.register("spacy.MultiHashEmbed.v1") @registry.architectures.register("spacy.LayerNormalizedMaxout.v1")
def MultiHashEmbed(config): def LayerNormalizedMaxout(width, maxout_pieces):
# For backwards compatibility with models before the architecture registry, return Maxout(
# we have to be careful to get exactly the same model structure. One subtle nO=width,
# trick is that when we define concatenation with the operator, the operator nP=maxout_pieces,
# is actually binary associative. So when we write (a | b | c), we're actually dropout=0.0,
# getting concatenate(concatenate(a, b), c). That's why the implementation normalize=True,
# is a bit ugly here. )
cols = config["columns"]
width = config["width"]
rows = config["rows"]
norm = HashEmbed(width, rows, column=cols.index("NORM"))
if config["use_subwords"]: @registry.architectures.register("spacy.MultiHashEmbed.v1")
prefix = HashEmbed(width, rows // 2, column=cols.index("PREFIX")) def MultiHashEmbed(columns, width, rows, use_subwords, pretrained_vectors, mix):
suffix = HashEmbed(width, rows // 2, column=cols.index("SUFFIX")) norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"))
shape = HashEmbed(width, rows // 2, column=cols.index("SHAPE")) if use_subwords:
if config.get("@pretrained_vectors"): prefix = HashEmbed(nO=width, nV=rows // 2, column=columns.index("PREFIX"))
glove = make_layer(config["@pretrained_vectors"]) suffix = HashEmbed(nO=width, nV=rows // 2, column=columns.index("SUFFIX"))
mix = make_layer(config["@mix"]) shape = HashEmbed(nO=width, nV=rows // 2, column=columns.index("SHAPE"))
if pretrained_vectors:
glove = StaticVectors(
vectors=pretrained_vectors.data,
nO=width,
column=columns.index(ID),
dropout=0.0,
)
with Model.define_operators({">>": chain, "|": concatenate}): with Model.define_operators({">>": chain, "|": concatenate}):
if config["use_subwords"] and config["@pretrained_vectors"]: if not use_subwords and not pretrained_vectors:
mix._layers[0].set_dim("nI", width * 5) embed_layer = norm
layer = uniqued(
(glove | norm | prefix | suffix | shape) >> mix,
column=cols.index("ORTH"),
)
elif config["use_subwords"]:
mix._layers[0].set_dim("nI", width * 4)
layer = uniqued(
(norm | prefix | suffix | shape) >> mix, column=cols.index("ORTH")
)
elif config["@pretrained_vectors"]:
mix._layers[0].set_dim("nI", width * 2)
layer = uniqued((glove | norm) >> mix, column=cols.index("ORTH"))
else: else:
layer = norm if use_subwords and pretrained_vectors:
layer.attrs["cfg"] = config nr_columns = 5
return layer concat_columns = glove | norm | prefix | suffix | shape
elif use_subwords:
nr_columns = 4
concat_columns = norm | prefix | suffix | shape
else:
nr_columns = 2
concat_columns = glove | norm
_set_dims(mix, "nI", width * nr_columns)
embed_layer = uniqued(concat_columns >> mix, column=columns.index("ORTH"))
return embed_layer
def _set_dims(model, name, value):
# Loop through the model to set a specific dimension if its unset on any layer.
for node in model.walk():
if node.has_dim(name) is None:
node.set_dim(name, value)
@registry.architectures.register("spacy.CharacterEmbed.v1") @registry.architectures.register("spacy.CharacterEmbed.v1")
def CharacterEmbed(config): def CharacterEmbed(columns, width, rows, nM, nC, features):
width = config["width"] norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"))
chars = config["chars"] chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC)
with Model.define_operators({">>": chain, "|": concatenate}):
chr_embed = _character_embed.CharacterEmbed(nM=width, nC=chars) embed_layer = chr_embed | features >> with_array(norm)
other_tables = make_layer(config["@embed_features"]) embed_layer.set_dim("nO", nM * nC + width)
mix = make_layer(config["@mix"]) return embed_layer
model = chain(concatenate(chr_embed, other_tables), mix)
model.attrs["cfg"] = config
return model
@registry.architectures.register("spacy.MaxoutWindowEncoder.v1") @registry.architectures.register("spacy.MaxoutWindowEncoder.v1")
def MaxoutWindowEncoder(config): def MaxoutWindowEncoder(width, window_size, maxout_pieces, depth):
nO = config["width"] cnn = chain(
nW = config["window_size"] expand_window(window_size=window_size),
nP = config["pieces"] Maxout(nO=width, nI=width * ((window_size * 2) + 1), nP=maxout_pieces, dropout=0.0, normalize=True),
depth = config["depth"]
cnn = (
expand_window(window_size=nW),
Maxout(nO=nO, nI=nO * ((nW * 2) + 1), nP=nP, dropout=0.0, normalize=True),
) )
model = clone(residual(cnn), depth) model = clone(residual(cnn), depth)
model.set_dim("nO", nO) model.set_dim("nO", width)
model.attrs["receptive_field"] = nW * depth model.attrs["receptive_field"] = window_size * depth
return model return model
@registry.architectures.register("spacy.MishWindowEncoder.v1") @registry.architectures.register("spacy.MishWindowEncoder.v1")
def MishWindowEncoder(config): def MishWindowEncoder(width, window_size, depth):
nO = config["width"]
nW = config["window_size"]
depth = config["depth"]
cnn = chain( cnn = chain(
expand_window(window_size=nW), expand_window(window_size=window_size),
Mish(nO=nO, nI=nO * ((nW * 2) + 1)), Mish(nO=width, nI=width * ((window_size * 2) + 1)),
LayerNorm(nO), LayerNorm(width),
) )
model = clone(residual(cnn), depth) model = clone(residual(cnn), depth)
model.set_dim("nO", nO) model.set_dim("nO", width)
return model return model
@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1") @registry.architectures.register("spacy.TorchBiLSTMEncoder.v1")
def TorchBiLSTMEncoder(config): def TorchBiLSTMEncoder(width, depth):
import torch.nn import torch.nn
# TODO FIX # TODO FIX
from thinc.api import PyTorchRNNWrapper from thinc.api import PyTorchRNNWrapper
width = config["width"]
depth = config["depth"]
if depth == 0: if depth == 0:
return noop() return noop()
return with_padded( return with_padded(
@ -243,40 +236,6 @@ def TorchBiLSTMEncoder(config):
) )
# TODO: update
_EXAMPLE_CONFIG = {
"@doc2feats": {
"arch": "Doc2Feats",
"config": {"columns": ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"]},
},
"@embed": {
"arch": "spacy.MultiHashEmbed.v1",
"config": {
"width": 96,
"rows": 2000,
"columns": ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"],
"use_subwords": True,
"@pretrained_vectors": {
"arch": "TransformedStaticVectors",
"config": {
"vectors_name": "en_vectors_web_lg.vectors",
"width": 96,
"column": 0,
},
},
"@mix": {
"arch": "LayerNormalizedMaxout",
"config": {"width": 96, "pieces": 3},
},
},
},
"@encode": {
"arch": "MaxoutWindowEncode",
"config": {"width": 96, "window_size": 1, "depth": 4, "pieces": 3},
},
}
def build_Tok2Vec_model( def build_Tok2Vec_model(
width, width,
embed_size, embed_size,

View File

View File

@ -131,9 +131,10 @@ class Tok2Vec(Pipe):
get_examples (function): Function returning example training data. get_examples (function): Function returning example training data.
pipeline (list): The pipeline the model is part of. pipeline (list): The pipeline the model is part of.
""" """
# TODO: use examples instead ? # TODO: charembed does not play nicely with dim inference yet
docs = [Doc(Vocab(), words=["hello"])] # docs = [Doc(Vocab(), words=["hello"])]
self.model.initialize(X=docs) # self.model.initialize(X=docs)
self.model.initialize()
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)

View File

@ -36,17 +36,17 @@ def test_overfitting_IO():
assert losses["senter"] < 0.0001 assert losses["senter"] < 0.0001
# test the trained model # test the trained model
test_text = "I like eggs. There is ham. She likes ham." test_text = "I like purple eggs. They eat ham. You like yellow eggs."
doc = nlp(test_text) doc = nlp(test_text)
gold_sent_starts = [0] * 12 gold_sent_starts = [0] * 14
gold_sent_starts[0] = 1 gold_sent_starts[0] = 1
gold_sent_starts[4] = 1 gold_sent_starts[5] = 1
gold_sent_starts[8] = 1 gold_sent_starts[9] = 1
assert gold_sent_starts == [int(t.is_sent_start) for t in doc] assert [int(t.is_sent_start) for t in doc] == gold_sent_starts
# Also test the results are still the same after IO # Also test the results are still the same after IO
with make_tempdir() as tmp_dir: with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir) nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir) nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text) doc2 = nlp2(test_text)
assert gold_sent_starts == [int(t.is_sent_start) for t in doc2] assert [int(t.is_sent_start) for t in doc2] == gold_sent_starts

View File

@ -79,11 +79,6 @@ def set_lang_class(name, cls):
registry.languages.register(name, func=cls) registry.languages.register(name, func=cls)
def make_layer(arch_config):
arch_func = registry.architectures.get(arch_config["arch"])
return arch_func(arch_config["config"])
def ensure_path(path): def ensure_path(path):
"""Ensure string is converted to a Path. """Ensure string is converted to a Path.
@ -563,7 +558,7 @@ def minibatch_by_words(examples, size, tuples=True, count_words=len):
"""Create minibatches of a given number of words.""" """Create minibatches of a given number of words."""
if isinstance(size, int): if isinstance(size, int):
size_ = itertools.repeat(size) size_ = itertools.repeat(size)
if isinstance(size, List): elif isinstance(size, List):
size_ = iter(size) size_ = iter(size)
else: else:
size_ = size size_ = size