mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Add multi-task objective for sentence segmentation
This commit is contained in:
		
							parent
							
								
									e7deadb519
								
							
						
					
					
						commit
						12264f9296
					
				| 
						 | 
					@ -624,11 +624,13 @@ class MultitaskObjective(Tagger):
 | 
				
			||||||
            self.make_label = self.make_dep_tag_offset
 | 
					            self.make_label = self.make_dep_tag_offset
 | 
				
			||||||
        elif target == 'ent_tag':
 | 
					        elif target == 'ent_tag':
 | 
				
			||||||
            self.make_label = self.make_ent_tag
 | 
					            self.make_label = self.make_ent_tag
 | 
				
			||||||
 | 
					        elif target == 'sent_start':
 | 
				
			||||||
 | 
					            self.make_label = self.make_sent_start
 | 
				
			||||||
        elif hasattr(target, '__call__'):
 | 
					        elif hasattr(target, '__call__'):
 | 
				
			||||||
            self.make_label = target
 | 
					            self.make_label = target
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            raise ValueError("MultitaskObjective target should be function or "
 | 
					            raise ValueError("MultitaskObjective target should be function or "
 | 
				
			||||||
                             "one of: dep, tag, ent, dep_tag_offset, ent_tag.")
 | 
					                             "one of: dep, tag, ent, sent_start, dep_tag_offset, ent_tag.")
 | 
				
			||||||
        self.cfg = dict(cfg)
 | 
					        self.cfg = dict(cfg)
 | 
				
			||||||
        self.cfg.setdefault('cnn_maxout_pieces', 2)
 | 
					        self.cfg.setdefault('cnn_maxout_pieces', 2)
 | 
				
			||||||
        self.cfg.setdefault('pretrained_dims',
 | 
					        self.cfg.setdefault('pretrained_dims',
 | 
				
			||||||
| 
						 | 
					@ -737,6 +739,52 @@ class MultitaskObjective(Tagger):
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return '%s-%s' % (tags[i], ents[i])
 | 
					            return '%s-%s' % (tags[i], ents[i])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def make_sent_start(target, words, tags, heads, deps, ents, cache=True, _cache={}):
 | 
				
			||||||
 | 
					        '''A multi-task objective for representing sentence boundaries,
 | 
				
			||||||
 | 
					        using BILU scheme. (O is impossible)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        The implementation of this method uses an internal cache that relies
 | 
				
			||||||
 | 
					        on the identity of the heads array, to avoid requiring a new piece
 | 
				
			||||||
 | 
					        of gold data. You can pass cache=False if you know the cache will
 | 
				
			||||||
 | 
					        do the wrong thing.
 | 
				
			||||||
 | 
					        '''
 | 
				
			||||||
 | 
					        if cache:
 | 
				
			||||||
 | 
					            if id(heads) in _cache:
 | 
				
			||||||
 | 
					                return _cache[id(heads)][target]
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                for key in list(_cache.keys()):
 | 
				
			||||||
 | 
					                    _cache.pop(key)
 | 
				
			||||||
 | 
					            sent_tags = ['I-SENT'] * len(words)
 | 
				
			||||||
 | 
					            _cache[id(heads)] = sent_tags
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            sent_tags = ['I-SENT'] * len(words)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def _find_root(child):
 | 
				
			||||||
 | 
					            while heads[child] != child:
 | 
				
			||||||
 | 
					                if heads[child] is None:
 | 
				
			||||||
 | 
					                    if child == 0:
 | 
				
			||||||
 | 
					                        return child
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        child -= 1
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    child = heads[child]
 | 
				
			||||||
 | 
					            return child
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sentences = {}
 | 
				
			||||||
 | 
					        for i in range(len(words)):
 | 
				
			||||||
 | 
					            root = _find_root(i)
 | 
				
			||||||
 | 
					            sentences.setdefault(root, []).append(i)
 | 
				
			||||||
 | 
					        for root, span in sorted(sentences.items()):
 | 
				
			||||||
 | 
					            if len(span) == 1:
 | 
				
			||||||
 | 
					                sent_tags[span[0]] = 'U-SENT'
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                sent_tags[span[0]] = 'B-SENT'
 | 
				
			||||||
 | 
					                sent_tags[span[-1]] = 'L-SENT'
 | 
				
			||||||
 | 
					        return sent_tags[target]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SimilarityHook(Pipe):
 | 
					class SimilarityHook(Pipe):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user