mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Fix drop_layer calculation
This commit is contained in:
parent
789e1a3980
commit
9f512e657a
12
spacy/_ml.py
12
spacy/_ml.py
|
@ -212,12 +212,14 @@ class PrecomputableMaxouts(Model):
|
||||||
|
|
||||||
def drop_layer(layer, factor=2.):
|
def drop_layer(layer, factor=2.):
|
||||||
def drop_layer_fwd(X, drop=0.):
|
def drop_layer_fwd(X, drop=0.):
|
||||||
drop *= factor
|
if drop <= 0.:
|
||||||
mask = layer.ops.get_dropout_mask((1,), drop)
|
|
||||||
if mask is None or mask > 0:
|
|
||||||
return layer.begin_update(X, drop=drop)
|
return layer.begin_update(X, drop=drop)
|
||||||
else:
|
else:
|
||||||
return X, lambda dX, sgd=None: dX
|
coinflip = layer.ops.xp.random.random()
|
||||||
|
if (coinflip / factor) >= drop:
|
||||||
|
return layer.begin_update(X, drop=drop)
|
||||||
|
else:
|
||||||
|
return X, lambda dX, sgd=None: dX
|
||||||
|
|
||||||
model = wrap(drop_layer_fwd, layer)
|
model = wrap(drop_layer_fwd, layer)
|
||||||
model.predict = layer
|
model.predict = layer
|
||||||
|
@ -362,6 +364,8 @@ def get_token_vectors(tokens_attrs_vectors, drop=0.):
|
||||||
def backward(d_output, sgd=None):
|
def backward(d_output, sgd=None):
|
||||||
return (tokens, d_output)
|
return (tokens, d_output)
|
||||||
return vectors, backward
|
return vectors, backward
|
||||||
|
|
||||||
|
|
||||||
def fine_tune(embedding, combine=None):
|
def fine_tune(embedding, combine=None):
|
||||||
if combine is not None:
|
if combine is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user