From c4e2daf6ef24a280e3c252e62a8534110c417ce8 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Thu, 2 Nov 2023 12:02:18 +0100 Subject: [PATCH] Fix displacy span stacking (#13068) * Fix displacy span stacking. * Format. Remove counter. * Remove test files. * Add unit test. Refactor to allow for unit test. * Fix off-by-one error in tests. --- spacy/displacy/render.py | 39 +++++++++++++++++++++++++++--------- spacy/tests/test_displacy.py | 22 +++++++++++++++++++- 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/spacy/displacy/render.py b/spacy/displacy/render.py index 2ab41ccc2..40b9986e8 100644 --- a/spacy/displacy/render.py +++ b/spacy/displacy/render.py @@ -142,7 +142,25 @@ class SpanRenderer: spans (list): Individual entity spans and their start, end, label, kb_id and kb_url. title (str / None): Document title set in Doc.user_data['title']. """ - per_token_info = [] + per_token_info = self._assemble_per_token_info(tokens, spans) + markup = self._render_markup(per_token_info) + markup = TPL_SPANS.format(content=markup, dir=self.direction) + if title: + markup = TPL_TITLE.format(title=title) + markup + return markup + + @staticmethod + def _assemble_per_token_info( + tokens: List[str], spans: List[Dict[str, Any]] + ) -> List[Dict[str, List[Dict[str, Any]]]]: + """Assembles token info used to generate markup in render_spans(). + tokens (List[str]): Tokens in text. + spans (List[Dict[str, Any]]): Spans in text. + RETURNS (List[Dict[str, List[Dict, str, Any]]]): Per token info needed to render HTML markup for given tokens + and spans. + """ + per_token_info: List[Dict[str, List[Dict[str, Any]]]] = [] + # we must sort so that we can correctly describe when spans need to "stack" # which is determined by their start token, then span length (longer spans on top), # then break any remaining ties with the span label @@ -154,21 +172,22 @@ class SpanRenderer: s["label"], ), ) + for s in spans: # this is the vertical 'slot' that the span will be rendered in # vertical_position = span_label_offset + (offset_step * (slot - 1)) s["render_slot"] = 0 + for idx, token in enumerate(tokens): # Identify if a token belongs to a Span (and which) and if it's a # start token of said Span. We'll use this for the final HTML render token_markup: Dict[str, Any] = {} token_markup["text"] = token - concurrent_spans = 0 + intersecting_spans: List[Dict[str, Any]] = [] entities = [] for span in spans: ent = {} if span["start_token"] <= idx < span["end_token"]: - concurrent_spans += 1 span_start = idx == span["start_token"] ent["label"] = span["label"] ent["is_start"] = span_start @@ -176,7 +195,12 @@ class SpanRenderer: # When the span starts, we need to know how many other # spans are on the 'span stack' and will be rendered. # This value becomes the vertical render slot for this entire span - span["render_slot"] = concurrent_spans + span["render_slot"] = ( + intersecting_spans[-1]["render_slot"] + if len(intersecting_spans) + else 0 + ) + 1 + intersecting_spans.append(span) ent["render_slot"] = span["render_slot"] kb_id = span.get("kb_id", "") kb_url = span.get("kb_url", "#") @@ -193,11 +217,8 @@ class SpanRenderer: span["render_slot"] = 0 token_markup["entities"] = entities per_token_info.append(token_markup) - markup = self._render_markup(per_token_info) - markup = TPL_SPANS.format(content=markup, dir=self.direction) - if title: - markup = TPL_TITLE.format(title=title) + markup - return markup + + return per_token_info def _render_markup(self, per_token_info: List[Dict[str, Any]]) -> str: """Render the markup from per-token information""" diff --git a/spacy/tests/test_displacy.py b/spacy/tests/test_displacy.py index 12d903dca..b83c7db07 100644 --- a/spacy/tests/test_displacy.py +++ b/spacy/tests/test_displacy.py @@ -2,7 +2,7 @@ import numpy import pytest from spacy import displacy -from spacy.displacy.render import DependencyRenderer, EntityRenderer +from spacy.displacy.render import DependencyRenderer, EntityRenderer, SpanRenderer from spacy.lang.en import English from spacy.lang.fa import Persian from spacy.tokens import Doc, Span @@ -468,3 +468,23 @@ def test_issue12816(en_vocab) -> None: # Verify that the HTML tag is still escaped html = displacy.render(doc, style="span") assert "<TEST>" in html + + +@pytest.mark.issue(13056) +def test_displacy_span_stacking(): + """Test whether span stacking works properly for multiple overlapping spans.""" + spans = [ + {"start_token": 2, "end_token": 5, "label": "SkillNC"}, + {"start_token": 0, "end_token": 2, "label": "Skill"}, + {"start_token": 1, "end_token": 3, "label": "Skill"}, + ] + tokens = ["Welcome", "to", "the", "Bank", "of", "China", "."] + per_token_info = SpanRenderer._assemble_per_token_info(spans=spans, tokens=tokens) + + assert len(per_token_info) == len(tokens) + assert all([len(per_token_info[i]["entities"]) == 1 for i in (0, 3, 4)]) + assert all([len(per_token_info[i]["entities"]) == 2 for i in (1, 2)]) + assert per_token_info[1]["entities"][0]["render_slot"] == 1 + assert per_token_info[1]["entities"][1]["render_slot"] == 2 + assert per_token_info[2]["entities"][0]["render_slot"] == 2 + assert per_token_info[2]["entities"][1]["render_slot"] == 3