Fix conflicts in nn_parser

This commit is contained in:
Matthew Honnibal 2017-08-18 20:55:58 +02:00
parent 1cec1efca7
commit d456d2efe1

View File

@ -51,6 +51,7 @@ from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts
from .._ml import Tok2Vec, doc2feats, rebatch
from ..compat import json_dumps
from . import _beam_utils
from . import _parse_features
from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport fill_context
@ -504,6 +505,9 @@ cdef class Parser:
losses[self.name] = 0.
docs, tokvec_lists = docs_tokvecs
tokvecs = self.model[0].ops.flatten(tokvec_lists)
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
tokvecs += self.model[0].ops.flatten(my_tokvecs)
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
docs = [docs]
golds = [golds]
@ -557,8 +561,7 @@ cdef class Parser:
self._make_updates(d_tokvecs,
backprops, sgd, cuda_stream)
d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs])
if USE_FINE_TUNE:
bp_my_tokvecs(d_tokvecs, sgd=sgd)
bp_my_tokvecs(d_tokvecs, sgd=sgd)
return d_tokvecs
def update_beam(self, docs_tokvecs, golds, width=None, density=None,
@ -573,10 +576,9 @@ cdef class Parser:
lengths = [len(d) for d in docs]
assert min(lengths) >= 1
tokvecs = self.model[0].ops.flatten(tokvecs)
if USE_FINE_TUNE:
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
my_tokvecs = self.model[0].ops.flatten(my_tokvecs)
tokvecs += my_tokvecs
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
my_tokvecs = self.model[0].ops.flatten(my_tokvecs)
tokvecs += my_tokvecs
states = self.moves.init_batch(docs)
for gold in golds:
@ -607,8 +609,7 @@ cdef class Parser:
d_tokvecs = self.model[0].ops.allocate(tokvecs.shape)
self._make_updates(d_tokvecs, backprop_lower, sgd, cuda_stream)
d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, lengths)
if USE_FINE_TUNE:
bp_my_tokvecs(d_tokvecs, sgd=sgd)
bp_my_tokvecs(d_tokvecs, sgd=sgd)
return d_tokvecs
def _init_gold_batch(self, whole_docs, whole_golds):