mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Try to fix parser training
This commit is contained in:
		
							parent
							
								
									3a6b93ae3a
								
							
						
					
					
						commit
						456c881ae3
					
				| 
						 | 
				
			
			@ -83,6 +83,8 @@ cdef class TransitionSystem:
 | 
			
		|||
    def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None):
 | 
			
		||||
        if state.is_final():
 | 
			
		||||
            return []
 | 
			
		||||
        if not self.has_gold(eg):
 | 
			
		||||
            return []
 | 
			
		||||
        cdef Pool mem = Pool()
 | 
			
		||||
        # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
 | 
			
		||||
        assert self.n_moves > 0
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -316,8 +316,9 @@ cdef class Parser(TrainablePipe):
 | 
			
		|||
        validate_examples(examples, "Parser.update")
 | 
			
		||||
        for multitask in self._multitasks:
 | 
			
		||||
            multitask.update(examples, drop=drop, sgd=sgd)
 | 
			
		||||
    
 | 
			
		||||
        examples = [eg for eg in examples if self.moves.has_gold(eg)]
 | 
			
		||||
        # We need to take care to act on the whole batch, because we might be
 | 
			
		||||
        # getting vectors via a listener.
 | 
			
		||||
        n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
 | 
			
		||||
        if len(examples) == 0:
 | 
			
		||||
            return losses
 | 
			
		||||
        set_dropout_rate(self.model, drop)
 | 
			
		||||
| 
						 | 
				
			
			@ -347,7 +348,8 @@ cdef class Parser(TrainablePipe):
 | 
			
		|||
            states, golds, _ = self.moves.init_gold_batch(examples)
 | 
			
		||||
        if not states:
 | 
			
		||||
            return losses
 | 
			
		||||
        model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
 | 
			
		||||
        docs = [eg.predicted for eg in examples]
 | 
			
		||||
        model, backprop_tok2vec = self.model.begin_update(docs)
 | 
			
		||||
 
 | 
			
		||||
        all_states = list(states)
 | 
			
		||||
        states_golds = list(zip(states, golds))
 | 
			
		||||
| 
						 | 
				
			
			@ -371,7 +373,6 @@ cdef class Parser(TrainablePipe):
 | 
			
		|||
        backprop_tok2vec(golds)
 | 
			
		||||
        if sgd not in (None, False):
 | 
			
		||||
            self.finish_update(sgd)
 | 
			
		||||
        docs = [eg.predicted for eg in examples]
 | 
			
		||||
        # If we want to set the annotations based on predictions, it's really
 | 
			
		||||
        # hard to avoid parsing the data twice :(. 
 | 
			
		||||
        # The issue is that we cut up the gold batch into sub-states, and that
 | 
			
		||||
| 
						 | 
				
			
			@ -601,7 +602,7 @@ cdef class Parser(TrainablePipe):
 | 
			
		|||
        states = []
 | 
			
		||||
        golds = []
 | 
			
		||||
        for state, eg, history in zip(all_states, examples, oracle_histories):
 | 
			
		||||
            if state.is_final():
 | 
			
		||||
            if not history:
 | 
			
		||||
                continue
 | 
			
		||||
            gold = self.moves.init_gold(state, eg)
 | 
			
		||||
            if len(history) < max_length:
 | 
			
		||||
| 
						 | 
				
			
			@ -609,6 +610,8 @@ cdef class Parser(TrainablePipe):
 | 
			
		|||
                golds.append(gold)
 | 
			
		||||
                continue
 | 
			
		||||
            for i in range(0, len(history), max_length):
 | 
			
		||||
                if state.is_final():
 | 
			
		||||
                    break
 | 
			
		||||
                start_state = state.copy()
 | 
			
		||||
                for clas in history[i:i+max_length]:
 | 
			
		||||
                    action = self.moves.c[clas]
 | 
			
		||||
| 
						 | 
				
			
			@ -618,6 +621,4 @@ cdef class Parser(TrainablePipe):
 | 
			
		|||
                if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
 | 
			
		||||
                    states.append(start_state)
 | 
			
		||||
                    golds.append(gold)
 | 
			
		||||
                if state.is_final():
 | 
			
		||||
                    break
 | 
			
		||||
        return states, golds, max_length
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user