mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Update beam_parser
This commit is contained in:
		
							parent
							
								
									c57bf6485d
								
							
						
					
					
						commit
						136a7a2322
					
				|  | @ -72,7 +72,7 @@ def get_templates(name): | |||
| 
 | ||||
| 
 | ||||
| cdef int BEAM_WIDTH = 16 | ||||
| cdef weight_t BEAM_DENSITY = 0.001 | ||||
| cdef weight_t BEAM_DENSITY = 0.01 | ||||
| 
 | ||||
| cdef class BeamParser(Parser): | ||||
|     cdef public int beam_width | ||||
|  | @ -104,7 +104,7 @@ cdef class BeamParser(Parser): | |||
|         pred.initialize(_init_state, tokens.length, tokens.c) | ||||
|         pred.check_done(_check_final_state, NULL) | ||||
|          | ||||
|         cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density) | ||||
|         cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=0.0) | ||||
|         gold.initialize(_init_state, tokens.length, tokens.c) | ||||
|         gold.check_done(_check_final_state, NULL) | ||||
|         violn = MaxViolation() | ||||
|  | @ -116,14 +116,22 @@ cdef class BeamParser(Parser): | |||
|             if pred.loss > 0 and pred.min_score > (gold.score + self.model.time): | ||||
|                 break | ||||
|         else: | ||||
|             # The non-monotonic oracle makes it difficult to ensure final costs are | ||||
|             # correct. Therefore do final correction | ||||
|             for i in range(pred.size): | ||||
|                 if is_gold(<StateClass>pred.at(i), gold_parse, self.moves.strings): | ||||
|                     pred._states[i].loss = 0.0 | ||||
|                 elif pred._states[i].loss == 0.0: | ||||
|                     pred._states[i].loss = 1.0 | ||||
|             violn.check_crf(pred, gold) | ||||
|         min_grad = 0.001 ** (itn+1) | ||||
|         _check_train_integrity(pred, gold, gold_parse, self.moves) | ||||
|         histories = zip(violn.p_probs, violn.p_hist) + zip(violn.g_probs, violn.g_hist) | ||||
|         min_grad = 0.001 ** (itn+1) | ||||
|         histories = [(grad, hist) for grad, hist in histories if abs(grad) >= min_grad] | ||||
|         random.shuffle(histories) | ||||
|         for grad, hist in histories: | ||||
|             assert not math.isnan(grad) and not math.isinf(grad) | ||||
|             if abs(grad) >= min_grad: | ||||
|                 self.model._update_from_history(self.moves, tokens, hist, grad) | ||||
|             self.model._update_from_history(self.moves, tokens, hist, grad) | ||||
|         _cleanup(pred) | ||||
|         _cleanup(gold) | ||||
|         return pred.loss | ||||
|  | @ -131,25 +139,26 @@ cdef class BeamParser(Parser): | |||
|     def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold): | ||||
|         cdef Pool mem = Pool() | ||||
|         features = <FeatureC*>mem.alloc(self.model.nr_feat, sizeof(FeatureC)) | ||||
|         cdef ParserNeuralNet nn_model = None | ||||
|         cdef ParserPerceptron ap_model = None | ||||
|         if isinstance(self.model, ParserNeuralNet): | ||||
|             nn_model = self.model | ||||
|             mb = Minibatch(self.model.widths, beam.size) | ||||
|             for i in range(beam.size): | ||||
|                 stcls = <StateClass>beam.at(i) | ||||
|                 if stcls.c.is_final(): | ||||
|                     nr_feat = 0 | ||||
|                 else: | ||||
|                     nr_feat = self.model.set_featuresC(features, stcls.c) | ||||
|                     self.moves.set_valid(beam.is_valid[i], stcls.c) | ||||
|                 mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0) | ||||
|             self.model(mb) | ||||
|             for i in range(beam.size): | ||||
|                 memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0])) | ||||
|         else: | ||||
|             ap_model = self.model | ||||
|             raise NotImplementedError | ||||
|         cdef Minibatch mb = Minibatch(nn_model.widths, beam.size) | ||||
|         for i in range(beam.size): | ||||
|             stcls = <StateClass>beam.at(i) | ||||
|             if stcls.c.is_final(): | ||||
|                 nr_feat = 0 | ||||
|             else: | ||||
|                 nr_feat = nn_model._set_featuresC(features, stcls.c) | ||||
|                 self.moves.set_valid(beam.is_valid[i], stcls.c) | ||||
|             mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0) | ||||
|         self.model(mb) | ||||
|         for i in range(beam.size): | ||||
|             memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0])) | ||||
|             for i in range(beam.size): | ||||
|                 stcls = <StateClass>beam.at(i) | ||||
|                 if not stcls.c.is_final(): | ||||
|                     nr_feat = self.model.set_featuresC(features, stcls.c) | ||||
|                     self.moves.set_valid(beam.is_valid[i], stcls.c) | ||||
|                     self.model.set_scoresC(beam.scores[i], features, nr_feat) | ||||
|         if gold is not None: | ||||
|             for i in range(beam.size): | ||||
|                 stcls = <StateClass>beam.at(i) | ||||
|  | @ -158,7 +167,10 @@ cdef class BeamParser(Parser): | |||
|                     if follow_gold: | ||||
|                         for j in range(self.moves.n_moves): | ||||
|                             beam.is_valid[i][j] *= beam.costs[i][j] < 1 | ||||
|         beam.advance(_transition_state, _hash_state, <void*>self.moves.c) | ||||
|         if follow_gold: | ||||
|             beam.advance(_transition_state, NULL, <void*>self.moves.c) | ||||
|         else: | ||||
|             beam.advance(_transition_state, _hash_state, <void*>self.moves.c) | ||||
|         beam.check_done(_check_final_state, NULL) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -195,4 +207,51 @@ def _cleanup(Beam beam): | |||
| 
 | ||||
| cdef hash_t _hash_state(void* _state, void* _) except 0: | ||||
|     state = <StateClass>_state | ||||
|     return state.c.hash() | ||||
|     if state.c.is_final(): | ||||
|         return 1 | ||||
|     else: | ||||
|         return state.c.hash() | ||||
| 
 | ||||
| 
 | ||||
| def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, TransitionSystem moves): | ||||
|     for i in range(pred.size): | ||||
|         if not pred._states[i].is_done or pred._states[i].loss == 0: | ||||
|             continue | ||||
|         state = <StateClass>pred.at(i) | ||||
|         if is_gold(state, gold_parse, moves.strings) == True: | ||||
|             print("Truth") | ||||
|             for dep in gold_parse.orig_annot: | ||||
|                 print(dep[1], dep[3], dep[4]) | ||||
|             print("Cost", pred._states[i].loss) | ||||
|             for j in range(gold_parse.length): | ||||
|                 print(gold_parse.orig_annot[j][1], state.H(j), moves.strings[state.safe_get(j).dep]) | ||||
|             acts = [moves.c[clas].move for clas in pred.histories[i]] | ||||
|             labels = [moves.c[clas].label for clas in pred.histories[i]] | ||||
|             print([moves.move_name(move, label) for move, label in zip(acts, labels)]) | ||||
|             raise Exception("Predicted state is gold-standard") | ||||
|     for i in range(gold.size): | ||||
|         if not gold._states[i].is_done: | ||||
|             continue | ||||
|         state = <StateClass>gold.at(i) | ||||
|         if is_gold(state, gold_parse, moves.strings) == False: | ||||
|             print("Truth") | ||||
|             for dep in gold_parse.orig_annot: | ||||
|                 print(dep[1], dep[3], dep[4]) | ||||
|             print("Predicted good") | ||||
|             for j in range(gold_parse.length): | ||||
|                 print(gold_parse.orig_annot[j][1], state.H(j), moves.strings[state.safe_get(j).dep]) | ||||
|             raise Exception("Gold parse is not gold-standard") | ||||
| 
 | ||||
| 
 | ||||
| def is_gold(StateClass state, GoldParse gold, StringStore strings): | ||||
|     predicted = set() | ||||
|     truth = set() | ||||
|     for i in range(gold.length): | ||||
|         if state.safe_get(i).dep: | ||||
|             predicted.add((i, state.H(i), strings[state.safe_get(i).dep])) | ||||
|         else: | ||||
|             predicted.add((i, state.H(i), 'ROOT')) | ||||
|         id_, word, tag, head, dep, ner = gold.orig_annot[i] | ||||
|         truth.add((id_, head, dep)) | ||||
|     return truth == predicted | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user