displacy: avoid overlapping arcs in manual mode (#10534)

* Added test for overlapping arcs

* Provide distinct levels to overlapping arcs

* Update return type hint for get_levels

* Improved formatting spacy/displacy/render.py

Co-authored-by: Ines Montani <ines@ines.io>

Co-authored-by: Joachim Fainberg <joachimfainberg@Joachims-MacBook-Pro.local>
Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
Joachim Fainberg 2022-04-05 09:08:02 +02:00 committed by GitHub
parent 0d0153db63
commit b91255a454
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 8 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Tuple, Union
import uuid import uuid
import itertools import itertools
@ -270,7 +270,7 @@ class DependencyRenderer:
RETURNS (str): Rendered SVG markup. RETURNS (str): Rendered SVG markup.
""" """
self.levels = self.get_levels(arcs) self.levels = self.get_levels(arcs)
self.highest_level = len(self.levels) self.highest_level = max(self.levels.values(), default=0)
self.offset_y = self.distance / 2 * self.highest_level + self.arrow_stroke self.offset_y = self.distance / 2 * self.highest_level + self.arrow_stroke
self.width = self.offset_x + len(words) * self.distance self.width = self.offset_x + len(words) * self.distance
self.height = self.offset_y + 3 * self.word_spacing self.height = self.offset_y + 3 * self.word_spacing
@ -330,7 +330,7 @@ class DependencyRenderer:
if start < 0 or end < 0: if start < 0 or end < 0:
error_args = dict(start=start, end=end, label=label, dir=direction) error_args = dict(start=start, end=end, label=label, dir=direction)
raise ValueError(Errors.E157.format(**error_args)) raise ValueError(Errors.E157.format(**error_args))
level = self.levels.index(end - start) + 1 level = self.levels[(start, end, label)]
x_start = self.offset_x + start * self.distance + self.arrow_spacing x_start = self.offset_x + start * self.distance + self.arrow_spacing
if self.direction == "rtl": if self.direction == "rtl":
x_start = self.width - x_start x_start = self.width - x_start
@ -346,7 +346,7 @@ class DependencyRenderer:
y_curve = self.offset_y - level * self.distance / 2 y_curve = self.offset_y - level * self.distance / 2
if self.compact: if self.compact:
y_curve = self.offset_y - level * self.distance / 6 y_curve = self.offset_y - level * self.distance / 6
if y_curve == 0 and len(self.levels) > 5: if y_curve == 0 and max(self.levels.values(), default=0) > 5:
y_curve = -self.distance y_curve = -self.distance
arrowhead = self.get_arrowhead(direction, x_start, y, x_end) arrowhead = self.get_arrowhead(direction, x_start, y, x_end)
arc = self.get_arc(x_start, y, y_curve, x_end) arc = self.get_arc(x_start, y, y_curve, x_end)
@ -390,15 +390,22 @@ class DependencyRenderer:
p1, p2, p3 = (end, end + self.arrow_width - 2, end - self.arrow_width + 2) p1, p2, p3 = (end, end + self.arrow_width - 2, end - self.arrow_width + 2)
return f"M{p1},{y + 2} L{p2},{y - self.arrow_width} {p3},{y - self.arrow_width}" return f"M{p1},{y + 2} L{p2},{y - self.arrow_width} {p3},{y - self.arrow_width}"
def get_levels(self, arcs: List[Dict[str, Any]]) -> List[int]: def get_levels(self, arcs: List[Dict[str, Any]]) -> Dict[Tuple[int, int, str], int]:
"""Calculate available arc height "levels". """Calculate available arc height "levels".
Used to calculate arrow heights dynamically and without wasting space. Used to calculate arrow heights dynamically and without wasting space.
args (list): Individual arcs and their start, end, direction and label. args (list): Individual arcs and their start, end, direction and label.
RETURNS (list): Arc levels sorted from lowest to highest. RETURNS (dict): Arc levels keyed by (start, end, label).
""" """
levels = set(map(lambda arc: arc["end"] - arc["start"], arcs)) length = max([arc["end"] for arc in arcs], default=0)
return sorted(list(levels)) max_level = [0] * length
levels = {}
for arc in sorted(arcs, key=lambda arc: arc["end"] - arc["start"]):
level = max(max_level[arc["start"] : arc["end"]]) + 1
for i in range(arc["start"], arc["end"]):
max_level[i] = level
levels[(arc["start"], arc["end"], arc["label"])] = level
return levels
class EntityRenderer: class EntityRenderer:

View File

@ -8,6 +8,26 @@ from spacy.lang.fa import Persian
from spacy.tokens import Span, Doc from spacy.tokens import Span, Doc
@pytest.mark.issue(5447)
def test_issue5447():
"""Test that overlapping arcs get separate levels."""
renderer = DependencyRenderer()
words = [
{"text": "This", "tag": "DT"},
{"text": "is", "tag": "VBZ"},
{"text": "a", "tag": "DT"},
{"text": "sentence.", "tag": "NN"},
]
arcs = [
{"start": 0, "end": 1, "label": "nsubj", "dir": "left"},
{"start": 2, "end": 3, "label": "det", "dir": "left"},
{"start": 2, "end": 3, "label": "overlap", "dir": "left"},
{"start": 1, "end": 3, "label": "attr", "dir": "left"},
]
html = renderer.render([{"words": words, "arcs": arcs}])
assert renderer.highest_level == 3
@pytest.mark.issue(2361) @pytest.mark.issue(2361)
def test_issue2361(de_vocab): def test_issue2361(de_vocab):
"""Test if < is escaped when rendering""" """Test if < is escaped when rendering"""