mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Pass dropout in parser
This commit is contained in:
parent
158e177cae
commit
cdb2d83e16
|
@ -533,7 +533,7 @@ cdef class Parser:
|
|||
|
||||
states, golds, max_steps = self._init_gold_batch(docs, golds)
|
||||
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream,
|
||||
0.0)
|
||||
drop)
|
||||
todo = [(s, g) for (s, g) in zip(states, golds)
|
||||
if not s.is_final() and g is not None]
|
||||
if not todo:
|
||||
|
@ -598,7 +598,7 @@ cdef class Parser:
|
|||
self.moves.preprocess_gold(gold)
|
||||
|
||||
cuda_stream = get_cuda_stream()
|
||||
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream, 0.0)
|
||||
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream, drop)
|
||||
|
||||
states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, 500,
|
||||
states, golds,
|
||||
|
@ -685,7 +685,7 @@ cdef class Parser:
|
|||
tok2vec, lower, upper = self.model
|
||||
tokvecs, bp_tokvecs = tok2vec.begin_update(docs, drop=dropout)
|
||||
state2vec = precompute_hiddens(len(docs), tokvecs,
|
||||
lower, stream, drop=dropout)
|
||||
lower, stream, drop=0.0)
|
||||
return (tokvecs, bp_tokvecs), state2vec, upper
|
||||
|
||||
nr_feature = 8
|
||||
|
|
Loading…
Reference in New Issue
Block a user