Update beam_parser

This commit is contained in:
Matthew Honnibal 2016-08-29 14:24:14 +02:00
parent c57bf6485d
commit 136a7a2322

View File

@ -72,7 +72,7 @@ def get_templates(name):
cdef int BEAM_WIDTH = 16 cdef int BEAM_WIDTH = 16
cdef weight_t BEAM_DENSITY = 0.001 cdef weight_t BEAM_DENSITY = 0.01
cdef class BeamParser(Parser): cdef class BeamParser(Parser):
cdef public int beam_width cdef public int beam_width
@ -104,7 +104,7 @@ cdef class BeamParser(Parser):
pred.initialize(_init_state, tokens.length, tokens.c) pred.initialize(_init_state, tokens.length, tokens.c)
pred.check_done(_check_final_state, NULL) pred.check_done(_check_final_state, NULL)
cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density) cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=0.0)
gold.initialize(_init_state, tokens.length, tokens.c) gold.initialize(_init_state, tokens.length, tokens.c)
gold.check_done(_check_final_state, NULL) gold.check_done(_check_final_state, NULL)
violn = MaxViolation() violn = MaxViolation()
@ -116,13 +116,21 @@ cdef class BeamParser(Parser):
if pred.loss > 0 and pred.min_score > (gold.score + self.model.time): if pred.loss > 0 and pred.min_score > (gold.score + self.model.time):
break break
else: else:
# The non-monotonic oracle makes it difficult to ensure final costs are
# correct. Therefore do final correction
for i in range(pred.size):
if is_gold(<StateClass>pred.at(i), gold_parse, self.moves.strings):
pred._states[i].loss = 0.0
elif pred._states[i].loss == 0.0:
pred._states[i].loss = 1.0
violn.check_crf(pred, gold) violn.check_crf(pred, gold)
min_grad = 0.001 ** (itn+1) _check_train_integrity(pred, gold, gold_parse, self.moves)
histories = zip(violn.p_probs, violn.p_hist) + zip(violn.g_probs, violn.g_hist) histories = zip(violn.p_probs, violn.p_hist) + zip(violn.g_probs, violn.g_hist)
min_grad = 0.001 ** (itn+1)
histories = [(grad, hist) for grad, hist in histories if abs(grad) >= min_grad]
random.shuffle(histories) random.shuffle(histories)
for grad, hist in histories: for grad, hist in histories:
assert not math.isnan(grad) and not math.isinf(grad) assert not math.isnan(grad) and not math.isinf(grad)
if abs(grad) >= min_grad:
self.model._update_from_history(self.moves, tokens, hist, grad) self.model._update_from_history(self.moves, tokens, hist, grad)
_cleanup(pred) _cleanup(pred)
_cleanup(gold) _cleanup(gold)
@ -131,25 +139,26 @@ cdef class BeamParser(Parser):
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold): def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
cdef Pool mem = Pool() cdef Pool mem = Pool()
features = <FeatureC*>mem.alloc(self.model.nr_feat, sizeof(FeatureC)) features = <FeatureC*>mem.alloc(self.model.nr_feat, sizeof(FeatureC))
cdef ParserNeuralNet nn_model = None
cdef ParserPerceptron ap_model = None
if isinstance(self.model, ParserNeuralNet): if isinstance(self.model, ParserNeuralNet):
nn_model = self.model mb = Minibatch(self.model.widths, beam.size)
else:
ap_model = self.model
raise NotImplementedError
cdef Minibatch mb = Minibatch(nn_model.widths, beam.size)
for i in range(beam.size): for i in range(beam.size):
stcls = <StateClass>beam.at(i) stcls = <StateClass>beam.at(i)
if stcls.c.is_final(): if stcls.c.is_final():
nr_feat = 0 nr_feat = 0
else: else:
nr_feat = nn_model._set_featuresC(features, stcls.c) nr_feat = self.model.set_featuresC(features, stcls.c)
self.moves.set_valid(beam.is_valid[i], stcls.c) self.moves.set_valid(beam.is_valid[i], stcls.c)
mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0) mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0)
self.model(mb) self.model(mb)
for i in range(beam.size): for i in range(beam.size):
memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0])) memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0]))
else:
for i in range(beam.size):
stcls = <StateClass>beam.at(i)
if not stcls.c.is_final():
nr_feat = self.model.set_featuresC(features, stcls.c)
self.moves.set_valid(beam.is_valid[i], stcls.c)
self.model.set_scoresC(beam.scores[i], features, nr_feat)
if gold is not None: if gold is not None:
for i in range(beam.size): for i in range(beam.size):
stcls = <StateClass>beam.at(i) stcls = <StateClass>beam.at(i)
@ -158,6 +167,9 @@ cdef class BeamParser(Parser):
if follow_gold: if follow_gold:
for j in range(self.moves.n_moves): for j in range(self.moves.n_moves):
beam.is_valid[i][j] *= beam.costs[i][j] < 1 beam.is_valid[i][j] *= beam.costs[i][j] < 1
if follow_gold:
beam.advance(_transition_state, NULL, <void*>self.moves.c)
else:
beam.advance(_transition_state, _hash_state, <void*>self.moves.c) beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL) beam.check_done(_check_final_state, NULL)
@ -195,4 +207,51 @@ def _cleanup(Beam beam):
cdef hash_t _hash_state(void* _state, void* _) except 0: cdef hash_t _hash_state(void* _state, void* _) except 0:
state = <StateClass>_state state = <StateClass>_state
if state.c.is_final():
return 1
else:
return state.c.hash() return state.c.hash()
def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, TransitionSystem moves):
for i in range(pred.size):
if not pred._states[i].is_done or pred._states[i].loss == 0:
continue
state = <StateClass>pred.at(i)
if is_gold(state, gold_parse, moves.strings) == True:
print("Truth")
for dep in gold_parse.orig_annot:
print(dep[1], dep[3], dep[4])
print("Cost", pred._states[i].loss)
for j in range(gold_parse.length):
print(gold_parse.orig_annot[j][1], state.H(j), moves.strings[state.safe_get(j).dep])
acts = [moves.c[clas].move for clas in pred.histories[i]]
labels = [moves.c[clas].label for clas in pred.histories[i]]
print([moves.move_name(move, label) for move, label in zip(acts, labels)])
raise Exception("Predicted state is gold-standard")
for i in range(gold.size):
if not gold._states[i].is_done:
continue
state = <StateClass>gold.at(i)
if is_gold(state, gold_parse, moves.strings) == False:
print("Truth")
for dep in gold_parse.orig_annot:
print(dep[1], dep[3], dep[4])
print("Predicted good")
for j in range(gold_parse.length):
print(gold_parse.orig_annot[j][1], state.H(j), moves.strings[state.safe_get(j).dep])
raise Exception("Gold parse is not gold-standard")
def is_gold(StateClass state, GoldParse gold, StringStore strings):
predicted = set()
truth = set()
for i in range(gold.length):
if state.safe_get(i).dep:
predicted.add((i, state.H(i), strings[state.safe_get(i).dep]))
else:
predicted.add((i, state.H(i), 'ROOT'))
id_, word, tag, head, dep, ner = gold.orig_annot[i]
truth.add((id_, head, dep))
return truth == predicted