Add multi-task objective for sentence segmentation

This commit is contained in:
Matthew Honnibal 2018-02-23 16:25:57 +01:00
parent e7deadb519
commit 12264f9296

View File

@ -624,11 +624,13 @@ class MultitaskObjective(Tagger):
self.make_label = self.make_dep_tag_offset self.make_label = self.make_dep_tag_offset
elif target == 'ent_tag': elif target == 'ent_tag':
self.make_label = self.make_ent_tag self.make_label = self.make_ent_tag
elif target == 'sent_start':
self.make_label = self.make_sent_start
elif hasattr(target, '__call__'): elif hasattr(target, '__call__'):
self.make_label = target self.make_label = target
else: else:
raise ValueError("MultitaskObjective target should be function or " raise ValueError("MultitaskObjective target should be function or "
"one of: dep, tag, ent, dep_tag_offset, ent_tag.") "one of: dep, tag, ent, sent_start, dep_tag_offset, ent_tag.")
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.cfg.setdefault('cnn_maxout_pieces', 2) self.cfg.setdefault('cnn_maxout_pieces', 2)
self.cfg.setdefault('pretrained_dims', self.cfg.setdefault('pretrained_dims',
@ -737,6 +739,52 @@ class MultitaskObjective(Tagger):
else: else:
return '%s-%s' % (tags[i], ents[i]) return '%s-%s' % (tags[i], ents[i])
@staticmethod
def make_sent_start(target, words, tags, heads, deps, ents, cache=True, _cache={}):
'''A multi-task objective for representing sentence boundaries,
using BILU scheme. (O is impossible)
The implementation of this method uses an internal cache that relies
on the identity of the heads array, to avoid requiring a new piece
of gold data. You can pass cache=False if you know the cache will
do the wrong thing.
'''
if cache:
if id(heads) in _cache:
return _cache[id(heads)][target]
else:
for key in list(_cache.keys()):
_cache.pop(key)
sent_tags = ['I-SENT'] * len(words)
_cache[id(heads)] = sent_tags
else:
sent_tags = ['I-SENT'] * len(words)
def _find_root(child):
while heads[child] != child:
if heads[child] is None:
if child == 0:
return child
else:
child -= 1
else:
child = heads[child]
return child
sentences = {}
for i in range(len(words)):
root = _find_root(i)
sentences.setdefault(root, []).append(i)
for root, span in sorted(sentences.items()):
if len(span) == 1:
sent_tags[span[0]] = 'U-SENT'
else:
sent_tags[span[0]] = 'B-SENT'
sent_tags[span[-1]] = 'L-SENT'
return sent_tags[target]
class SimilarityHook(Pipe): class SimilarityHook(Pipe):
""" """