mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Make parser update less hacky
This commit is contained in:
		
							parent
							
								
									8500d9b1da
								
							
						
					
					
						commit
						679efe79c8
					
				| 
						 | 
				
			
			@ -438,7 +438,7 @@ cdef class Parser:
 | 
			
		|||
        backprops = []
 | 
			
		||||
        d_tokvecs = state2vec.ops.allocate(tokvecs.shape)
 | 
			
		||||
        cdef float loss = 0.
 | 
			
		||||
        while len(todo) >= 2:
 | 
			
		||||
        while todo:
 | 
			
		||||
            states, golds = zip(*todo)
 | 
			
		||||
 | 
			
		||||
            token_ids = self.get_token_ids(states)
 | 
			
		||||
| 
						 | 
				
			
			@ -465,15 +465,10 @@ cdef class Parser:
 | 
			
		|||
                backprops.append((token_ids, d_vector, bp_vector))
 | 
			
		||||
            self.transition_batch(states, scores)
 | 
			
		||||
            todo = [st for st in todo if not st[0].is_final()]
 | 
			
		||||
            if len(backprops) >= 20:
 | 
			
		||||
                self._make_updates(d_tokvecs,
 | 
			
		||||
                    backprops, sgd, cuda_stream)
 | 
			
		||||
                backprops = []
 | 
			
		||||
            if losses is not None:
 | 
			
		||||
                losses[self.name] += (d_scores**2).sum()
 | 
			
		||||
        if backprops:
 | 
			
		||||
            self._make_updates(d_tokvecs,
 | 
			
		||||
                backprops, sgd, cuda_stream)
 | 
			
		||||
        self._make_updates(d_tokvecs,
 | 
			
		||||
            backprops, sgd, cuda_stream)
 | 
			
		||||
        return self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs])
 | 
			
		||||
 | 
			
		||||
    def _make_updates(self, d_tokvecs, backprops, sgd, cuda_stream=None):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user