From c699aec089b589d14ae5c74a6f5bcc1d6db109eb Mon Sep 17 00:00:00 2001 From: ines Date: Sun, 26 Nov 2017 16:38:01 +0100 Subject: [PATCH] Add offsets_from_biluo_tags helper and tests (see #1626) --- spacy/gold.pyx | 19 +++++++++++++++++++ spacy/tests/gold/test_biluo.py | 13 ++++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/spacy/gold.pyx b/spacy/gold.pyx index d6db9b853..dff5fc147 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -541,5 +541,24 @@ def biluo_tags_from_offsets(doc, entities, missing='O'): return biluo +def offsets_from_biluo_tags(doc, tags): + """Encode per-token tags following the BILUO scheme into entity offsets. + + doc (Doc): The document that the BILUO tags refer to. + entities (iterable): A sequence of BILUO tags with each tag describing one + token. Each tags string will be of the form of either "", "O" or + "{action}-{label}", where action is one of "B", "I", "L", "U". + RETURNS (list): A sequence of `(start, end, label)` triples. `start` and + `end` will be character-offset integers denoting the slice into the + original string. + """ + token_offsets = tags_to_entities(tags) + offsets = [] + for label, start_idx, end_idx in token_offsets: + span = doc[start_idx : end_idx + 1] + offsets.append((span.start_char, span.end_char, label)) + return offsets + + def is_punct_label(label): return label == 'P' or label.lower() == 'punct' diff --git a/spacy/tests/gold/test_biluo.py b/spacy/tests/gold/test_biluo.py index a1aa91cf0..b89dd46b8 100644 --- a/spacy/tests/gold/test_biluo.py +++ b/spacy/tests/gold/test_biluo.py @@ -1,7 +1,7 @@ # coding: utf-8 from __future__ import unicode_literals -from ...gold import biluo_tags_from_offsets +from ...gold import biluo_tags_from_offsets, offsets_from_biluo_tags from ...tokens.doc import Doc import pytest @@ -41,3 +41,14 @@ def test_gold_biluo_misalign(en_vocab): entities = [(len("I flew to "), len("I flew to San Francisco Valley"), 'LOC')] tags = biluo_tags_from_offsets(doc, entities) assert tags == ['O', 'O', 'O', '-', '-', '-'] + + +def test_roundtrip_offsets_biluo_conversion(en_tokenizer): + text = "I flew to Silicon Valley via London." + biluo_tags = ['O', 'O', 'O', 'B-LOC', 'L-LOC', 'O', 'U-GPE', 'O'] + offsets = [(10, 24, 'LOC'), (29, 35, 'GPE')] + doc = en_tokenizer(text) + biluo_tags_converted = biluo_tags_from_offsets(doc, offsets) + assert biluo_tags_converted == biluo_tags + offsets_converted = offsets_from_biluo_tags(doc, biluo_tags) + assert offsets_converted == offsets