mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Add labels implicitly for parser and ner
This commit is contained in:
		
							parent
							
								
									68b1c2984d
								
							
						
					
					
						commit
						1d20e21f3e
					
				| 
						 | 
					@ -614,10 +614,22 @@ cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
        actions[LEFT].setdefault('dep', 0)
 | 
					        actions[LEFT].setdefault('dep', 0)
 | 
				
			||||||
        return actions
 | 
					        return actions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def builtin_labels(self):
 | 
				
			||||||
 | 
					        return ["ROOT", "dep"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def action_types(self):
 | 
					    def action_types(self):
 | 
				
			||||||
        return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
 | 
					        return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_doc_labels(self, doc):
 | 
				
			||||||
 | 
					        """Get the labels required for a given Doc."""
 | 
				
			||||||
 | 
					        labels = set(self.builtin_labels)
 | 
				
			||||||
 | 
					        for token in doc:
 | 
				
			||||||
 | 
					            if token.dep_:
 | 
				
			||||||
 | 
					                labels.add(token.dep_)
 | 
				
			||||||
 | 
					        return labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def transition(self, StateClass state, action):
 | 
					    def transition(self, StateClass state, action):
 | 
				
			||||||
        cdef Transition t = self.lookup_transition(action)
 | 
					        cdef Transition t = self.lookup_transition(action)
 | 
				
			||||||
        t.do(state.c, t.label)
 | 
					        t.do(state.c, t.label)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -126,6 +126,13 @@ cdef class BiluoPushDown(TransitionSystem):
 | 
				
			||||||
    def action_types(self):
 | 
					    def action_types(self):
 | 
				
			||||||
        return (BEGIN, IN, LAST, UNIT, OUT)
 | 
					        return (BEGIN, IN, LAST, UNIT, OUT)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_doc_labels(self, doc):
 | 
				
			||||||
 | 
					        labels = set()
 | 
				
			||||||
 | 
					        for token in doc:
 | 
				
			||||||
 | 
					            if token.ent_type:
 | 
				
			||||||
 | 
					                labels.add(token.ent_type_)
 | 
				
			||||||
 | 
					        return labels
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
    def move_name(self, int move, attr_t label):
 | 
					    def move_name(self, int move, attr_t label):
 | 
				
			||||||
        if move == OUT:
 | 
					        if move == OUT:
 | 
				
			||||||
            return 'O'
 | 
					            return 'O'
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -132,6 +132,23 @@ cdef class Parser(TrainablePipe):
 | 
				
			||||||
            return 1
 | 
					            return 1
 | 
				
			||||||
        return 0
 | 
					        return 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _ensure_labels_are_added(self, docs):
 | 
				
			||||||
 | 
					        """Ensure that all labels for a batch of docs are added."""
 | 
				
			||||||
 | 
					        resized = False
 | 
				
			||||||
 | 
					        labels = set()
 | 
				
			||||||
 | 
					        for doc in docs:
 | 
				
			||||||
 | 
					            labels.update(self.moves.get_doc_labels(doc))
 | 
				
			||||||
 | 
					        for label in labels:
 | 
				
			||||||
 | 
					            for action in self.moves.action_types:
 | 
				
			||||||
 | 
					                added = self.moves.add_action(action, label)
 | 
				
			||||||
 | 
					                if added:
 | 
				
			||||||
 | 
					                    self.vocab.strings.add(label)
 | 
				
			||||||
 | 
					                    resized = True
 | 
				
			||||||
 | 
					        if resized:
 | 
				
			||||||
 | 
					            self._resize()
 | 
				
			||||||
 | 
					            return 1
 | 
				
			||||||
 | 
					        return 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _resize(self):
 | 
					    def _resize(self):
 | 
				
			||||||
        self.model.attrs["resize_output"](self.model, self.moves.n_moves)
 | 
					        self.model.attrs["resize_output"](self.model, self.moves.n_moves)
 | 
				
			||||||
        if self._rehearsal_model not in (True, False, None):
 | 
					        if self._rehearsal_model not in (True, False, None):
 | 
				
			||||||
| 
						 | 
					@ -188,9 +205,9 @@ cdef class Parser(TrainablePipe):
 | 
				
			||||||
    def predict(self, docs):
 | 
					    def predict(self, docs):
 | 
				
			||||||
        if isinstance(docs, Doc):
 | 
					        if isinstance(docs, Doc):
 | 
				
			||||||
            docs = [docs]
 | 
					            docs = [docs]
 | 
				
			||||||
 | 
					        self._ensure_labels_are_added(docs)
 | 
				
			||||||
        if not any(len(doc) for doc in docs):
 | 
					        if not any(len(doc) for doc in docs):
 | 
				
			||||||
            result = self.moves.init_batch(docs)
 | 
					            result = self.moves.init_batch(docs)
 | 
				
			||||||
            self._resize()
 | 
					 | 
				
			||||||
            return result
 | 
					            return result
 | 
				
			||||||
        if self.cfg["beam_width"] == 1:
 | 
					        if self.cfg["beam_width"] == 1:
 | 
				
			||||||
            return self.greedy_parse(docs, drop=0.0)
 | 
					            return self.greedy_parse(docs, drop=0.0)
 | 
				
			||||||
| 
						 | 
					@ -207,10 +224,6 @@ cdef class Parser(TrainablePipe):
 | 
				
			||||||
        cdef StateClass state
 | 
					        cdef StateClass state
 | 
				
			||||||
        set_dropout_rate(self.model, drop)
 | 
					        set_dropout_rate(self.model, drop)
 | 
				
			||||||
        batch = self.moves.init_batch(docs)
 | 
					        batch = self.moves.init_batch(docs)
 | 
				
			||||||
        # This is pretty dirty, but the NER can resize itself in init_batch,
 | 
					 | 
				
			||||||
        # if labels are missing. We therefore have to check whether we need to
 | 
					 | 
				
			||||||
        # expand our model output.
 | 
					 | 
				
			||||||
        self._resize()
 | 
					 | 
				
			||||||
        model = self.model.predict(docs)
 | 
					        model = self.model.predict(docs)
 | 
				
			||||||
        weights = get_c_weights(model)
 | 
					        weights = get_c_weights(model)
 | 
				
			||||||
        for state in batch:
 | 
					        for state in batch:
 | 
				
			||||||
| 
						 | 
					@ -234,10 +247,6 @@ cdef class Parser(TrainablePipe):
 | 
				
			||||||
            beam_width,
 | 
					            beam_width,
 | 
				
			||||||
            density=beam_density
 | 
					            density=beam_density
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        # This is pretty dirty, but the NER can resize itself in init_batch,
 | 
					 | 
				
			||||||
        # if labels are missing. We therefore have to check whether we need to
 | 
					 | 
				
			||||||
        # expand our model output.
 | 
					 | 
				
			||||||
        self._resize()
 | 
					 | 
				
			||||||
        model = self.model.predict(docs)
 | 
					        model = self.model.predict(docs)
 | 
				
			||||||
        while not batch.is_done:
 | 
					        while not batch.is_done:
 | 
				
			||||||
            states = batch.get_unfinished_states()
 | 
					            states = batch.get_unfinished_states()
 | 
				
			||||||
| 
						 | 
					@ -314,6 +323,9 @@ cdef class Parser(TrainablePipe):
 | 
				
			||||||
            losses = {}
 | 
					            losses = {}
 | 
				
			||||||
        losses.setdefault(self.name, 0.)
 | 
					        losses.setdefault(self.name, 0.)
 | 
				
			||||||
        validate_examples(examples, "Parser.update")
 | 
					        validate_examples(examples, "Parser.update")
 | 
				
			||||||
 | 
					        self._ensure_labels_are_added(
 | 
				
			||||||
 | 
					            [eg.x for eg in examples] + [eg.y for eg in examples]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        for multitask in self._multitasks:
 | 
					        for multitask in self._multitasks:
 | 
				
			||||||
            multitask.update(examples, drop=drop, sgd=sgd)
 | 
					            multitask.update(examples, drop=drop, sgd=sgd)
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user