Pass dropout through to embed tables

This commit is contained in:
Matthew Honnibal 2017-10-06 06:09:18 -05:00
parent 21d11936fe
commit fbba7c517e

View File

@ -266,15 +266,10 @@ def HistoryFeatures(nr_class, hist_size=8, nr_dim=8):
ops = embed.ops ops = embed.ops
def add_history_fwd(vectors_hists, drop=0.): def add_history_fwd(vectors_hists, drop=0.):
vectors, hist_ids = vectors_hists vectors, hist_ids = vectors_hists
hist_feats, bp_hists = embed.begin_update(hist_ids) hist_feats, bp_hists = embed.begin_update(hist_ids, drop=drop)
outputs = ops.xp.hstack((vectors, hist_feats)) outputs = ops.xp.hstack((vectors, hist_feats))
mask = ops.get_dropout_mask(outputs.shape, drop)
if mask is not None:
outputs *= mask
def add_history_bwd(d_outputs, sgd=None): def add_history_bwd(d_outputs, sgd=None):
if mask is not None:
d_outputs *= mask
d_vectors = d_outputs[:, :vectors.shape[1]] d_vectors = d_outputs[:, :vectors.shape[1]]
d_hists = d_outputs[:, vectors.shape[1]:] d_hists = d_outputs[:, vectors.shape[1]:]
bp_hists(d_hists, sgd=sgd) bp_hists(d_hists, sgd=sgd)