From 5ff2491f245c95e8003b84aeeef379cb8ef5676d Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 16 Sep 2017 12:47:21 -0500 Subject: [PATCH] Pass option for pre-trained vectors in parser --- spacy/syntax/nn_parser.pyx | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 1c4107c06..04cf20d12 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -245,7 +245,7 @@ cdef class Parser: parser_maxout_pieces = util.env_opt('parser_maxout_pieces', 2) embed_size = util.env_opt('embed_size', 4000) tensors = fine_tune(Tok2Vec(token_vector_width, embed_size, - preprocess=doc2feats())) + pretrained_dims=cfg.get('pretrained_dims'))) if parser_maxout_pieces == 1: lower = PrecomputableAffine(hidden_width if depth >= 1 else nr_class, nF=cls.nr_feature, @@ -391,9 +391,10 @@ cdef class Parser: if isinstance(tokvecses, np.ndarray): tokvecses = [tokvecses] - tokvecs = self.model[0].ops.flatten(tokvecses) if USE_FINE_TUNE: tokvecs = self.model[0].ops.flatten(self.model[0]((docs, tokvecses))) + else: + tokvecs = self.model[0].ops.flatten(tokvecses) nr_state = len(docs) nr_class = self.moves.n_moves @@ -451,9 +452,10 @@ cdef class Parser: cdef Doc doc cdef int nr_class = self.moves.n_moves cdef StateClass stcls, output - tokvecs = self.model[0].ops.flatten(tokvecses) if USE_FINE_TUNE: tokvecs = self.model[0].ops.flatten(self.model[0]((docs, tokvecses))) + else: + tokvecs = self.model[0].ops.flatten(tokvecses) cuda_stream = get_cuda_stream() state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs, cuda_stream, 0.0) @@ -533,6 +535,8 @@ cdef class Parser: if USE_FINE_TUNE: my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop) tokvecs = self.model[0].ops.flatten(my_tokvecs) + else: + tokvecs = self.model[0].ops.flatten(docs_tokvecs[1]) cuda_stream = get_cuda_stream() @@ -603,11 +607,11 @@ cdef class Parser: docs, tokvecs = docs_tokvecs lengths = [len(d) for d in docs] assert min(lengths) >= 1 - tokvecs = self.model[0].ops.flatten(tokvecs) if USE_FINE_TUNE: my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop) tokvecs += self.model[0].ops.flatten(my_tokvecs) - + else: + tokvecs = self.model[0].ops.flatten(tokvecs) states = self.moves.init_batch(docs) for gold in golds: self.moves.preprocess_gold(gold) @@ -775,6 +779,7 @@ cdef class Parser: for label in labels: self.moves.add_action(action, label) if self.model is True: + cfg['pretrained_dims'] = self.vocab.vectors_length self.model, cfg = self.Model(self.moves.n_moves, **cfg) self.cfg.update(cfg) @@ -856,9 +861,11 @@ cdef class Parser: msg = util.from_bytes(bytes_data, deserializers, exclude) if 'model' not in exclude: if self.model is True: - self.model, cfg = self.Model(self.moves.n_moves) + self.model, cfg = self.Model(self.moves.n_moves, + pretrained_dims=self.vocab.vectors_length) else: cfg = {} + cfg['pretrained_dims'] = self.vocab.vectors_length if 'tok2vec_model' in msg: self.model[0].from_bytes(msg['tok2vec_model']) if 'lower_model' in msg: