Pass config better in nn_parser

This commit is contained in:
Matthw Honnibal 2019-10-17 21:10:56 +02:00
parent e737750a02
commit ca0759b325

View File

@ -57,6 +57,7 @@ cdef class Parser:
cfg.get('subword_features', True))
conv_depth = util.env_opt('conv_depth', cfg.get('conv_depth', 4))
bilstm_depth = util.env_opt('bilstm_depth', cfg.get('bilstm_depth', 0))
self_attn_depth = util.env_opt('self_attn_depth', cfg.get('self_attn_depth', 0))
if depth != 1:
raise ValueError(TempErrors.T004.format(value=depth))
parser_maxout_pieces = util.env_opt('parser_maxout_pieces',
@ -70,7 +71,8 @@ cdef class Parser:
conv_depth=conv_depth,
subword_features=subword_features,
pretrained_vectors=pretrained_vectors,
bilstm_depth=bilstm_depth)
bilstm_depth=bilstm_depth,
self_attn_depth=self_attn_depth)
tok2vec = chain(tok2vec, flatten)
tok2vec.nO = token_vector_width
lower = PrecomputableAffine(hidden_width,
@ -89,7 +91,10 @@ cdef class Parser:
'hidden_width': hidden_width,
'maxout_pieces': parser_maxout_pieces,
'pretrained_vectors': pretrained_vectors,
'bilstm_depth': bilstm_depth
'bilstm_depth': bilstm_depth,
'self_attn_depth': self_attn_depth,
'conv_depth': conv_depth,
'embed_size': embed_size
}
return ParserModel(tok2vec, lower, upper), cfg