mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-13 16:05:50 +03:00
displacy: avoid overlapping arcs in manual mode (#10534)
* Added test for overlapping arcs * Provide distinct levels to overlapping arcs * Update return type hint for get_levels * Improved formatting spacy/displacy/render.py Co-authored-by: Ines Montani <ines@ines.io> Co-authored-by: Joachim Fainberg <joachimfainberg@Joachims-MacBook-Pro.local> Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
parent
0d0153db63
commit
b91255a454
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
import uuid
|
import uuid
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
|
@ -270,7 +270,7 @@ class DependencyRenderer:
|
||||||
RETURNS (str): Rendered SVG markup.
|
RETURNS (str): Rendered SVG markup.
|
||||||
"""
|
"""
|
||||||
self.levels = self.get_levels(arcs)
|
self.levels = self.get_levels(arcs)
|
||||||
self.highest_level = len(self.levels)
|
self.highest_level = max(self.levels.values(), default=0)
|
||||||
self.offset_y = self.distance / 2 * self.highest_level + self.arrow_stroke
|
self.offset_y = self.distance / 2 * self.highest_level + self.arrow_stroke
|
||||||
self.width = self.offset_x + len(words) * self.distance
|
self.width = self.offset_x + len(words) * self.distance
|
||||||
self.height = self.offset_y + 3 * self.word_spacing
|
self.height = self.offset_y + 3 * self.word_spacing
|
||||||
|
@ -330,7 +330,7 @@ class DependencyRenderer:
|
||||||
if start < 0 or end < 0:
|
if start < 0 or end < 0:
|
||||||
error_args = dict(start=start, end=end, label=label, dir=direction)
|
error_args = dict(start=start, end=end, label=label, dir=direction)
|
||||||
raise ValueError(Errors.E157.format(**error_args))
|
raise ValueError(Errors.E157.format(**error_args))
|
||||||
level = self.levels.index(end - start) + 1
|
level = self.levels[(start, end, label)]
|
||||||
x_start = self.offset_x + start * self.distance + self.arrow_spacing
|
x_start = self.offset_x + start * self.distance + self.arrow_spacing
|
||||||
if self.direction == "rtl":
|
if self.direction == "rtl":
|
||||||
x_start = self.width - x_start
|
x_start = self.width - x_start
|
||||||
|
@ -346,7 +346,7 @@ class DependencyRenderer:
|
||||||
y_curve = self.offset_y - level * self.distance / 2
|
y_curve = self.offset_y - level * self.distance / 2
|
||||||
if self.compact:
|
if self.compact:
|
||||||
y_curve = self.offset_y - level * self.distance / 6
|
y_curve = self.offset_y - level * self.distance / 6
|
||||||
if y_curve == 0 and len(self.levels) > 5:
|
if y_curve == 0 and max(self.levels.values(), default=0) > 5:
|
||||||
y_curve = -self.distance
|
y_curve = -self.distance
|
||||||
arrowhead = self.get_arrowhead(direction, x_start, y, x_end)
|
arrowhead = self.get_arrowhead(direction, x_start, y, x_end)
|
||||||
arc = self.get_arc(x_start, y, y_curve, x_end)
|
arc = self.get_arc(x_start, y, y_curve, x_end)
|
||||||
|
@ -390,15 +390,22 @@ class DependencyRenderer:
|
||||||
p1, p2, p3 = (end, end + self.arrow_width - 2, end - self.arrow_width + 2)
|
p1, p2, p3 = (end, end + self.arrow_width - 2, end - self.arrow_width + 2)
|
||||||
return f"M{p1},{y + 2} L{p2},{y - self.arrow_width} {p3},{y - self.arrow_width}"
|
return f"M{p1},{y + 2} L{p2},{y - self.arrow_width} {p3},{y - self.arrow_width}"
|
||||||
|
|
||||||
def get_levels(self, arcs: List[Dict[str, Any]]) -> List[int]:
|
def get_levels(self, arcs: List[Dict[str, Any]]) -> Dict[Tuple[int, int, str], int]:
|
||||||
"""Calculate available arc height "levels".
|
"""Calculate available arc height "levels".
|
||||||
Used to calculate arrow heights dynamically and without wasting space.
|
Used to calculate arrow heights dynamically and without wasting space.
|
||||||
|
|
||||||
args (list): Individual arcs and their start, end, direction and label.
|
args (list): Individual arcs and their start, end, direction and label.
|
||||||
RETURNS (list): Arc levels sorted from lowest to highest.
|
RETURNS (dict): Arc levels keyed by (start, end, label).
|
||||||
"""
|
"""
|
||||||
levels = set(map(lambda arc: arc["end"] - arc["start"], arcs))
|
length = max([arc["end"] for arc in arcs], default=0)
|
||||||
return sorted(list(levels))
|
max_level = [0] * length
|
||||||
|
levels = {}
|
||||||
|
for arc in sorted(arcs, key=lambda arc: arc["end"] - arc["start"]):
|
||||||
|
level = max(max_level[arc["start"] : arc["end"]]) + 1
|
||||||
|
for i in range(arc["start"], arc["end"]):
|
||||||
|
max_level[i] = level
|
||||||
|
levels[(arc["start"], arc["end"], arc["label"])] = level
|
||||||
|
return levels
|
||||||
|
|
||||||
|
|
||||||
class EntityRenderer:
|
class EntityRenderer:
|
||||||
|
|
|
@ -8,6 +8,26 @@ from spacy.lang.fa import Persian
|
||||||
from spacy.tokens import Span, Doc
|
from spacy.tokens import Span, Doc
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.issue(5447)
|
||||||
|
def test_issue5447():
|
||||||
|
"""Test that overlapping arcs get separate levels."""
|
||||||
|
renderer = DependencyRenderer()
|
||||||
|
words = [
|
||||||
|
{"text": "This", "tag": "DT"},
|
||||||
|
{"text": "is", "tag": "VBZ"},
|
||||||
|
{"text": "a", "tag": "DT"},
|
||||||
|
{"text": "sentence.", "tag": "NN"},
|
||||||
|
]
|
||||||
|
arcs = [
|
||||||
|
{"start": 0, "end": 1, "label": "nsubj", "dir": "left"},
|
||||||
|
{"start": 2, "end": 3, "label": "det", "dir": "left"},
|
||||||
|
{"start": 2, "end": 3, "label": "overlap", "dir": "left"},
|
||||||
|
{"start": 1, "end": 3, "label": "attr", "dir": "left"},
|
||||||
|
]
|
||||||
|
html = renderer.render([{"words": words, "arcs": arcs}])
|
||||||
|
assert renderer.highest_level == 3
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.issue(2361)
|
@pytest.mark.issue(2361)
|
||||||
def test_issue2361(de_vocab):
|
def test_issue2361(de_vocab):
|
||||||
"""Test if < is escaped when rendering"""
|
"""Test if < is escaped when rendering"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user