Make it easier to reference embedding tables

This commit is contained in:
Matthew Honnibal 2017-05-29 17:53:29 -05:00
parent 293d1b425b
commit b92a89f87b

View File

@ -133,15 +133,16 @@ class PrecomputableMaxouts(Model):
def Tok2Vec(width, embed_size, preprocess=None):
cols = [ID, LOWER, PREFIX, SUFFIX, SHAPE]
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, '+': add}):
lower = get_col(cols.index(LOWER)) >> HashEmbed(width, embed_size)
prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size//2)
suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size//2)
shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size//2)
lower = get_col(cols.index(LOWER)) >> HashEmbed(width, embed_size, name='embed_lower')
prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size//2, name='embed_prefix')
suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size//2, name='embed_suffix')
shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size//2, name='embed_shape')
embed = (lower | prefix | suffix | shape )
tok2vec = (
with_flatten(
asarray(Model.ops, dtype='uint64')
>> (lower | prefix | suffix | shape )
>> embed
>> Maxout(width, width*4, pieces=3)
>> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3))
>> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3))
@ -153,6 +154,7 @@ def Tok2Vec(width, embed_size, preprocess=None):
tok2vec = preprocess >> tok2vec
# Work around thinc API limitations :(. TODO: Revise in Thinc 7
tok2vec.nO = width
tok2vec.embed = embed
return tok2vec