Fiddle with sizings for parser

This commit is contained in:
Matthew Honnibal 2017-05-13 17:20:23 -05:00
parent e6d71e1778
commit 613ba79e2e

View File

@ -41,23 +41,23 @@ class TokenVectorEncoder(object):
Softmax(self.vocab.morphology.n_tags,
token_vector_width))
def build_model(self, lang, width, embed_size=1000, **cfg):
def build_model(self, lang, width, embed_size=5000, **cfg):
cols = self.doc2feats.cols
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, '+': add}):
lower = get_col(cols.index(LOWER)) >> (HashEmbed(width, embed_size*3)
+HashEmbed(width, embed_size*3))
prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size)
suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size)
shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size)
lower = get_col(cols.index(LOWER)) >> (HashEmbed(width, embed_size)
+HashEmbed(width, embed_size))
prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size//2)
suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size//2)
shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size//2)
tok2vec = (
flatten
>> (lower | prefix | suffix | shape )
>> BN(Maxout(width, pieces=3))
>> Residual(ExtractWindow(nW=1) >> BN(Maxout(width, width*3)))
>> Residual(ExtractWindow(nW=1) >> BN(Maxout(width, width*3)))
>> Residual(ExtractWindow(nW=1) >> BN(Maxout(width, width*3)))
>> Residual(ExtractWindow(nW=1) >> BN(Maxout(width, width*3)))
>> Maxout(width, pieces=3)
>> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3))
>> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3))
>> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3))
>> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3))
)
return tok2vec
@ -80,7 +80,9 @@ class TokenVectorEncoder(object):
scores, finish_update = self.tagger.begin_update(feats, drop=drop)
scores, _ = self.tagger.begin_update(feats, drop=drop)
idx = 0
guesses = scores.argmax(axis=1).get()
guesses = scores.argmax(axis=1)
if not isinstance(guesses, numpy.ndarray):
guesses = guesses.get()
for i, doc in enumerate(docs):
tag_ids = guesses[idx:idx+len(doc)]
for j, tag_id in enumerate(tag_ids):