Add function for entity->biluo transformation

This commit is contained in:
Matthew Honnibal 2016-10-15 21:51:04 +02:00
parent 2163fd238f
commit 86ae665c78

View File

@ -267,6 +267,69 @@ cdef class GoldParse:
return not nonproj.is_nonproj_tree(self.heads)
def biluo_tags_from_offsets(doc, entities):
'''Encode labelled spans into per-token tags, using the Begin/In/Last/Unit/Out
scheme (biluo).
Arguments:
doc (Doc):
The document that the entity offsets refer to. The output tags will
refer to the token boundaries within the document.
entities (sequence):
A sequence of (start, end, label) triples. start and end should be
character-offset integers denoting the slice into the original string.
Returns:
tags (list):
A list of unicode strings, describing the tags. Each tag string will
be of the form either "", "O" or "{action}-{label}", where action is one
of "B", "I", "L", "U". The empty string "" is used where the entity
offsets don't align with the tokenization in the Doc object. The
training algorithm will view these as missing values. "O" denotes
a non-entity token. "B" denotes the beginning of a multi-token entity,
"I" the inside of an entity of three or more tokens, and "L" the end
of an entity of two or more tokens. "U" denotes a single-token entity.
Example:
text = 'I like London.'
entities = [(len('I like '), len('I like London'), 'LOC')]
doc = nlp.tokenizer(text)
tags = biluo_tags_from_offsets(doc, entities)
assert tags == ['O', 'O', 'U-LOC', 'O']
'''
starts = {token.idx: token.i for token in doc}
ends = {token.idx+len(token): token.i for token in doc}
biluo = ['' for _ in doc]
# Handle entity cases
for start_char, end_char, label in entities:
start_token = starts.get(start_char)
end_token = ends.get(end_char)
# Only interested if the tokenization is correct
if start_token is not None and end_token is not None:
if start_token == end_token:
biluo[start_token] = 'U-%s' % label
else:
biluo[start_token] = 'B-%s' % label
for i in range(start_token+1, end_token):
biluo[i] = 'I-%s' % label
biluo[end_token] = 'L-%s' % label
# Now distinguish the O cases from ones where we miss the tokenization
entity_chars = set()
for start_char, end_char, label in entities:
for i in range(start_char, end_char):
entity_chars.add(i)
for token in doc:
for i in range(token.idx, token.idx+len(token)):
if i in entity_chars:
break
else:
biluo[token.i] = 'O'
return biluo
def is_punct_label(label):
return label == 'P' or label.lower() == 'punct'