mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Add dropout to parser
This commit is contained in:
		
							parent
							
								
									1f292bfd17
								
							
						
					
					
						commit
						2a91d641e6
					
				| 
						 | 
				
			
			@ -193,13 +193,11 @@ cdef class Parser:
 | 
			
		|||
        elif 'features' not in cfg:
 | 
			
		||||
            cfg['features'] = self.feature_templates
 | 
			
		||||
        self.model = ParserModel(self.moves.n_moves, cfg['features'],
 | 
			
		||||
                                 size=2**18,
 | 
			
		||||
                                 size=2**14,
 | 
			
		||||
                                 learn_rate=cfg.get('learn_rate', 0.001))
 | 
			
		||||
        #self.model.l1_penalty = cfg.get('L1', 1e-8)
 | 
			
		||||
        #self.model.learn_rate = cfg.get('learn_rate', 0.001)
 | 
			
		||||
        #self.model.l1_penalty = cfg.get('L1', 0.0)
 | 
			
		||||
 | 
			
		||||
        self.optimizer = SGD(NumpyOps(), cfg.get('learn_rate', 0.001),
 | 
			
		||||
                             momentum=0.9)
 | 
			
		||||
        self.optimizer = Adam(NumpyOps(), cfg.get('learn_rate', 0.001))
 | 
			
		||||
 | 
			
		||||
        self.cfg = cfg
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -337,9 +335,19 @@ cdef class Parser:
 | 
			
		|||
        cdef Transition action
 | 
			
		||||
        words = [w.text for w in tokens]
 | 
			
		||||
 
 | 
			
		||||
        cdef int i
 | 
			
		||||
        cdef double[::1] py_dropout
 | 
			
		||||
        cdef double* dropout
 | 
			
		||||
        while not stcls.is_final():
 | 
			
		||||
 | 
			
		||||
            nr_feat = self.model.set_featuresC(context, features, stcls.c)
 | 
			
		||||
            py_dropout = numpy.random.uniform(0., 1., nr_feat)
 | 
			
		||||
            dropout = &py_dropout[0]
 | 
			
		||||
            for i in range(nr_feat):
 | 
			
		||||
                if dropout[i] < 0.5:
 | 
			
		||||
                    features[i].value = 0
 | 
			
		||||
                else:
 | 
			
		||||
                    features[i].value *= 2
 | 
			
		||||
            self.moves.set_costs(is_valid, costs, stcls, gold)
 | 
			
		||||
            self.model.set_scoresC(scores, features, nr_feat)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -347,6 +355,9 @@ cdef class Parser:
 | 
			
		|||
            best = arg_max_if_gold(scores, costs, nr_class)
 | 
			
		||||
 | 
			
		||||
            self.model.regression_lossC(d_scores, scores, costs)
 | 
			
		||||
            for i in range(nr_class):
 | 
			
		||||
                if not is_valid[i]:
 | 
			
		||||
                    d_scores[i] = 0
 | 
			
		||||
            self.model.set_gradientC(d_scores, features, nr_feat) 
 | 
			
		||||
 | 
			
		||||
            action = self.moves.c[guess]
 | 
			
		||||
| 
						 | 
				
			
			@ -354,7 +365,7 @@ cdef class Parser:
 | 
			
		|||
            #print(scores[guess], scores[best], d_scores[guess], costs[guess],
 | 
			
		||||
            #    self.moves.move_name(action.move, action.label), stcls.print_state(words))
 | 
			
		||||
 | 
			
		||||
            loss += scores[guess]
 | 
			
		||||
            loss += abs(scores[guess] + costs[guess])
 | 
			
		||||
            memset(context, 0, sizeof(context))
 | 
			
		||||
            memset(features, 0, sizeof(features[0]) * nr_feat)
 | 
			
		||||
            memset(scores, 0, sizeof(scores[0]) * nr_class)
 | 
			
		||||
| 
						 | 
				
			
			@ -363,8 +374,7 @@ cdef class Parser:
 | 
			
		|||
            for i in range(nr_class):
 | 
			
		||||
                is_valid[i] = 1
 | 
			
		||||
        #if itn % 100 == 0:
 | 
			
		||||
        #    self.optimizer(self.model.model[0].ravel(),
 | 
			
		||||
        #        self.model.model[1].ravel(), key=1)
 | 
			
		||||
        #    self.model.finish_update(self.optimizer)
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
    def step_through(self, Doc doc):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user