Pass option for pre-trained vectors in parser

This commit is contained in:
Matthew Honnibal 2017-09-16 12:47:21 -05:00
parent 8665a77f48
commit 5ff2491f24

View File

@ -245,7 +245,7 @@ cdef class Parser:
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', 2)
embed_size = util.env_opt('embed_size', 4000)
tensors = fine_tune(Tok2Vec(token_vector_width, embed_size,
preprocess=doc2feats()))
pretrained_dims=cfg.get('pretrained_dims')))
if parser_maxout_pieces == 1:
lower = PrecomputableAffine(hidden_width if depth >= 1 else nr_class,
nF=cls.nr_feature,
@ -391,9 +391,10 @@ cdef class Parser:
if isinstance(tokvecses, np.ndarray):
tokvecses = [tokvecses]
tokvecs = self.model[0].ops.flatten(tokvecses)
if USE_FINE_TUNE:
tokvecs = self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
else:
tokvecs = self.model[0].ops.flatten(tokvecses)
nr_state = len(docs)
nr_class = self.moves.n_moves
@ -451,9 +452,10 @@ cdef class Parser:
cdef Doc doc
cdef int nr_class = self.moves.n_moves
cdef StateClass stcls, output
tokvecs = self.model[0].ops.flatten(tokvecses)
if USE_FINE_TUNE:
tokvecs = self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
else:
tokvecs = self.model[0].ops.flatten(tokvecses)
cuda_stream = get_cuda_stream()
state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs,
cuda_stream, 0.0)
@ -533,6 +535,8 @@ cdef class Parser:
if USE_FINE_TUNE:
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
tokvecs = self.model[0].ops.flatten(my_tokvecs)
else:
tokvecs = self.model[0].ops.flatten(docs_tokvecs[1])
cuda_stream = get_cuda_stream()
@ -603,11 +607,11 @@ cdef class Parser:
docs, tokvecs = docs_tokvecs
lengths = [len(d) for d in docs]
assert min(lengths) >= 1
tokvecs = self.model[0].ops.flatten(tokvecs)
if USE_FINE_TUNE:
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
tokvecs += self.model[0].ops.flatten(my_tokvecs)
else:
tokvecs = self.model[0].ops.flatten(tokvecs)
states = self.moves.init_batch(docs)
for gold in golds:
self.moves.preprocess_gold(gold)
@ -775,6 +779,7 @@ cdef class Parser:
for label in labels:
self.moves.add_action(action, label)
if self.model is True:
cfg['pretrained_dims'] = self.vocab.vectors_length
self.model, cfg = self.Model(self.moves.n_moves, **cfg)
self.cfg.update(cfg)
@ -856,9 +861,11 @@ cdef class Parser:
msg = util.from_bytes(bytes_data, deserializers, exclude)
if 'model' not in exclude:
if self.model is True:
self.model, cfg = self.Model(self.moves.n_moves)
self.model, cfg = self.Model(self.moves.n_moves,
pretrained_dims=self.vocab.vectors_length)
else:
cfg = {}
cfg['pretrained_dims'] = self.vocab.vectors_length
if 'tok2vec_model' in msg:
self.model[0].from_bytes(msg['tok2vec_model'])
if 'lower_model' in msg: