mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Allow multitask objectives to be added to the parser and NER more easily
This commit is contained in:
		
							parent
							
								
									4a7d524efb
								
							
						
					
					
						commit
						203d2ea830
					
				| 
						 | 
				
			
			@ -882,14 +882,16 @@ cdef class DependencyParser(Parser):
 | 
			
		|||
    def postprocesses(self):
 | 
			
		||||
        return [nonproj.deprojectivize]
 | 
			
		||||
    
 | 
			
		||||
    def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
 | 
			
		||||
        for target in []:
 | 
			
		||||
    def add_multitask_objective(self, target):
 | 
			
		||||
        labeller = MultitaskObjective(self.vocab, target=target)
 | 
			
		||||
        self._multitasks.append(labeller)
 | 
			
		||||
 | 
			
		||||
    def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
 | 
			
		||||
        for labeller in self._multitasks:
 | 
			
		||||
            tok2vec = self.model[0]
 | 
			
		||||
            labeller.begin_training(gold_tuples, pipeline=pipeline,
 | 
			
		||||
                                    tok2vec=tok2vec, sgd=sgd)
 | 
			
		||||
            pipeline.append(labeller)
 | 
			
		||||
            self._multitasks.append(labeller)
 | 
			
		||||
            pipeline.append((labeller.name, labeller))
 | 
			
		||||
 | 
			
		||||
    def __reduce__(self):
 | 
			
		||||
        return (DependencyParser, (self.vocab, self.moves, self.model),
 | 
			
		||||
| 
						 | 
				
			
			@ -902,14 +904,16 @@ cdef class EntityRecognizer(Parser):
 | 
			
		|||
 | 
			
		||||
    nr_feature = 6
 | 
			
		||||
    
 | 
			
		||||
    def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
 | 
			
		||||
        for target in []:
 | 
			
		||||
    def add_multitask_objective(self, target):
 | 
			
		||||
        labeller = MultitaskObjective(self.vocab, target=target)
 | 
			
		||||
        self._multitasks.append(labeller)
 | 
			
		||||
 | 
			
		||||
    def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
 | 
			
		||||
        for labeller in self._multitasks:
 | 
			
		||||
            tok2vec = self.model[0]
 | 
			
		||||
            labeller.begin_training(gold_tuples, pipeline=pipeline,
 | 
			
		||||
                                    tok2vec=tok2vec)
 | 
			
		||||
            pipeline.append(labeller)
 | 
			
		||||
            self._multitasks.append(labeller)
 | 
			
		||||
            pipeline.append((labeller.name, labeller))
 | 
			
		||||
 | 
			
		||||
    def __reduce__(self):
 | 
			
		||||
        return (EntityRecognizer, (self.vocab, self.moves, self.model),
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -269,9 +269,6 @@ cdef class Parser:
 | 
			
		|||
                zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # TODO: This is an unfortunate hack atm!
 | 
			
		||||
        # Used to set input dimensions in network.
 | 
			
		||||
        lower.begin_training(lower.ops.allocate((500, token_vector_width)))
 | 
			
		||||
        cfg = {
 | 
			
		||||
            'nr_class': nr_class,
 | 
			
		||||
            'hidden_depth': depth,
 | 
			
		||||
| 
						 | 
				
			
			@ -840,8 +837,14 @@ cdef class Parser:
 | 
			
		|||
            self.cfg.update(cfg)
 | 
			
		||||
        elif sgd is None:
 | 
			
		||||
            sgd = self.create_optimizer()
 | 
			
		||||
        self.model[1].begin_training(
 | 
			
		||||
            self.model[1].ops.allocate((5, cfg['token_vector_width'])))
 | 
			
		||||
        return sgd
 | 
			
		||||
 | 
			
		||||
    def add_multitask_objective(self, target):
 | 
			
		||||
        # Defined in subclasses, to avoid circular import
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
    
 | 
			
		||||
    def init_multitask_objectives(self, gold_tuples, pipeline, **cfg):
 | 
			
		||||
        '''Setup models for secondary objectives, to benefit from multi-task
 | 
			
		||||
        learning. This method is intended to be overridden by subclasses.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user