diff --git a/spacy/gold.pyx b/spacy/gold.pyx index 69e256167..91204f671 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -12,7 +12,7 @@ import srsly from . import _align from .syntax import nonproj -from .tokens import Doc +from .tokens import Doc, Span from .errors import Errors from . import util from .util import minibatch, itershuffle @@ -659,6 +659,24 @@ def biluo_tags_from_offsets(doc, entities, missing='O'): return biluo +def spans_from_biluo_tags(doc, tags): + """Encode per-token tags following the BILUO scheme into Span object, e.g. + to overwrite the doc.ents. + + doc (Doc): The document that the BILUO tags refer to. + entities (iterable): A sequence of BILUO tags with each tag describing one + token. Each tags string will be of the form of either "", "O" or + "{action}-{label}", where action is one of "B", "I", "L", "U". + RETURNS (list): A sequence of Span objects. + """ + token_offsets = tags_to_entities(tags) + spans = [] + for label, start_idx, end_idx in token_offsets: + span = Span(doc, start_idx, end_idx + 1, label=label) + spans.append(span) + return spans + + def offsets_from_biluo_tags(doc, tags): """Encode per-token tags following the BILUO scheme into entity offsets. @@ -670,12 +688,8 @@ def offsets_from_biluo_tags(doc, tags): `end` will be character-offset integers denoting the slice into the original string. """ - token_offsets = tags_to_entities(tags) - offsets = [] - for label, start_idx, end_idx in token_offsets: - span = doc[start_idx : end_idx + 1] - offsets.append((span.start_char, span.end_char, label)) - return offsets + spans = spans_from_biluo_tags(doc, tags) + return [(span.start_char, span.end_char, span.label_) for span in spans] def is_punct_label(label): diff --git a/spacy/tests/test_gold.py b/spacy/tests/test_gold.py index 7c230f469..30dd2e6c6 100644 --- a/spacy/tests/test_gold.py +++ b/spacy/tests/test_gold.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags +from spacy.gold import spans_from_biluo_tags from spacy.tokens import Doc @@ -50,3 +51,14 @@ def test_roundtrip_offsets_biluo_conversion(en_tokenizer): assert biluo_tags_converted == biluo_tags offsets_converted = offsets_from_biluo_tags(doc, biluo_tags) assert offsets_converted == offsets + + +def test_biluo_spans(en_tokenizer): + doc = en_tokenizer("I flew to Silicon Valley via London.") + biluo_tags = ["O", "O", "O", "B-LOC", "L-LOC", "O", "U-GPE", "O"] + spans = spans_from_biluo_tags(doc, biluo_tags) + assert len(spans) == 2 + assert spans[0].text == "Silicon Valley" + assert spans[0].label_ == "LOC" + assert spans[1].text == "London" + assert spans[1].label_ == "GPE"