Gate parser fine-tuning behind feature flag

This commit is contained in:
Matthew Honnibal 2017-08-09 16:40:42 -05:00
parent dbdd8afc4b
commit bbace204be

View File

@ -59,8 +59,9 @@ from ..structs cimport TokenC
from ..tokens.doc cimport Doc
from ..strings cimport StringStore
from ..gold cimport GoldParse
from ..attrs cimport TAG, DEP
from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG
USE_FINE_TUNE = True
def get_templates(*args, **kwargs):
return []
@ -237,7 +238,8 @@ cdef class Parser:
token_vector_width = util.env_opt('token_vector_width', token_vector_width)
hidden_width = util.env_opt('hidden_width', hidden_width)
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', 2)
tensors = fine_tune(Tok2Vec(token_vector_width, 7500, preprocess=doc2feats()))
tensors = fine_tune(Tok2Vec(token_vector_width, 7500,
preprocess=doc2feats(cols=[ID, NORM, PREFIX, SUFFIX, TAG])))
if parser_maxout_pieces == 1:
lower = PrecomputableAffine(hidden_width if depth >= 1 else nr_class,
nF=cls.nr_feature,
@ -367,6 +369,7 @@ cdef class Parser:
tokvecses = [tokvecses]
tokvecs = self.model[0].ops.flatten(tokvecses)
if USE_FINE_TUNE:
tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
nr_state = len(docs)
@ -419,6 +422,7 @@ cdef class Parser:
cdef int nr_class = self.moves.n_moves
cdef StateClass stcls, output
tokvecs = self.model[0].ops.flatten(tokvecses)
if USE_FINE_TUNE:
tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
cuda_stream = get_cuda_stream()
state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs,
@ -460,6 +464,7 @@ cdef class Parser:
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
docs = [docs]
golds = [golds]
if USE_FINE_TUNE:
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=0.)
my_tokvecs = self.model[0].ops.flatten(my_tokvecs)
tokvecs += my_tokvecs
@ -513,6 +518,7 @@ cdef class Parser:
self._make_updates(d_tokvecs,
backprops, sgd, cuda_stream)
d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs])
if USE_FINE_TUNE:
bp_my_tokvecs(d_tokvecs, sgd=sgd)
return d_tokvecs