mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Default to beam_update_prob 1
This commit is contained in:
parent
a61fd60681
commit
8e8724b55b
|
@ -356,7 +356,7 @@ cdef class Parser:
|
|||
losses.setdefault(self.name, 0.)
|
||||
# The probability we use beam update, instead of falling back to
|
||||
# a greedy update
|
||||
beam_update_prob = self.cfg.get('beam_update_prob', 0.5)
|
||||
beam_update_prob = self.cfg.get('beam_update_prob', 1.0)
|
||||
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() < beam_update_prob:
|
||||
return self.update_beam(docs, golds,
|
||||
self.cfg['beam_width'],
|
||||
|
@ -391,7 +391,7 @@ cdef class Parser:
|
|||
self.moves.preprocess_gold(gold)
|
||||
model, finish_update = self.model.begin_update(docs, drop=drop)
|
||||
states_d_scores, backprops, beams = _beam_utils.update_beam(
|
||||
self.moves, self.nr_feature, 500, states, golds, model.state2vec,
|
||||
self.moves, self.nr_feature, 10000, states, golds, model.state2vec,
|
||||
model.vec2scores, width, drop=drop, losses=losses)
|
||||
for i, d_scores in enumerate(states_d_scores):
|
||||
losses[self.name] += (d_scores**2).sum()
|
||||
|
|
Loading…
Reference in New Issue
Block a user