mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Allow multitask objectives to be added to the parser and NER more easily
This commit is contained in:
		
							parent
							
								
									4a7d524efb
								
							
						
					
					
						commit
						203d2ea830
					
				| 
						 | 
					@ -882,14 +882,16 @@ cdef class DependencyParser(Parser):
 | 
				
			||||||
    def postprocesses(self):
 | 
					    def postprocesses(self):
 | 
				
			||||||
        return [nonproj.deprojectivize]
 | 
					        return [nonproj.deprojectivize]
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
 | 
					    def add_multitask_objective(self, target):
 | 
				
			||||||
        for target in []:
 | 
					 | 
				
			||||||
        labeller = MultitaskObjective(self.vocab, target=target)
 | 
					        labeller = MultitaskObjective(self.vocab, target=target)
 | 
				
			||||||
 | 
					        self._multitasks.append(labeller)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
 | 
				
			||||||
 | 
					        for labeller in self._multitasks:
 | 
				
			||||||
            tok2vec = self.model[0]
 | 
					            tok2vec = self.model[0]
 | 
				
			||||||
            labeller.begin_training(gold_tuples, pipeline=pipeline,
 | 
					            labeller.begin_training(gold_tuples, pipeline=pipeline,
 | 
				
			||||||
                                    tok2vec=tok2vec, sgd=sgd)
 | 
					                                    tok2vec=tok2vec, sgd=sgd)
 | 
				
			||||||
            pipeline.append(labeller)
 | 
					            pipeline.append((labeller.name, labeller))
 | 
				
			||||||
            self._multitasks.append(labeller)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __reduce__(self):
 | 
					    def __reduce__(self):
 | 
				
			||||||
        return (DependencyParser, (self.vocab, self.moves, self.model),
 | 
					        return (DependencyParser, (self.vocab, self.moves, self.model),
 | 
				
			||||||
| 
						 | 
					@ -902,14 +904,16 @@ cdef class EntityRecognizer(Parser):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    nr_feature = 6
 | 
					    nr_feature = 6
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
 | 
					    def add_multitask_objective(self, target):
 | 
				
			||||||
        for target in []:
 | 
					 | 
				
			||||||
        labeller = MultitaskObjective(self.vocab, target=target)
 | 
					        labeller = MultitaskObjective(self.vocab, target=target)
 | 
				
			||||||
 | 
					        self._multitasks.append(labeller)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
 | 
				
			||||||
 | 
					        for labeller in self._multitasks:
 | 
				
			||||||
            tok2vec = self.model[0]
 | 
					            tok2vec = self.model[0]
 | 
				
			||||||
            labeller.begin_training(gold_tuples, pipeline=pipeline,
 | 
					            labeller.begin_training(gold_tuples, pipeline=pipeline,
 | 
				
			||||||
                                    tok2vec=tok2vec)
 | 
					                                    tok2vec=tok2vec)
 | 
				
			||||||
            pipeline.append(labeller)
 | 
					            pipeline.append((labeller.name, labeller))
 | 
				
			||||||
            self._multitasks.append(labeller)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __reduce__(self):
 | 
					    def __reduce__(self):
 | 
				
			||||||
        return (EntityRecognizer, (self.vocab, self.moves, self.model),
 | 
					        return (EntityRecognizer, (self.vocab, self.moves, self.model),
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -269,9 +269,6 @@ cdef class Parser:
 | 
				
			||||||
                zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
 | 
					                zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # TODO: This is an unfortunate hack atm!
 | 
					 | 
				
			||||||
        # Used to set input dimensions in network.
 | 
					 | 
				
			||||||
        lower.begin_training(lower.ops.allocate((500, token_vector_width)))
 | 
					 | 
				
			||||||
        cfg = {
 | 
					        cfg = {
 | 
				
			||||||
            'nr_class': nr_class,
 | 
					            'nr_class': nr_class,
 | 
				
			||||||
            'hidden_depth': depth,
 | 
					            'hidden_depth': depth,
 | 
				
			||||||
| 
						 | 
					@ -840,8 +837,14 @@ cdef class Parser:
 | 
				
			||||||
            self.cfg.update(cfg)
 | 
					            self.cfg.update(cfg)
 | 
				
			||||||
        elif sgd is None:
 | 
					        elif sgd is None:
 | 
				
			||||||
            sgd = self.create_optimizer()
 | 
					            sgd = self.create_optimizer()
 | 
				
			||||||
 | 
					        self.model[1].begin_training(
 | 
				
			||||||
 | 
					            self.model[1].ops.allocate((5, cfg['token_vector_width'])))
 | 
				
			||||||
        return sgd
 | 
					        return sgd
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_multitask_objective(self, target):
 | 
				
			||||||
 | 
					        # Defined in subclasses, to avoid circular import
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
    def init_multitask_objectives(self, gold_tuples, pipeline, **cfg):
 | 
					    def init_multitask_objectives(self, gold_tuples, pipeline, **cfg):
 | 
				
			||||||
        '''Setup models for secondary objectives, to benefit from multi-task
 | 
					        '''Setup models for secondary objectives, to benefit from multi-task
 | 
				
			||||||
        learning. This method is intended to be overridden by subclasses.
 | 
					        learning. This method is intended to be overridden by subclasses.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user