mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 17:54:39 +03:00
Update beam_parser
This commit is contained in:
parent
c57bf6485d
commit
136a7a2322
|
@ -72,7 +72,7 @@ def get_templates(name):
|
||||||
|
|
||||||
|
|
||||||
cdef int BEAM_WIDTH = 16
|
cdef int BEAM_WIDTH = 16
|
||||||
cdef weight_t BEAM_DENSITY = 0.001
|
cdef weight_t BEAM_DENSITY = 0.01
|
||||||
|
|
||||||
cdef class BeamParser(Parser):
|
cdef class BeamParser(Parser):
|
||||||
cdef public int beam_width
|
cdef public int beam_width
|
||||||
|
@ -104,7 +104,7 @@ cdef class BeamParser(Parser):
|
||||||
pred.initialize(_init_state, tokens.length, tokens.c)
|
pred.initialize(_init_state, tokens.length, tokens.c)
|
||||||
pred.check_done(_check_final_state, NULL)
|
pred.check_done(_check_final_state, NULL)
|
||||||
|
|
||||||
cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density)
|
cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=0.0)
|
||||||
gold.initialize(_init_state, tokens.length, tokens.c)
|
gold.initialize(_init_state, tokens.length, tokens.c)
|
||||||
gold.check_done(_check_final_state, NULL)
|
gold.check_done(_check_final_state, NULL)
|
||||||
violn = MaxViolation()
|
violn = MaxViolation()
|
||||||
|
@ -116,13 +116,21 @@ cdef class BeamParser(Parser):
|
||||||
if pred.loss > 0 and pred.min_score > (gold.score + self.model.time):
|
if pred.loss > 0 and pred.min_score > (gold.score + self.model.time):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
# The non-monotonic oracle makes it difficult to ensure final costs are
|
||||||
|
# correct. Therefore do final correction
|
||||||
|
for i in range(pred.size):
|
||||||
|
if is_gold(<StateClass>pred.at(i), gold_parse, self.moves.strings):
|
||||||
|
pred._states[i].loss = 0.0
|
||||||
|
elif pred._states[i].loss == 0.0:
|
||||||
|
pred._states[i].loss = 1.0
|
||||||
violn.check_crf(pred, gold)
|
violn.check_crf(pred, gold)
|
||||||
min_grad = 0.001 ** (itn+1)
|
_check_train_integrity(pred, gold, gold_parse, self.moves)
|
||||||
histories = zip(violn.p_probs, violn.p_hist) + zip(violn.g_probs, violn.g_hist)
|
histories = zip(violn.p_probs, violn.p_hist) + zip(violn.g_probs, violn.g_hist)
|
||||||
|
min_grad = 0.001 ** (itn+1)
|
||||||
|
histories = [(grad, hist) for grad, hist in histories if abs(grad) >= min_grad]
|
||||||
random.shuffle(histories)
|
random.shuffle(histories)
|
||||||
for grad, hist in histories:
|
for grad, hist in histories:
|
||||||
assert not math.isnan(grad) and not math.isinf(grad)
|
assert not math.isnan(grad) and not math.isinf(grad)
|
||||||
if abs(grad) >= min_grad:
|
|
||||||
self.model._update_from_history(self.moves, tokens, hist, grad)
|
self.model._update_from_history(self.moves, tokens, hist, grad)
|
||||||
_cleanup(pred)
|
_cleanup(pred)
|
||||||
_cleanup(gold)
|
_cleanup(gold)
|
||||||
|
@ -131,25 +139,26 @@ cdef class BeamParser(Parser):
|
||||||
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
|
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
features = <FeatureC*>mem.alloc(self.model.nr_feat, sizeof(FeatureC))
|
features = <FeatureC*>mem.alloc(self.model.nr_feat, sizeof(FeatureC))
|
||||||
cdef ParserNeuralNet nn_model = None
|
|
||||||
cdef ParserPerceptron ap_model = None
|
|
||||||
if isinstance(self.model, ParserNeuralNet):
|
if isinstance(self.model, ParserNeuralNet):
|
||||||
nn_model = self.model
|
mb = Minibatch(self.model.widths, beam.size)
|
||||||
else:
|
|
||||||
ap_model = self.model
|
|
||||||
raise NotImplementedError
|
|
||||||
cdef Minibatch mb = Minibatch(nn_model.widths, beam.size)
|
|
||||||
for i in range(beam.size):
|
for i in range(beam.size):
|
||||||
stcls = <StateClass>beam.at(i)
|
stcls = <StateClass>beam.at(i)
|
||||||
if stcls.c.is_final():
|
if stcls.c.is_final():
|
||||||
nr_feat = 0
|
nr_feat = 0
|
||||||
else:
|
else:
|
||||||
nr_feat = nn_model._set_featuresC(features, stcls.c)
|
nr_feat = self.model.set_featuresC(features, stcls.c)
|
||||||
self.moves.set_valid(beam.is_valid[i], stcls.c)
|
self.moves.set_valid(beam.is_valid[i], stcls.c)
|
||||||
mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0)
|
mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0)
|
||||||
self.model(mb)
|
self.model(mb)
|
||||||
for i in range(beam.size):
|
for i in range(beam.size):
|
||||||
memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0]))
|
memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0]))
|
||||||
|
else:
|
||||||
|
for i in range(beam.size):
|
||||||
|
stcls = <StateClass>beam.at(i)
|
||||||
|
if not stcls.c.is_final():
|
||||||
|
nr_feat = self.model.set_featuresC(features, stcls.c)
|
||||||
|
self.moves.set_valid(beam.is_valid[i], stcls.c)
|
||||||
|
self.model.set_scoresC(beam.scores[i], features, nr_feat)
|
||||||
if gold is not None:
|
if gold is not None:
|
||||||
for i in range(beam.size):
|
for i in range(beam.size):
|
||||||
stcls = <StateClass>beam.at(i)
|
stcls = <StateClass>beam.at(i)
|
||||||
|
@ -158,6 +167,9 @@ cdef class BeamParser(Parser):
|
||||||
if follow_gold:
|
if follow_gold:
|
||||||
for j in range(self.moves.n_moves):
|
for j in range(self.moves.n_moves):
|
||||||
beam.is_valid[i][j] *= beam.costs[i][j] < 1
|
beam.is_valid[i][j] *= beam.costs[i][j] < 1
|
||||||
|
if follow_gold:
|
||||||
|
beam.advance(_transition_state, NULL, <void*>self.moves.c)
|
||||||
|
else:
|
||||||
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
||||||
beam.check_done(_check_final_state, NULL)
|
beam.check_done(_check_final_state, NULL)
|
||||||
|
|
||||||
|
@ -195,4 +207,51 @@ def _cleanup(Beam beam):
|
||||||
|
|
||||||
cdef hash_t _hash_state(void* _state, void* _) except 0:
|
cdef hash_t _hash_state(void* _state, void* _) except 0:
|
||||||
state = <StateClass>_state
|
state = <StateClass>_state
|
||||||
|
if state.c.is_final():
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
return state.c.hash()
|
return state.c.hash()
|
||||||
|
|
||||||
|
|
||||||
|
def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, TransitionSystem moves):
|
||||||
|
for i in range(pred.size):
|
||||||
|
if not pred._states[i].is_done or pred._states[i].loss == 0:
|
||||||
|
continue
|
||||||
|
state = <StateClass>pred.at(i)
|
||||||
|
if is_gold(state, gold_parse, moves.strings) == True:
|
||||||
|
print("Truth")
|
||||||
|
for dep in gold_parse.orig_annot:
|
||||||
|
print(dep[1], dep[3], dep[4])
|
||||||
|
print("Cost", pred._states[i].loss)
|
||||||
|
for j in range(gold_parse.length):
|
||||||
|
print(gold_parse.orig_annot[j][1], state.H(j), moves.strings[state.safe_get(j).dep])
|
||||||
|
acts = [moves.c[clas].move for clas in pred.histories[i]]
|
||||||
|
labels = [moves.c[clas].label for clas in pred.histories[i]]
|
||||||
|
print([moves.move_name(move, label) for move, label in zip(acts, labels)])
|
||||||
|
raise Exception("Predicted state is gold-standard")
|
||||||
|
for i in range(gold.size):
|
||||||
|
if not gold._states[i].is_done:
|
||||||
|
continue
|
||||||
|
state = <StateClass>gold.at(i)
|
||||||
|
if is_gold(state, gold_parse, moves.strings) == False:
|
||||||
|
print("Truth")
|
||||||
|
for dep in gold_parse.orig_annot:
|
||||||
|
print(dep[1], dep[3], dep[4])
|
||||||
|
print("Predicted good")
|
||||||
|
for j in range(gold_parse.length):
|
||||||
|
print(gold_parse.orig_annot[j][1], state.H(j), moves.strings[state.safe_get(j).dep])
|
||||||
|
raise Exception("Gold parse is not gold-standard")
|
||||||
|
|
||||||
|
|
||||||
|
def is_gold(StateClass state, GoldParse gold, StringStore strings):
|
||||||
|
predicted = set()
|
||||||
|
truth = set()
|
||||||
|
for i in range(gold.length):
|
||||||
|
if state.safe_get(i).dep:
|
||||||
|
predicted.add((i, state.H(i), strings[state.safe_get(i).dep]))
|
||||||
|
else:
|
||||||
|
predicted.add((i, state.H(i), 'ROOT'))
|
||||||
|
id_, word, tag, head, dep, ner = gold.orig_annot[i]
|
||||||
|
truth.add((id_, head, dep))
|
||||||
|
return truth == predicted
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user