Allow tagger models to be built with pre-defined tok2vec layer

This commit is contained in:
Matthew Honnibal 2017-09-26 05:51:52 -05:00
parent bf917225ab
commit e34e70673f

View File

@ -512,8 +512,11 @@ def build_tagger_model(nr_class, **cfg):
token_vector_width = util.env_opt('token_vector_width', 128)
pretrained_dims = cfg.get('pretrained_dims', 0)
with Model.define_operators({'>>': chain, '+': add}):
tok2vec = Tok2Vec(token_vector_width, embed_size,
pretrained_dims=pretrained_dims)
if 'tok2vec' in cfg:
tok2vec = cfg['tok2vec']
else:
tok2vec = Tok2Vec(token_vector_width, embed_size,
pretrained_dims=pretrained_dims)
model = (
tok2vec
>> with_flatten(Softmax(nr_class, token_vector_width))