displacy: avoid overlapping arcs in manual mode (#10534)

* Added test for overlapping arcs

* Provide distinct levels to overlapping arcs

* Update return type hint for get_levels

* Improved formatting spacy/displacy/render.py

Co-authored-by: Ines Montani <ines@ines.io>

Co-authored-by: Joachim Fainberg <joachimfainberg@Joachims-MacBook-Pro.local>
Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
Joachim Fainberg 2022-04-05 09:08:02 +02:00 committed by GitHub
parent 0d0153db63
commit b91255a454
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 8 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import uuid
import itertools
@ -270,7 +270,7 @@ class DependencyRenderer:
RETURNS (str): Rendered SVG markup.
"""
self.levels = self.get_levels(arcs)
self.highest_level = len(self.levels)
self.highest_level = max(self.levels.values(), default=0)
self.offset_y = self.distance / 2 * self.highest_level + self.arrow_stroke
self.width = self.offset_x + len(words) * self.distance
self.height = self.offset_y + 3 * self.word_spacing
@ -330,7 +330,7 @@ class DependencyRenderer:
if start < 0 or end < 0:
error_args = dict(start=start, end=end, label=label, dir=direction)
raise ValueError(Errors.E157.format(**error_args))
level = self.levels.index(end - start) + 1
level = self.levels[(start, end, label)]
x_start = self.offset_x + start * self.distance + self.arrow_spacing
if self.direction == "rtl":
x_start = self.width - x_start
@ -346,7 +346,7 @@ class DependencyRenderer:
y_curve = self.offset_y - level * self.distance / 2
if self.compact:
y_curve = self.offset_y - level * self.distance / 6
if y_curve == 0 and len(self.levels) > 5:
if y_curve == 0 and max(self.levels.values(), default=0) > 5:
y_curve = -self.distance
arrowhead = self.get_arrowhead(direction, x_start, y, x_end)
arc = self.get_arc(x_start, y, y_curve, x_end)
@ -390,15 +390,22 @@ class DependencyRenderer:
p1, p2, p3 = (end, end + self.arrow_width - 2, end - self.arrow_width + 2)
return f"M{p1},{y + 2} L{p2},{y - self.arrow_width} {p3},{y - self.arrow_width}"
def get_levels(self, arcs: List[Dict[str, Any]]) -> List[int]:
def get_levels(self, arcs: List[Dict[str, Any]]) -> Dict[Tuple[int, int, str], int]:
"""Calculate available arc height "levels".
Used to calculate arrow heights dynamically and without wasting space.
args (list): Individual arcs and their start, end, direction and label.
RETURNS (list): Arc levels sorted from lowest to highest.
RETURNS (dict): Arc levels keyed by (start, end, label).
"""
levels = set(map(lambda arc: arc["end"] - arc["start"], arcs))
return sorted(list(levels))
length = max([arc["end"] for arc in arcs], default=0)
max_level = [0] * length
levels = {}
for arc in sorted(arcs, key=lambda arc: arc["end"] - arc["start"]):
level = max(max_level[arc["start"] : arc["end"]]) + 1
for i in range(arc["start"], arc["end"]):
max_level[i] = level
levels[(arc["start"], arc["end"], arc["label"])] = level
return levels
class EntityRenderer:

View File

@ -8,6 +8,26 @@ from spacy.lang.fa import Persian
from spacy.tokens import Span, Doc
@pytest.mark.issue(5447)
def test_issue5447():
"""Test that overlapping arcs get separate levels."""
renderer = DependencyRenderer()
words = [
{"text": "This", "tag": "DT"},
{"text": "is", "tag": "VBZ"},
{"text": "a", "tag": "DT"},
{"text": "sentence.", "tag": "NN"},
]
arcs = [
{"start": 0, "end": 1, "label": "nsubj", "dir": "left"},
{"start": 2, "end": 3, "label": "det", "dir": "left"},
{"start": 2, "end": 3, "label": "overlap", "dir": "left"},
{"start": 1, "end": 3, "label": "attr", "dir": "left"},
]
html = renderer.render([{"words": words, "arcs": arcs}])
assert renderer.highest_level == 3
@pytest.mark.issue(2361)
def test_issue2361(de_vocab):
"""Test if < is escaped when rendering"""