mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	Fix beam parsing. Starting to work with early update.
This commit is contained in:
		
							parent
							
								
									407ed4652d
								
							
						
					
					
						commit
						8b4abc24e3
					
				|  | @ -100,23 +100,32 @@ cdef class BeamParser(Parser): | ||||||
|         cdef Beam gold = Beam(self.moves.n_moves, self.beam_width) |         cdef Beam gold = Beam(self.moves.n_moves, self.beam_width) | ||||||
|         gold.initialize(_init_state, tokens.length, tokens.c) |         gold.initialize(_init_state, tokens.length, tokens.c) | ||||||
|         gold.check_done(_check_final_state, NULL) |         gold.check_done(_check_final_state, NULL) | ||||||
|         violn = MaxViolation() |  | ||||||
| 
 |  | ||||||
|         while not pred.is_done and not gold.is_done: |         while not pred.is_done and not gold.is_done: | ||||||
|  |             # We search separately here, to allow for ambiguity in the gold | ||||||
|  |             # parse. | ||||||
|             self._advance_beam(pred, gold_parse, False) |             self._advance_beam(pred, gold_parse, False) | ||||||
|             self._advance_beam(gold, gold_parse, True) |             self._advance_beam(gold, gold_parse, True) | ||||||
|  |             # Early update | ||||||
|             if pred.min_score > gold.score: |             if pred.min_score > gold.score: | ||||||
|                 break |                 break | ||||||
|         #print(pred.score, pred.min_score, gold.score) |         # Gather the partition function --- Z --- by which we can normalize the | ||||||
|  |         # scores into a probability distribution. The simple idea here is that | ||||||
|  |         # we clip the probability of all parses outside the beam to 0. | ||||||
|         cdef long double Z = 0.0 |         cdef long double Z = 0.0 | ||||||
|         for i in range(pred.size): |         for i in range(pred.size): | ||||||
|  |             # Make sure we've only got negative examples here. | ||||||
|  |             # Otherwise, we might double-count the gold. | ||||||
|             if pred._states[i].loss > 0:  |             if pred._states[i].loss > 0:  | ||||||
|                 Z += exp(pred._states[i].score) |                 Z += exp(pred._states[i].score) | ||||||
|         if Z > 0: |         if Z > 0: # If no negative examples, don't update. | ||||||
|             Z += exp(gold.score) |             Z += exp(gold.score) | ||||||
|             for i, hist in enumerate(pred.histories): |             for i, hist in enumerate(pred.histories): | ||||||
|                 if pred._states[i].loss > 0: |                 if pred._states[i].loss > 0: | ||||||
|  |                     # Update with the negative example. | ||||||
|  |                     # Gradient of loss is P(parse) - 0 | ||||||
|                     self._update_dense(tokens, hist, exp(pred._states[i].score) / Z) |                     self._update_dense(tokens, hist, exp(pred._states[i].score) / Z) | ||||||
|  |             # Update with the positive example. | ||||||
|  |             # Gradient of loss is P(parse) - 1 | ||||||
|             self._update_dense(tokens, gold.histories[0], (exp(gold.score) / Z) - 1) |             self._update_dense(tokens, gold.histories[0], (exp(gold.score) / Z) - 1) | ||||||
|         _cleanup(pred) |         _cleanup(pred) | ||||||
|         _cleanup(gold) |         _cleanup(gold) | ||||||
|  | @ -217,7 +226,6 @@ def _cleanup(Beam beam): | ||||||
| 
 | 
 | ||||||
| cdef hash_t _hash_state(void* _state, void* _) except 0: | cdef hash_t _hash_state(void* _state, void* _) except 0: | ||||||
|     state = <StateClass>_state |     state = <StateClass>_state | ||||||
|     #return <uint64_t>state.c |  | ||||||
|     return state.c.hash() |     return state.c.hash() | ||||||
| 
 | 
 | ||||||
| # | # | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user