mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Update beam_parser
This commit is contained in:
parent
c57bf6485d
commit
136a7a2322
|
@ -72,7 +72,7 @@ def get_templates(name):
|
|||
|
||||
|
||||
cdef int BEAM_WIDTH = 16
|
||||
cdef weight_t BEAM_DENSITY = 0.001
|
||||
cdef weight_t BEAM_DENSITY = 0.01
|
||||
|
||||
cdef class BeamParser(Parser):
|
||||
cdef public int beam_width
|
||||
|
@ -104,7 +104,7 @@ cdef class BeamParser(Parser):
|
|||
pred.initialize(_init_state, tokens.length, tokens.c)
|
||||
pred.check_done(_check_final_state, NULL)
|
||||
|
||||
cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density)
|
||||
cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=0.0)
|
||||
gold.initialize(_init_state, tokens.length, tokens.c)
|
||||
gold.check_done(_check_final_state, NULL)
|
||||
violn = MaxViolation()
|
||||
|
@ -116,14 +116,22 @@ cdef class BeamParser(Parser):
|
|||
if pred.loss > 0 and pred.min_score > (gold.score + self.model.time):
|
||||
break
|
||||
else:
|
||||
# The non-monotonic oracle makes it difficult to ensure final costs are
|
||||
# correct. Therefore do final correction
|
||||
for i in range(pred.size):
|
||||
if is_gold(<StateClass>pred.at(i), gold_parse, self.moves.strings):
|
||||
pred._states[i].loss = 0.0
|
||||
elif pred._states[i].loss == 0.0:
|
||||
pred._states[i].loss = 1.0
|
||||
violn.check_crf(pred, gold)
|
||||
min_grad = 0.001 ** (itn+1)
|
||||
_check_train_integrity(pred, gold, gold_parse, self.moves)
|
||||
histories = zip(violn.p_probs, violn.p_hist) + zip(violn.g_probs, violn.g_hist)
|
||||
min_grad = 0.001 ** (itn+1)
|
||||
histories = [(grad, hist) for grad, hist in histories if abs(grad) >= min_grad]
|
||||
random.shuffle(histories)
|
||||
for grad, hist in histories:
|
||||
assert not math.isnan(grad) and not math.isinf(grad)
|
||||
if abs(grad) >= min_grad:
|
||||
self.model._update_from_history(self.moves, tokens, hist, grad)
|
||||
self.model._update_from_history(self.moves, tokens, hist, grad)
|
||||
_cleanup(pred)
|
||||
_cleanup(gold)
|
||||
return pred.loss
|
||||
|
@ -131,25 +139,26 @@ cdef class BeamParser(Parser):
|
|||
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
|
||||
cdef Pool mem = Pool()
|
||||
features = <FeatureC*>mem.alloc(self.model.nr_feat, sizeof(FeatureC))
|
||||
cdef ParserNeuralNet nn_model = None
|
||||
cdef ParserPerceptron ap_model = None
|
||||
if isinstance(self.model, ParserNeuralNet):
|
||||
nn_model = self.model
|
||||
mb = Minibatch(self.model.widths, beam.size)
|
||||
for i in range(beam.size):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
if stcls.c.is_final():
|
||||
nr_feat = 0
|
||||
else:
|
||||
nr_feat = self.model.set_featuresC(features, stcls.c)
|
||||
self.moves.set_valid(beam.is_valid[i], stcls.c)
|
||||
mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0)
|
||||
self.model(mb)
|
||||
for i in range(beam.size):
|
||||
memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0]))
|
||||
else:
|
||||
ap_model = self.model
|
||||
raise NotImplementedError
|
||||
cdef Minibatch mb = Minibatch(nn_model.widths, beam.size)
|
||||
for i in range(beam.size):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
if stcls.c.is_final():
|
||||
nr_feat = 0
|
||||
else:
|
||||
nr_feat = nn_model._set_featuresC(features, stcls.c)
|
||||
self.moves.set_valid(beam.is_valid[i], stcls.c)
|
||||
mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0)
|
||||
self.model(mb)
|
||||
for i in range(beam.size):
|
||||
memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0]))
|
||||
for i in range(beam.size):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
if not stcls.c.is_final():
|
||||
nr_feat = self.model.set_featuresC(features, stcls.c)
|
||||
self.moves.set_valid(beam.is_valid[i], stcls.c)
|
||||
self.model.set_scoresC(beam.scores[i], features, nr_feat)
|
||||
if gold is not None:
|
||||
for i in range(beam.size):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
|
@ -158,7 +167,10 @@ cdef class BeamParser(Parser):
|
|||
if follow_gold:
|
||||
for j in range(self.moves.n_moves):
|
||||
beam.is_valid[i][j] *= beam.costs[i][j] < 1
|
||||
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
||||
if follow_gold:
|
||||
beam.advance(_transition_state, NULL, <void*>self.moves.c)
|
||||
else:
|
||||
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
|
||||
|
||||
|
@ -195,4 +207,51 @@ def _cleanup(Beam beam):
|
|||
|
||||
cdef hash_t _hash_state(void* _state, void* _) except 0:
|
||||
state = <StateClass>_state
|
||||
return state.c.hash()
|
||||
if state.c.is_final():
|
||||
return 1
|
||||
else:
|
||||
return state.c.hash()
|
||||
|
||||
|
||||
def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, TransitionSystem moves):
|
||||
for i in range(pred.size):
|
||||
if not pred._states[i].is_done or pred._states[i].loss == 0:
|
||||
continue
|
||||
state = <StateClass>pred.at(i)
|
||||
if is_gold(state, gold_parse, moves.strings) == True:
|
||||
print("Truth")
|
||||
for dep in gold_parse.orig_annot:
|
||||
print(dep[1], dep[3], dep[4])
|
||||
print("Cost", pred._states[i].loss)
|
||||
for j in range(gold_parse.length):
|
||||
print(gold_parse.orig_annot[j][1], state.H(j), moves.strings[state.safe_get(j).dep])
|
||||
acts = [moves.c[clas].move for clas in pred.histories[i]]
|
||||
labels = [moves.c[clas].label for clas in pred.histories[i]]
|
||||
print([moves.move_name(move, label) for move, label in zip(acts, labels)])
|
||||
raise Exception("Predicted state is gold-standard")
|
||||
for i in range(gold.size):
|
||||
if not gold._states[i].is_done:
|
||||
continue
|
||||
state = <StateClass>gold.at(i)
|
||||
if is_gold(state, gold_parse, moves.strings) == False:
|
||||
print("Truth")
|
||||
for dep in gold_parse.orig_annot:
|
||||
print(dep[1], dep[3], dep[4])
|
||||
print("Predicted good")
|
||||
for j in range(gold_parse.length):
|
||||
print(gold_parse.orig_annot[j][1], state.H(j), moves.strings[state.safe_get(j).dep])
|
||||
raise Exception("Gold parse is not gold-standard")
|
||||
|
||||
|
||||
def is_gold(StateClass state, GoldParse gold, StringStore strings):
|
||||
predicted = set()
|
||||
truth = set()
|
||||
for i in range(gold.length):
|
||||
if state.safe_get(i).dep:
|
||||
predicted.add((i, state.H(i), strings[state.safe_get(i).dep]))
|
||||
else:
|
||||
predicted.add((i, state.H(i), 'ROOT'))
|
||||
id_, word, tag, head, dep, ner = gold.orig_annot[i]
|
||||
truth.add((id_, head, dep))
|
||||
return truth == predicted
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user