Add gold.spans_from_biluo_tags helper (#3227)

This commit is contained in:
Ines Montani 2019-02-06 11:50:26 +01:00 committed by Matthew Honnibal
parent 5e16490d9d
commit f25bd9f5e4
2 changed files with 33 additions and 7 deletions

View File

@ -12,7 +12,7 @@ import srsly
from . import _align
from .syntax import nonproj
from .tokens import Doc
from .tokens import Doc, Span
from .errors import Errors
from . import util
from .util import minibatch, itershuffle
@ -659,6 +659,24 @@ def biluo_tags_from_offsets(doc, entities, missing='O'):
return biluo
def spans_from_biluo_tags(doc, tags):
"""Encode per-token tags following the BILUO scheme into Span object, e.g.
to overwrite the doc.ents.
doc (Doc): The document that the BILUO tags refer to.
entities (iterable): A sequence of BILUO tags with each tag describing one
token. Each tags string will be of the form of either "", "O" or
"{action}-{label}", where action is one of "B", "I", "L", "U".
RETURNS (list): A sequence of Span objects.
"""
token_offsets = tags_to_entities(tags)
spans = []
for label, start_idx, end_idx in token_offsets:
span = Span(doc, start_idx, end_idx + 1, label=label)
spans.append(span)
return spans
def offsets_from_biluo_tags(doc, tags):
"""Encode per-token tags following the BILUO scheme into entity offsets.
@ -670,12 +688,8 @@ def offsets_from_biluo_tags(doc, tags):
`end` will be character-offset integers denoting the slice into the
original string.
"""
token_offsets = tags_to_entities(tags)
offsets = []
for label, start_idx, end_idx in token_offsets:
span = doc[start_idx : end_idx + 1]
offsets.append((span.start_char, span.end_char, label))
return offsets
spans = spans_from_biluo_tags(doc, tags)
return [(span.start_char, span.end_char, span.label_) for span in spans]
def is_punct_label(label):

View File

@ -2,6 +2,7 @@
from __future__ import unicode_literals
from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags
from spacy.gold import spans_from_biluo_tags
from spacy.tokens import Doc
@ -50,3 +51,14 @@ def test_roundtrip_offsets_biluo_conversion(en_tokenizer):
assert biluo_tags_converted == biluo_tags
offsets_converted = offsets_from_biluo_tags(doc, biluo_tags)
assert offsets_converted == offsets
def test_biluo_spans(en_tokenizer):
doc = en_tokenizer("I flew to Silicon Valley via London.")
biluo_tags = ["O", "O", "O", "B-LOC", "L-LOC", "O", "U-GPE", "O"]
spans = spans_from_biluo_tags(doc, biluo_tags)
assert len(spans) == 2
assert spans[0].text == "Silicon Valley"
assert spans[0].label_ == "LOC"
assert spans[1].text == "London"
assert spans[1].label_ == "GPE"