mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Pass dropout through to embed tables
This commit is contained in:
parent
21d11936fe
commit
fbba7c517e
|
@ -266,15 +266,10 @@ def HistoryFeatures(nr_class, hist_size=8, nr_dim=8):
|
||||||
ops = embed.ops
|
ops = embed.ops
|
||||||
def add_history_fwd(vectors_hists, drop=0.):
|
def add_history_fwd(vectors_hists, drop=0.):
|
||||||
vectors, hist_ids = vectors_hists
|
vectors, hist_ids = vectors_hists
|
||||||
hist_feats, bp_hists = embed.begin_update(hist_ids)
|
hist_feats, bp_hists = embed.begin_update(hist_ids, drop=drop)
|
||||||
outputs = ops.xp.hstack((vectors, hist_feats))
|
outputs = ops.xp.hstack((vectors, hist_feats))
|
||||||
mask = ops.get_dropout_mask(outputs.shape, drop)
|
|
||||||
if mask is not None:
|
|
||||||
outputs *= mask
|
|
||||||
|
|
||||||
def add_history_bwd(d_outputs, sgd=None):
|
def add_history_bwd(d_outputs, sgd=None):
|
||||||
if mask is not None:
|
|
||||||
d_outputs *= mask
|
|
||||||
d_vectors = d_outputs[:, :vectors.shape[1]]
|
d_vectors = d_outputs[:, :vectors.shape[1]]
|
||||||
d_hists = d_outputs[:, vectors.shape[1]:]
|
d_hists = d_outputs[:, vectors.shape[1]:]
|
||||||
bp_hists(d_hists, sgd=sgd)
|
bp_hists(d_hists, sgd=sgd)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user