Fix beam parsing. Starting to work with early update.

This commit is contained in:
Matthew Honnibal 2016-07-24 10:45:50 +02:00
parent 407ed4652d
commit 8b4abc24e3

View File

@ -100,23 +100,32 @@ cdef class BeamParser(Parser):
cdef Beam gold = Beam(self.moves.n_moves, self.beam_width) cdef Beam gold = Beam(self.moves.n_moves, self.beam_width)
gold.initialize(_init_state, tokens.length, tokens.c) gold.initialize(_init_state, tokens.length, tokens.c)
gold.check_done(_check_final_state, NULL) gold.check_done(_check_final_state, NULL)
violn = MaxViolation()
while not pred.is_done and not gold.is_done: while not pred.is_done and not gold.is_done:
# We search separately here, to allow for ambiguity in the gold
# parse.
self._advance_beam(pred, gold_parse, False) self._advance_beam(pred, gold_parse, False)
self._advance_beam(gold, gold_parse, True) self._advance_beam(gold, gold_parse, True)
# Early update
if pred.min_score > gold.score: if pred.min_score > gold.score:
break break
#print(pred.score, pred.min_score, gold.score) # Gather the partition function --- Z --- by which we can normalize the
# scores into a probability distribution. The simple idea here is that
# we clip the probability of all parses outside the beam to 0.
cdef long double Z = 0.0 cdef long double Z = 0.0
for i in range(pred.size): for i in range(pred.size):
if pred._states[i].loss > 0: # Make sure we've only got negative examples here.
# Otherwise, we might double-count the gold.
if pred._states[i].loss > 0:
Z += exp(pred._states[i].score) Z += exp(pred._states[i].score)
if Z > 0: if Z > 0: # If no negative examples, don't update.
Z += exp(gold.score) Z += exp(gold.score)
for i, hist in enumerate(pred.histories): for i, hist in enumerate(pred.histories):
if pred._states[i].loss > 0: if pred._states[i].loss > 0:
# Update with the negative example.
# Gradient of loss is P(parse) - 0
self._update_dense(tokens, hist, exp(pred._states[i].score) / Z) self._update_dense(tokens, hist, exp(pred._states[i].score) / Z)
# Update with the positive example.
# Gradient of loss is P(parse) - 1
self._update_dense(tokens, gold.histories[0], (exp(gold.score) / Z) - 1) self._update_dense(tokens, gold.histories[0], (exp(gold.score) / Z) - 1)
_cleanup(pred) _cleanup(pred)
_cleanup(gold) _cleanup(gold)
@ -217,7 +226,6 @@ def _cleanup(Beam beam):
cdef hash_t _hash_state(void* _state, void* _) except 0: cdef hash_t _hash_state(void* _state, void* _) except 0:
state = <StateClass>_state state = <StateClass>_state
#return <uint64_t>state.c
return state.c.hash() return state.c.hash()
# #