mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Fix displacy span stacking (#13068)
* Fix displacy span stacking. * Format. Remove counter. * Remove test files. * Add unit test. Refactor to allow for unit test. * Fix off-by-one error in tests.
This commit is contained in:
parent
a804b83a4b
commit
c4e2daf6ef
|
@ -142,7 +142,25 @@ class SpanRenderer:
|
|||
spans (list): Individual entity spans and their start, end, label, kb_id and kb_url.
|
||||
title (str / None): Document title set in Doc.user_data['title'].
|
||||
"""
|
||||
per_token_info = []
|
||||
per_token_info = self._assemble_per_token_info(tokens, spans)
|
||||
markup = self._render_markup(per_token_info)
|
||||
markup = TPL_SPANS.format(content=markup, dir=self.direction)
|
||||
if title:
|
||||
markup = TPL_TITLE.format(title=title) + markup
|
||||
return markup
|
||||
|
||||
@staticmethod
|
||||
def _assemble_per_token_info(
|
||||
tokens: List[str], spans: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, List[Dict[str, Any]]]]:
|
||||
"""Assembles token info used to generate markup in render_spans().
|
||||
tokens (List[str]): Tokens in text.
|
||||
spans (List[Dict[str, Any]]): Spans in text.
|
||||
RETURNS (List[Dict[str, List[Dict, str, Any]]]): Per token info needed to render HTML markup for given tokens
|
||||
and spans.
|
||||
"""
|
||||
per_token_info: List[Dict[str, List[Dict[str, Any]]]] = []
|
||||
|
||||
# we must sort so that we can correctly describe when spans need to "stack"
|
||||
# which is determined by their start token, then span length (longer spans on top),
|
||||
# then break any remaining ties with the span label
|
||||
|
@ -154,21 +172,22 @@ class SpanRenderer:
|
|||
s["label"],
|
||||
),
|
||||
)
|
||||
|
||||
for s in spans:
|
||||
# this is the vertical 'slot' that the span will be rendered in
|
||||
# vertical_position = span_label_offset + (offset_step * (slot - 1))
|
||||
s["render_slot"] = 0
|
||||
|
||||
for idx, token in enumerate(tokens):
|
||||
# Identify if a token belongs to a Span (and which) and if it's a
|
||||
# start token of said Span. We'll use this for the final HTML render
|
||||
token_markup: Dict[str, Any] = {}
|
||||
token_markup["text"] = token
|
||||
concurrent_spans = 0
|
||||
intersecting_spans: List[Dict[str, Any]] = []
|
||||
entities = []
|
||||
for span in spans:
|
||||
ent = {}
|
||||
if span["start_token"] <= idx < span["end_token"]:
|
||||
concurrent_spans += 1
|
||||
span_start = idx == span["start_token"]
|
||||
ent["label"] = span["label"]
|
||||
ent["is_start"] = span_start
|
||||
|
@ -176,7 +195,12 @@ class SpanRenderer:
|
|||
# When the span starts, we need to know how many other
|
||||
# spans are on the 'span stack' and will be rendered.
|
||||
# This value becomes the vertical render slot for this entire span
|
||||
span["render_slot"] = concurrent_spans
|
||||
span["render_slot"] = (
|
||||
intersecting_spans[-1]["render_slot"]
|
||||
if len(intersecting_spans)
|
||||
else 0
|
||||
) + 1
|
||||
intersecting_spans.append(span)
|
||||
ent["render_slot"] = span["render_slot"]
|
||||
kb_id = span.get("kb_id", "")
|
||||
kb_url = span.get("kb_url", "#")
|
||||
|
@ -193,11 +217,8 @@ class SpanRenderer:
|
|||
span["render_slot"] = 0
|
||||
token_markup["entities"] = entities
|
||||
per_token_info.append(token_markup)
|
||||
markup = self._render_markup(per_token_info)
|
||||
markup = TPL_SPANS.format(content=markup, dir=self.direction)
|
||||
if title:
|
||||
markup = TPL_TITLE.format(title=title) + markup
|
||||
return markup
|
||||
|
||||
return per_token_info
|
||||
|
||||
def _render_markup(self, per_token_info: List[Dict[str, Any]]) -> str:
|
||||
"""Render the markup from per-token information"""
|
||||
|
|
|
@ -2,7 +2,7 @@ import numpy
|
|||
import pytest
|
||||
|
||||
from spacy import displacy
|
||||
from spacy.displacy.render import DependencyRenderer, EntityRenderer
|
||||
from spacy.displacy.render import DependencyRenderer, EntityRenderer, SpanRenderer
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.fa import Persian
|
||||
from spacy.tokens import Doc, Span
|
||||
|
@ -468,3 +468,23 @@ def test_issue12816(en_vocab) -> None:
|
|||
# Verify that the HTML tag is still escaped
|
||||
html = displacy.render(doc, style="span")
|
||||
assert "<TEST>" in html
|
||||
|
||||
|
||||
@pytest.mark.issue(13056)
|
||||
def test_displacy_span_stacking():
|
||||
"""Test whether span stacking works properly for multiple overlapping spans."""
|
||||
spans = [
|
||||
{"start_token": 2, "end_token": 5, "label": "SkillNC"},
|
||||
{"start_token": 0, "end_token": 2, "label": "Skill"},
|
||||
{"start_token": 1, "end_token": 3, "label": "Skill"},
|
||||
]
|
||||
tokens = ["Welcome", "to", "the", "Bank", "of", "China", "."]
|
||||
per_token_info = SpanRenderer._assemble_per_token_info(spans=spans, tokens=tokens)
|
||||
|
||||
assert len(per_token_info) == len(tokens)
|
||||
assert all([len(per_token_info[i]["entities"]) == 1 for i in (0, 3, 4)])
|
||||
assert all([len(per_token_info[i]["entities"]) == 2 for i in (1, 2)])
|
||||
assert per_token_info[1]["entities"][0]["render_slot"] == 1
|
||||
assert per_token_info[1]["entities"][1]["render_slot"] == 2
|
||||
assert per_token_info[2]["entities"][0]["render_slot"] == 2
|
||||
assert per_token_info[2]["entities"][1]["render_slot"] == 3
|
||||
|
|
Loading…
Reference in New Issue
Block a user