From b91255a454fd91ca9e680143ade21f5570327b44 Mon Sep 17 00:00:00 2001 From: Joachim Fainberg Date: Tue, 5 Apr 2022 09:08:02 +0200 Subject: [PATCH] displacy: avoid overlapping arcs in manual mode (#10534) * Added test for overlapping arcs * Provide distinct levels to overlapping arcs * Update return type hint for get_levels * Improved formatting spacy/displacy/render.py Co-authored-by: Ines Montani Co-authored-by: Joachim Fainberg Co-authored-by: Ines Montani --- spacy/displacy/render.py | 23 +++++++++++++++-------- spacy/tests/test_displacy.py | 20 ++++++++++++++++++++ 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/spacy/displacy/render.py b/spacy/displacy/render.py index 2925c68a0..8d39e62f5 100644 --- a/spacy/displacy/render.py +++ b/spacy/displacy/render.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import uuid import itertools @@ -270,7 +270,7 @@ class DependencyRenderer: RETURNS (str): Rendered SVG markup. """ self.levels = self.get_levels(arcs) - self.highest_level = len(self.levels) + self.highest_level = max(self.levels.values(), default=0) self.offset_y = self.distance / 2 * self.highest_level + self.arrow_stroke self.width = self.offset_x + len(words) * self.distance self.height = self.offset_y + 3 * self.word_spacing @@ -330,7 +330,7 @@ class DependencyRenderer: if start < 0 or end < 0: error_args = dict(start=start, end=end, label=label, dir=direction) raise ValueError(Errors.E157.format(**error_args)) - level = self.levels.index(end - start) + 1 + level = self.levels[(start, end, label)] x_start = self.offset_x + start * self.distance + self.arrow_spacing if self.direction == "rtl": x_start = self.width - x_start @@ -346,7 +346,7 @@ class DependencyRenderer: y_curve = self.offset_y - level * self.distance / 2 if self.compact: y_curve = self.offset_y - level * self.distance / 6 - if y_curve == 0 and len(self.levels) > 5: + if y_curve == 0 and max(self.levels.values(), default=0) > 5: y_curve = -self.distance arrowhead = self.get_arrowhead(direction, x_start, y, x_end) arc = self.get_arc(x_start, y, y_curve, x_end) @@ -390,15 +390,22 @@ class DependencyRenderer: p1, p2, p3 = (end, end + self.arrow_width - 2, end - self.arrow_width + 2) return f"M{p1},{y + 2} L{p2},{y - self.arrow_width} {p3},{y - self.arrow_width}" - def get_levels(self, arcs: List[Dict[str, Any]]) -> List[int]: + def get_levels(self, arcs: List[Dict[str, Any]]) -> Dict[Tuple[int, int, str], int]: """Calculate available arc height "levels". Used to calculate arrow heights dynamically and without wasting space. args (list): Individual arcs and their start, end, direction and label. - RETURNS (list): Arc levels sorted from lowest to highest. + RETURNS (dict): Arc levels keyed by (start, end, label). """ - levels = set(map(lambda arc: arc["end"] - arc["start"], arcs)) - return sorted(list(levels)) + length = max([arc["end"] for arc in arcs], default=0) + max_level = [0] * length + levels = {} + for arc in sorted(arcs, key=lambda arc: arc["end"] - arc["start"]): + level = max(max_level[arc["start"] : arc["end"]]) + 1 + for i in range(arc["start"], arc["end"]): + max_level[i] = level + levels[(arc["start"], arc["end"], arc["label"])] = level + return levels class EntityRenderer: diff --git a/spacy/tests/test_displacy.py b/spacy/tests/test_displacy.py index ccad7e342..95dc47a19 100644 --- a/spacy/tests/test_displacy.py +++ b/spacy/tests/test_displacy.py @@ -8,6 +8,26 @@ from spacy.lang.fa import Persian from spacy.tokens import Span, Doc +@pytest.mark.issue(5447) +def test_issue5447(): + """Test that overlapping arcs get separate levels.""" + renderer = DependencyRenderer() + words = [ + {"text": "This", "tag": "DT"}, + {"text": "is", "tag": "VBZ"}, + {"text": "a", "tag": "DT"}, + {"text": "sentence.", "tag": "NN"}, + ] + arcs = [ + {"start": 0, "end": 1, "label": "nsubj", "dir": "left"}, + {"start": 2, "end": 3, "label": "det", "dir": "left"}, + {"start": 2, "end": 3, "label": "overlap", "dir": "left"}, + {"start": 1, "end": 3, "label": "attr", "dir": "left"}, + ] + html = renderer.render([{"words": words, "arcs": arcs}]) + assert renderer.highest_level == 3 + + @pytest.mark.issue(2361) def test_issue2361(de_vocab): """Test if < is escaped when rendering"""