Add unit test. Refactor to allow for unit test.

This commit is contained in:
Raphael Mitsch 2023-10-18 14:35:27 +02:00
parent eb46a9c62d
commit ce814e9e7d
2 changed files with 44 additions and 7 deletions

View File

@ -142,7 +142,25 @@ class SpanRenderer:
spans (list): Individual entity spans and their start, end, label, kb_id and kb_url. spans (list): Individual entity spans and their start, end, label, kb_id and kb_url.
title (str / None): Document title set in Doc.user_data['title']. title (str / None): Document title set in Doc.user_data['title'].
""" """
per_token_info = [] per_token_info = self._assemble_per_token_info(tokens, spans)
markup = self._render_markup(per_token_info)
markup = TPL_SPANS.format(content=markup, dir=self.direction)
if title:
markup = TPL_TITLE.format(title=title) + markup
return markup
@staticmethod
def _assemble_per_token_info(
tokens: List[str], spans: List[Dict[str, Any]]
) -> List[Dict[str, List[Dict[str, Any]]]]:
"""Assembles token info used to generate markup in render_spans().
tokens (List[str]): Tokens in text.
spans (List[Dict[str, Any]]): Spans in text.
RETURNS (List[Dict[str, List[Dict, str, Any]]]): Per token info needed to render HTML markup for given tokens
and spans.
"""
per_token_info: List[Dict[str, List[Dict[str, Any]]]] = []
# we must sort so that we can correctly describe when spans need to "stack" # we must sort so that we can correctly describe when spans need to "stack"
# which is determined by their start token, then span length (longer spans on top), # which is determined by their start token, then span length (longer spans on top),
# then break any remaining ties with the span label # then break any remaining ties with the span label
@ -154,10 +172,12 @@ class SpanRenderer:
s["label"], s["label"],
), ),
) )
for s in spans: for s in spans:
# this is the vertical 'slot' that the span will be rendered in # this is the vertical 'slot' that the span will be rendered in
# vertical_position = span_label_offset + (offset_step * (slot - 1)) # vertical_position = span_label_offset + (offset_step * (slot - 1))
s["render_slot"] = 0 s["render_slot"] = 0
for idx, token in enumerate(tokens): for idx, token in enumerate(tokens):
# Identify if a token belongs to a Span (and which) and if it's a # Identify if a token belongs to a Span (and which) and if it's a
# start token of said Span. We'll use this for the final HTML render # start token of said Span. We'll use this for the final HTML render
@ -197,11 +217,8 @@ class SpanRenderer:
span["render_slot"] = 0 span["render_slot"] = 0
token_markup["entities"] = entities token_markup["entities"] = entities
per_token_info.append(token_markup) per_token_info.append(token_markup)
markup = self._render_markup(per_token_info)
markup = TPL_SPANS.format(content=markup, dir=self.direction) return per_token_info
if title:
markup = TPL_TITLE.format(title=title) + markup
return markup
def _render_markup(self, per_token_info: List[Dict[str, Any]]) -> str: def _render_markup(self, per_token_info: List[Dict[str, Any]]) -> str:
"""Render the markup from per-token information""" """Render the markup from per-token information"""

View File

@ -2,7 +2,7 @@ import numpy
import pytest import pytest
from spacy import displacy from spacy import displacy
from spacy.displacy.render import DependencyRenderer, EntityRenderer from spacy.displacy.render import DependencyRenderer, EntityRenderer, SpanRenderer
from spacy.lang.en import English from spacy.lang.en import English
from spacy.lang.fa import Persian from spacy.lang.fa import Persian
from spacy.tokens import Doc, Span from spacy.tokens import Doc, Span
@ -468,3 +468,23 @@ def test_issue12816(en_vocab) -> None:
# Verify that the HTML tag is still escaped # Verify that the HTML tag is still escaped
html = displacy.render(doc, style="span") html = displacy.render(doc, style="span")
assert "<TEST>" in html assert "<TEST>" in html
@pytest.mark.issue(13056)
def test_displacy_span_stacking():
"""Test whether span stacking works properly for multiple overlapping spans."""
spans = [
{"start_token": 2, "end_token": 5, "label": "SkillNC"},
{"start_token": 0, "end_token": 2, "label": "Skill"},
{"start_token": 1, "end_token": 3, "label": "Skill"},
]
tokens = ["Welcome", "to", "the", "Bank", "of", "China", "."]
per_token_info = SpanRenderer._assemble_per_token_info(spans=spans, tokens=tokens)
assert len(per_token_info) == len(tokens)
assert all([len(per_token_info[i]["entities"]) == 1 for i in (0, 3, 4)])
assert all([len(per_token_info[i]["entities"]) == 2 for i in (1, 2)])
assert per_token_info[1]["entities"][1]["render_slot"] == 1
assert per_token_info[1]["entities"][2]["render_slot"] == 2
assert per_token_info[2]["entities"][1]["render_slot"] == 2
assert per_token_info[2]["entities"][2]["render_slot"] == 3