Pass parser settings better

This commit is contained in:
Matthw Honnibal 2019-10-23 04:41:20 +02:00
parent 8892ce98aa
commit 95648dcdd7

View File

@ -56,10 +56,11 @@ cdef class Parser:
subword_features = util.env_opt('subword_features', subword_features = util.env_opt('subword_features',
cfg.get('subword_features', True)) cfg.get('subword_features', True))
conv_depth = util.env_opt('conv_depth', cfg.get('conv_depth', 4)) conv_depth = util.env_opt('conv_depth', cfg.get('conv_depth', 4))
conv_window = util.env_opt('conv_window', cfg.get('conv_depth', 1)) conv_window = util.env_opt('conv_window', cfg.get('conv_window', 1))
t2v_pieces = util.env_opt('cnn_maxout_pieces', cfg.get('cnn_maxout_pieces', 3)) t2v_pieces = util.env_opt('cnn_maxout_pieces', cfg.get('cnn_maxout_pieces', 3))
bilstm_depth = util.env_opt('bilstm_depth', cfg.get('bilstm_depth', 0)) bilstm_depth = util.env_opt('bilstm_depth', cfg.get('bilstm_depth', 0))
self_attn_depth = util.env_opt('self_attn_depth', cfg.get('self_attn_depth', 0)) self_attn_depth = util.env_opt('self_attn_depth', cfg.get('self_attn_depth', 0))
char_embed = util.env_opt('char_embed', cfg.get('char_embed', 0))
if depth not in (0, 1): if depth not in (0, 1):
raise ValueError(TempErrors.T004.format(value=depth)) raise ValueError(TempErrors.T004.format(value=depth))
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', parser_maxout_pieces = util.env_opt('parser_maxout_pieces',
@ -76,6 +77,7 @@ cdef class Parser:
conv_window=conv_window, conv_window=conv_window,
cnn_maxout_pieces=t2v_pieces, cnn_maxout_pieces=t2v_pieces,
subword_features=subword_features, subword_features=subword_features,
char_embed=char_embed,
pretrained_vectors=pretrained_vectors, pretrained_vectors=pretrained_vectors,
bilstm_depth=bilstm_depth, bilstm_depth=bilstm_depth,
self_attn_depth=self_attn_depth) self_attn_depth=self_attn_depth)
@ -104,6 +106,7 @@ cdef class Parser:
'conv_depth': conv_depth, 'conv_depth': conv_depth,
'conv_window': conv_window, 'conv_window': conv_window,
'embed_size': embed_size, 'embed_size': embed_size,
'char_embed': char_embed,
'cnn_maxout_pieces': t2v_pieces 'cnn_maxout_pieces': t2v_pieces
} }
return ParserModel(tok2vec, lower, upper), cfg return ParserModel(tok2vec, lower, upper), cfg