mirror of
https://github.com/explosion/spaCy.git
synced 2025-05-03 07:13:40 +03:00
Fix dropout in parser
This commit is contained in:
parent
5b67bcbee0
commit
620df0414f
|
@ -249,11 +249,13 @@ cdef class Parser:
|
||||||
with Model.use_device('cpu'):
|
with Model.use_device('cpu'):
|
||||||
if depth == 0:
|
if depth == 0:
|
||||||
upper = chain()
|
upper = chain()
|
||||||
|
upper.is_noop = True
|
||||||
else:
|
else:
|
||||||
upper = chain(
|
upper = chain(
|
||||||
clone(Maxout(hidden_width), (depth-1)),
|
clone(Maxout(hidden_width), (depth-1)),
|
||||||
zero_init(Affine(nr_class))
|
zero_init(Affine(nr_class, drop_factor=0.0))
|
||||||
)
|
)
|
||||||
|
upper.is_noop = False
|
||||||
# TODO: This is an unfortunate hack atm!
|
# TODO: This is an unfortunate hack atm!
|
||||||
# Used to set input dimensions in network.
|
# Used to set input dimensions in network.
|
||||||
lower.begin_training(lower.ops.allocate((500, token_vector_width)))
|
lower.begin_training(lower.ops.allocate((500, token_vector_width)))
|
||||||
|
@ -364,7 +366,7 @@ cdef class Parser:
|
||||||
cdef np.ndarray scores
|
cdef np.ndarray scores
|
||||||
c_token_ids = <int*>token_ids.data
|
c_token_ids = <int*>token_ids.data
|
||||||
c_is_valid = <int*>is_valid.data
|
c_is_valid = <int*>is_valid.data
|
||||||
cdef int has_hidden = hasattr(vec2scores, 'W')
|
cdef int has_hidden = not getattr(vec2scores, 'is_noop', False)
|
||||||
while not next_step.empty():
|
while not next_step.empty():
|
||||||
if not has_hidden:
|
if not has_hidden:
|
||||||
for i in cython.parallel.prange(
|
for i in cython.parallel.prange(
|
||||||
|
@ -426,7 +428,7 @@ cdef class Parser:
|
||||||
|
|
||||||
states = self.moves.init_batch(docs)
|
states = self.moves.init_batch(docs)
|
||||||
state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream,
|
state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream,
|
||||||
drop)
|
0.0)
|
||||||
|
|
||||||
todo = [(s, g) for (s, g) in zip(states, golds)
|
todo = [(s, g) for (s, g) in zip(states, golds)
|
||||||
if not s.is_final() and g is not None]
|
if not s.is_final() and g is not None]
|
||||||
|
@ -438,11 +440,14 @@ cdef class Parser:
|
||||||
states, golds = zip(*todo)
|
states, golds = zip(*todo)
|
||||||
|
|
||||||
token_ids = self.get_token_ids(states)
|
token_ids = self.get_token_ids(states)
|
||||||
vector, bp_vector = state2vec.begin_update(token_ids, drop=drop)
|
vector, bp_vector = state2vec.begin_update(token_ids, drop=0.0)
|
||||||
|
mask = vec2scores.ops.get_dropout_mask(vector.shape, drop)
|
||||||
|
vector *= mask
|
||||||
scores, bp_scores = vec2scores.begin_update(vector, drop=drop)
|
scores, bp_scores = vec2scores.begin_update(vector, drop=drop)
|
||||||
|
|
||||||
d_scores = self.get_batch_loss(states, golds, scores)
|
d_scores = self.get_batch_loss(states, golds, scores)
|
||||||
d_vector = bp_scores(d_scores, sgd=sgd)
|
d_vector = bp_scores(d_scores, sgd=sgd)
|
||||||
|
d_vector *= mask
|
||||||
|
|
||||||
if isinstance(self.model[0].ops, CupyOps) \
|
if isinstance(self.model[0].ops, CupyOps) \
|
||||||
and not isinstance(token_ids, state2vec.ops.xp.ndarray):
|
and not isinstance(token_ids, state2vec.ops.xp.ndarray):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user