Fix displacy span stacking (#13068)

* Fix displacy span stacking.

* Format. Remove counter.

* Remove test files.

* Add unit test. Refactor to allow for unit test.

* Fix off-by-one error in tests.
This commit is contained in:
Raphael Mitsch 2023-11-02 12:02:18 +01:00 committed by GitHub
parent a804b83a4b
commit c4e2daf6ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 10 deletions

View File

@ -142,7 +142,25 @@ class SpanRenderer:
spans (list): Individual entity spans and their start, end, label, kb_id and kb_url. spans (list): Individual entity spans and their start, end, label, kb_id and kb_url.
title (str / None): Document title set in Doc.user_data['title']. title (str / None): Document title set in Doc.user_data['title'].
""" """
per_token_info = [] per_token_info = self._assemble_per_token_info(tokens, spans)
markup = self._render_markup(per_token_info)
markup = TPL_SPANS.format(content=markup, dir=self.direction)
if title:
markup = TPL_TITLE.format(title=title) + markup
return markup
@staticmethod
def _assemble_per_token_info(
tokens: List[str], spans: List[Dict[str, Any]]
) -> List[Dict[str, List[Dict[str, Any]]]]:
"""Assembles token info used to generate markup in render_spans().
tokens (List[str]): Tokens in text.
spans (List[Dict[str, Any]]): Spans in text.
RETURNS (List[Dict[str, List[Dict, str, Any]]]): Per token info needed to render HTML markup for given tokens
and spans.
"""
per_token_info: List[Dict[str, List[Dict[str, Any]]]] = []
# we must sort so that we can correctly describe when spans need to "stack" # we must sort so that we can correctly describe when spans need to "stack"
# which is determined by their start token, then span length (longer spans on top), # which is determined by their start token, then span length (longer spans on top),
# then break any remaining ties with the span label # then break any remaining ties with the span label
@ -154,21 +172,22 @@ class SpanRenderer:
s["label"], s["label"],
), ),
) )
for s in spans: for s in spans:
# this is the vertical 'slot' that the span will be rendered in # this is the vertical 'slot' that the span will be rendered in
# vertical_position = span_label_offset + (offset_step * (slot - 1)) # vertical_position = span_label_offset + (offset_step * (slot - 1))
s["render_slot"] = 0 s["render_slot"] = 0
for idx, token in enumerate(tokens): for idx, token in enumerate(tokens):
# Identify if a token belongs to a Span (and which) and if it's a # Identify if a token belongs to a Span (and which) and if it's a
# start token of said Span. We'll use this for the final HTML render # start token of said Span. We'll use this for the final HTML render
token_markup: Dict[str, Any] = {} token_markup: Dict[str, Any] = {}
token_markup["text"] = token token_markup["text"] = token
concurrent_spans = 0 intersecting_spans: List[Dict[str, Any]] = []
entities = [] entities = []
for span in spans: for span in spans:
ent = {} ent = {}
if span["start_token"] <= idx < span["end_token"]: if span["start_token"] <= idx < span["end_token"]:
concurrent_spans += 1
span_start = idx == span["start_token"] span_start = idx == span["start_token"]
ent["label"] = span["label"] ent["label"] = span["label"]
ent["is_start"] = span_start ent["is_start"] = span_start
@ -176,7 +195,12 @@ class SpanRenderer:
# When the span starts, we need to know how many other # When the span starts, we need to know how many other
# spans are on the 'span stack' and will be rendered. # spans are on the 'span stack' and will be rendered.
# This value becomes the vertical render slot for this entire span # This value becomes the vertical render slot for this entire span
span["render_slot"] = concurrent_spans span["render_slot"] = (
intersecting_spans[-1]["render_slot"]
if len(intersecting_spans)
else 0
) + 1
intersecting_spans.append(span)
ent["render_slot"] = span["render_slot"] ent["render_slot"] = span["render_slot"]
kb_id = span.get("kb_id", "") kb_id = span.get("kb_id", "")
kb_url = span.get("kb_url", "#") kb_url = span.get("kb_url", "#")
@ -193,11 +217,8 @@ class SpanRenderer:
span["render_slot"] = 0 span["render_slot"] = 0
token_markup["entities"] = entities token_markup["entities"] = entities
per_token_info.append(token_markup) per_token_info.append(token_markup)
markup = self._render_markup(per_token_info)
markup = TPL_SPANS.format(content=markup, dir=self.direction) return per_token_info
if title:
markup = TPL_TITLE.format(title=title) + markup
return markup
def _render_markup(self, per_token_info: List[Dict[str, Any]]) -> str: def _render_markup(self, per_token_info: List[Dict[str, Any]]) -> str:
"""Render the markup from per-token information""" """Render the markup from per-token information"""

View File

@ -2,7 +2,7 @@ import numpy
import pytest import pytest
from spacy import displacy from spacy import displacy
from spacy.displacy.render import DependencyRenderer, EntityRenderer from spacy.displacy.render import DependencyRenderer, EntityRenderer, SpanRenderer
from spacy.lang.en import English from spacy.lang.en import English
from spacy.lang.fa import Persian from spacy.lang.fa import Persian
from spacy.tokens import Doc, Span from spacy.tokens import Doc, Span
@ -468,3 +468,23 @@ def test_issue12816(en_vocab) -> None:
# Verify that the HTML tag is still escaped # Verify that the HTML tag is still escaped
html = displacy.render(doc, style="span") html = displacy.render(doc, style="span")
assert "&lt;TEST&gt;" in html assert "&lt;TEST&gt;" in html
@pytest.mark.issue(13056)
def test_displacy_span_stacking():
"""Test whether span stacking works properly for multiple overlapping spans."""
spans = [
{"start_token": 2, "end_token": 5, "label": "SkillNC"},
{"start_token": 0, "end_token": 2, "label": "Skill"},
{"start_token": 1, "end_token": 3, "label": "Skill"},
]
tokens = ["Welcome", "to", "the", "Bank", "of", "China", "."]
per_token_info = SpanRenderer._assemble_per_token_info(spans=spans, tokens=tokens)
assert len(per_token_info) == len(tokens)
assert all([len(per_token_info[i]["entities"]) == 1 for i in (0, 3, 4)])
assert all([len(per_token_info[i]["entities"]) == 2 for i in (1, 2)])
assert per_token_info[1]["entities"][0]["render_slot"] == 1
assert per_token_info[1]["entities"][1]["render_slot"] == 2
assert per_token_info[2]["entities"][0]["render_slot"] == 2
assert per_token_info[2]["entities"][1]["render_slot"] == 3