mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Merge branch 'develop' of https://github.com/explosion/spaCy into develop
This commit is contained in:
commit
058372d120
12
spacy/_ml.py
12
spacy/_ml.py
|
@ -212,12 +212,14 @@ class PrecomputableMaxouts(Model):
|
|||
|
||||
def drop_layer(layer, factor=2.):
|
||||
def drop_layer_fwd(X, drop=0.):
|
||||
drop *= factor
|
||||
mask = layer.ops.get_dropout_mask((1,), drop)
|
||||
if mask is None or mask > 0:
|
||||
if drop <= 0.:
|
||||
return layer.begin_update(X, drop=drop)
|
||||
else:
|
||||
return X, lambda dX, sgd=None: dX
|
||||
coinflip = layer.ops.xp.random.random()
|
||||
if (coinflip / factor) >= drop:
|
||||
return layer.begin_update(X, drop=drop)
|
||||
else:
|
||||
return X, lambda dX, sgd=None: dX
|
||||
|
||||
model = wrap(drop_layer_fwd, layer)
|
||||
model.predict = layer
|
||||
|
@ -362,6 +364,8 @@ def get_token_vectors(tokens_attrs_vectors, drop=0.):
|
|||
def backward(d_output, sgd=None):
|
||||
return (tokens, d_output)
|
||||
return vectors, backward
|
||||
|
||||
|
||||
def fine_tune(embedding, combine=None):
|
||||
if combine is not None:
|
||||
raise NotImplementedError(
|
||||
|
|
|
@ -705,7 +705,7 @@ cdef class Parser:
|
|||
lower, stream, drop=dropout)
|
||||
return state2vec, upper
|
||||
|
||||
nr_feature = 8
|
||||
nr_feature = 13
|
||||
|
||||
def get_token_ids(self, states):
|
||||
cdef StateClass state
|
||||
|
|
Loading…
Reference in New Issue
Block a user