mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Fix beam parse. Not sure if working
This commit is contained in:
		
							parent
							
								
									24b45b45c6
								
							
						
					
					
						commit
						c96d769836
					
				| 
						 | 
					@ -87,8 +87,9 @@ cdef class ParserBeam(object):
 | 
				
			||||||
    def _set_scores(self, Beam beam, scores):
 | 
					    def _set_scores(self, Beam beam, scores):
 | 
				
			||||||
        for i in range(beam.size):
 | 
					        for i in range(beam.size):
 | 
				
			||||||
            state = <StateClass>beam.at(i)
 | 
					            state = <StateClass>beam.at(i)
 | 
				
			||||||
            for j in range(beam.nr_class):
 | 
					            if not state.is_final():
 | 
				
			||||||
                beam.scores[i][j] = scores[i, j]
 | 
					                for j in range(beam.nr_class):
 | 
				
			||||||
 | 
					                    beam.scores[i][j] = scores[i, j]
 | 
				
			||||||
            self.moves.set_valid(beam.is_valid[i], state.c)
 | 
					            self.moves.set_valid(beam.is_valid[i], state.c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False):
 | 
					    def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False):
 | 
				
			||||||
| 
						 | 
					@ -137,8 +138,8 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        backprops.append((token_ids, bp_vectors, bp_scores))
 | 
					        backprops.append((token_ids, bp_vectors, bp_scores))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        p_scores = [scores[indices] for indices in p_indices]
 | 
					        p_scores = [numpy.ascontiguousarray(scores[indices], dtype='f') for indices in p_indices]
 | 
				
			||||||
        g_scores = [scores[indices] for indices in g_indices]
 | 
					        g_scores = [numpy.ascontiguousarray(scores[indices], dtype='f')  for indices in g_indices]
 | 
				
			||||||
        pbeam.advance(p_scores)
 | 
					        pbeam.advance(p_scores)
 | 
				
			||||||
        gbeam.advance(g_scores, follow_gold=True)
 | 
					        gbeam.advance(g_scores, follow_gold=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -176,8 +177,8 @@ def get_states(pbeams, gbeams, beam_map):
 | 
				
			||||||
                beam_map[key] = len(states)
 | 
					                beam_map[key] = len(states)
 | 
				
			||||||
                states.append(<StateClass>gbeam.at(i))
 | 
					                states.append(<StateClass>gbeam.at(i))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    p_indices = numpy.asarray(p_indices, dtype='i')
 | 
					    p_indices = [numpy.asarray(idx, dtype='i') for idx in p_indices]
 | 
				
			||||||
    g_indices = numpy.asarray(g_indices, dtype='i')
 | 
					    g_indices = [numpy.asarray(idx, dtype='i') for idx in g_indices]
 | 
				
			||||||
    return states, p_indices, g_indices
 | 
					    return states, p_indices, g_indices
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -203,7 +204,9 @@ def get_gradient(nr_class, beam_maps, histories, losses):
 | 
				
			||||||
            key = tuple([eg_id])
 | 
					            key = tuple([eg_id])
 | 
				
			||||||
            for j, clas in enumerate(hist):
 | 
					            for j, clas in enumerate(hist):
 | 
				
			||||||
                i = beam_maps[j][key]
 | 
					                i = beam_maps[j][key]
 | 
				
			||||||
                grads[j][i, clas] = loss
 | 
					                # In step j, at state i action clas
 | 
				
			||||||
 | 
					                # resulted in loss
 | 
				
			||||||
 | 
					                grads[j][i, clas] += loss
 | 
				
			||||||
                key = key + tuple([clas])
 | 
					                key = key + tuple([clas])
 | 
				
			||||||
    return grads
 | 
					    return grads
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user