from typing import Any, Dict, List, Optional, Tuple, Union import uuid import itertools from ..errors import Errors from ..util import escape_html, minify_html, registry from .templates import TPL_DEP_ARCS, TPL_DEP_SVG, TPL_DEP_WORDS from .templates import TPL_DEP_WORDS_LEMMA, TPL_ENT, TPL_ENT_RTL, TPL_ENTS from .templates import TPL_FIGURE, TPL_KB_LINK, TPL_PAGE, TPL_SPAN from .templates import TPL_SPAN_RTL, TPL_SPAN_SLICE, TPL_SPAN_SLICE_RTL from .templates import TPL_SPAN_START, TPL_SPAN_START_RTL, TPL_SPANS from .templates import TPL_TITLE DEFAULT_LANG = "en" DEFAULT_DIR = "ltr" DEFAULT_ENTITY_COLOR = "#ddd" DEFAULT_LABEL_COLORS = { "ORG": "#7aecec", "PRODUCT": "#bfeeb7", "GPE": "#feca74", "LOC": "#ff9561", "PERSON": "#aa9cfc", "NORP": "#c887fb", "FAC": "#9cc9cc", "EVENT": "#ffeb80", "LAW": "#ff8197", "LANGUAGE": "#ff8197", "WORK_OF_ART": "#f0d0ff", "DATE": "#bfe1d9", "TIME": "#bfe1d9", "MONEY": "#e4e7d2", "QUANTITY": "#e4e7d2", "ORDINAL": "#e4e7d2", "CARDINAL": "#e4e7d2", "PERCENT": "#e4e7d2", } class SpanRenderer: """Render Spans as SVGs.""" style = "span" def __init__(self, options: Dict[str, Any] = {}) -> None: """Initialise span renderer options (dict): Visualiser-specific options (colors, spans) """ # Set up the colors and overall look colors = dict(DEFAULT_LABEL_COLORS) user_colors = registry.displacy_colors.get_all() for user_color in user_colors.values(): if callable(user_color): # Since this comes from the function registry, we want to make # sure we support functions that *return* a dict of colors user_color = user_color() if not isinstance(user_color, dict): raise ValueError(Errors.E925.format(obj=type(user_color))) colors.update(user_color) colors.update(options.get("colors", {})) self.default_color = DEFAULT_ENTITY_COLOR self.colors = {label.upper(): color for label, color in colors.items()} # Set up how the text and labels will be rendered self.direction = DEFAULT_DIR self.lang = DEFAULT_LANG # These values are in px self.top_offset = options.get("top_offset", 40) # This is how far under the top offset the span labels appear self.span_label_offset = options.get("span_label_offset", 20) self.offset_step = options.get("top_offset_step", 17) # Set up which templates will be used template = options.get("template") if template: self.span_template = template["span"] self.span_slice_template = template["slice"] self.span_start_template = template["start"] else: if self.direction == "rtl": self.span_template = TPL_SPAN_RTL self.span_slice_template = TPL_SPAN_SLICE_RTL self.span_start_template = TPL_SPAN_START_RTL else: self.span_template = TPL_SPAN self.span_slice_template = TPL_SPAN_SLICE self.span_start_template = TPL_SPAN_START def render( self, parsed: List[Dict[str, Any]], page: bool = False, minify: bool = False ) -> str: """Render complete markup. parsed (list): Dependency parses to render. page (bool): Render parses wrapped as full HTML page. minify (bool): Minify HTML markup. RETURNS (str): Rendered HTML markup. """ rendered = [] for i, p in enumerate(parsed): if i == 0: settings = p.get("settings", {}) self.direction = settings.get("direction", DEFAULT_DIR) self.lang = settings.get("lang", DEFAULT_LANG) rendered.append(self.render_spans(p["tokens"], p["spans"], p.get("title"))) if page: docs = "".join([TPL_FIGURE.format(content=doc) for doc in rendered]) markup = TPL_PAGE.format(content=docs, lang=self.lang, dir=self.direction) else: markup = "".join(rendered) if minify: return minify_html(markup) return markup def render_spans( self, tokens: List[str], spans: List[Dict[str, Any]], title: Optional[str], ) -> str: """Render span types in text. Spans are rendered per-token, this means that for each token, we check if it's part of a span slice (a member of a span type) or a span start (the starting token of a given span type). tokens (list): Individual tokens in the text spans (list): Individual entity spans and their start, end, label, kb_id and kb_url. title (str / None): Document title set in Doc.user_data['title']. """ per_token_info = [] for idx, token in enumerate(tokens): # Identify if a token belongs to a Span (and which) and if it's a # start token of said Span. We'll use this for the final HTML render token_markup: Dict[str, Any] = {} token_markup["text"] = token entities = [] for span in spans: ent = {} if span["start_token"] <= idx < span["end_token"]: ent["label"] = span["label"] ent["is_start"] = True if idx == span["start_token"] else False kb_id = span.get("kb_id", "") kb_url = span.get("kb_url", "#") ent["kb_link"] = ( TPL_KB_LINK.format(kb_id=kb_id, kb_url=kb_url) if kb_id else "" ) entities.append(ent) token_markup["entities"] = entities per_token_info.append(token_markup) markup = self._render_markup(per_token_info) markup = TPL_SPANS.format(content=markup, dir=self.direction) if title: markup = TPL_TITLE.format(title=title) + markup return markup def _render_markup(self, per_token_info: List[Dict[str, Any]]) -> str: """Render the markup from per-token information""" markup = "" for token in per_token_info: entities = sorted(token["entities"], key=lambda d: d["label"]) if entities: slices = self._get_span_slices(token["entities"]) starts = self._get_span_starts(token["entities"]) total_height = ( self.top_offset + self.span_label_offset + (self.offset_step * (len(entities) - 1)) ) markup += self.span_template.format( text=token["text"], span_slices=slices, span_starts=starts, total_height=total_height, ) else: markup += escape_html(token["text"] + " ") return markup def _get_span_slices(self, entities: List[Dict]) -> str: """Get the rendered markup of all Span slices""" span_slices = [] for entity, step in zip(entities, itertools.count(step=self.offset_step)): color = self.colors.get(entity["label"].upper(), self.default_color) span_slice = self.span_slice_template.format( bg=color, top_offset=self.top_offset + step ) span_slices.append(span_slice) return "".join(span_slices) def _get_span_starts(self, entities: List[Dict]) -> str: """Get the rendered markup of all Span start tokens""" span_starts = [] for entity, step in zip(entities, itertools.count(step=self.offset_step)): color = self.colors.get(entity["label"].upper(), self.default_color) span_start = ( self.span_start_template.format( bg=color, top_offset=self.top_offset + step, label=entity["label"], kb_link=entity["kb_link"], ) if entity["is_start"] else "" ) span_starts.append(span_start) return "".join(span_starts) class DependencyRenderer: """Render dependency parses as SVGs.""" style = "dep" def __init__(self, options: Dict[str, Any] = {}) -> None: """Initialise dependency renderer. options (dict): Visualiser-specific options (compact, word_spacing, arrow_spacing, arrow_width, arrow_stroke, distance, offset_x, color, bg, font) """ self.compact = options.get("compact", False) self.word_spacing = options.get("word_spacing", 45) self.arrow_spacing = options.get("arrow_spacing", 12 if self.compact else 20) self.arrow_width = options.get("arrow_width", 6 if self.compact else 10) self.arrow_stroke = options.get("arrow_stroke", 2) self.distance = options.get("distance", 150 if self.compact else 175) self.offset_x = options.get("offset_x", 50) self.color = options.get("color", "#000000") self.bg = options.get("bg", "#ffffff") self.font = options.get("font", "Arial") self.direction = DEFAULT_DIR self.lang = DEFAULT_LANG def render( self, parsed: List[Dict[str, Any]], page: bool = False, minify: bool = False ) -> str: """Render complete markup. parsed (list): Dependency parses to render. page (bool): Render parses wrapped as full HTML page. minify (bool): Minify HTML markup. RETURNS (str): Rendered SVG or HTML markup. """ # Create a random ID prefix to make sure parses don't receive the # same ID, even if they're identical id_prefix = uuid.uuid4().hex rendered = [] for i, p in enumerate(parsed): if i == 0: settings = p.get("settings", {}) self.direction = settings.get("direction", DEFAULT_DIR) self.lang = settings.get("lang", DEFAULT_LANG) render_id = f"{id_prefix}-{i}" svg = self.render_svg(render_id, p["words"], p["arcs"]) rendered.append(svg) if page: content = "".join([TPL_FIGURE.format(content=svg) for svg in rendered]) markup = TPL_PAGE.format( content=content, lang=self.lang, dir=self.direction ) else: markup = "".join(rendered) if minify: return minify_html(markup) return markup def render_svg( self, render_id: Union[int, str], words: List[Dict[str, Any]], arcs: List[Dict[str, Any]], ) -> str: """Render SVG. render_id (Union[int, str]): Unique ID, typically index of document. words (list): Individual words and their tags. arcs (list): Individual arcs and their start, end, direction and label. RETURNS (str): Rendered SVG markup. """ self.levels = self.get_levels(arcs) self.highest_level = max(self.levels.values(), default=0) self.offset_y = self.distance / 2 * self.highest_level + self.arrow_stroke self.width = self.offset_x + len(words) * self.distance self.height = self.offset_y + 3 * self.word_spacing self.id = render_id words_svg = [ self.render_word(w["text"], w["tag"], w.get("lemma", None), i) for i, w in enumerate(words) ] arcs_svg = [ self.render_arrow(a["label"], a["start"], a["end"], a["dir"], i) for i, a in enumerate(arcs) ] content = "".join(words_svg) + "".join(arcs_svg) return TPL_DEP_SVG.format( id=self.id, width=self.width, height=self.height, color=self.color, bg=self.bg, font=self.font, content=content, dir=self.direction, lang=self.lang, ) def render_word(self, text: str, tag: str, lemma: str, i: int) -> str: """Render individual word. text (str): Word text. tag (str): Part-of-speech tag. i (int): Unique ID, typically word index. RETURNS (str): Rendered SVG markup. """ y = self.offset_y + self.word_spacing x = self.offset_x + i * self.distance if self.direction == "rtl": x = self.width - x html_text = escape_html(text) if lemma is not None: return TPL_DEP_WORDS_LEMMA.format( text=html_text, tag=tag, lemma=lemma, x=x, y=y ) return TPL_DEP_WORDS.format(text=html_text, tag=tag, x=x, y=y) def render_arrow( self, label: str, start: int, end: int, direction: str, i: int ) -> str: """Render individual arrow. label (str): Dependency label. start (int): Index of start word. end (int): Index of end word. direction (str): Arrow direction, 'left' or 'right'. i (int): Unique ID, typically arrow index. RETURNS (str): Rendered SVG markup. """ if start < 0 or end < 0: error_args = dict(start=start, end=end, label=label, dir=direction) raise ValueError(Errors.E157.format(**error_args)) level = self.levels[(start, end, label)] x_start = self.offset_x + start * self.distance + self.arrow_spacing if self.direction == "rtl": x_start = self.width - x_start y = self.offset_y x_end = ( self.offset_x + (end - start) * self.distance + start * self.distance - self.arrow_spacing * (self.highest_level - level) / 4 ) if self.direction == "rtl": x_end = self.width - x_end y_curve = self.offset_y - level * self.distance / 2 if self.compact: y_curve = self.offset_y - level * self.distance / 6 if y_curve == 0 and max(self.levels.values(), default=0) > 5: y_curve = -self.distance arrowhead = self.get_arrowhead(direction, x_start, y, x_end) arc = self.get_arc(x_start, y, y_curve, x_end) label_side = "right" if self.direction == "rtl" else "left" return TPL_DEP_ARCS.format( id=self.id, i=i, stroke=self.arrow_stroke, head=arrowhead, label=label, label_side=label_side, arc=arc, ) def get_arc(self, x_start: int, y: int, y_curve: int, x_end: int) -> str: """Render individual arc. x_start (int): X-coordinate of arrow start point. y (int): Y-coordinate of arrow start and end point. y_curve (int): Y-corrdinate of Cubic Bézier y_curve point. x_end (int): X-coordinate of arrow end point. RETURNS (str): Definition of the arc path ('d' attribute). """ template = "M{x},{y} C{x},{c} {e},{c} {e},{y}" if self.compact: template = "M{x},{y} {x},{c} {e},{c} {e},{y}" return template.format(x=x_start, y=y, c=y_curve, e=x_end) def get_arrowhead(self, direction: str, x: int, y: int, end: int) -> str: """Render individual arrow head. direction (str): Arrow direction, 'left' or 'right'. x (int): X-coordinate of arrow start point. y (int): Y-coordinate of arrow start and end point. end (int): X-coordinate of arrow end point. RETURNS (str): Definition of the arrow head path ('d' attribute). """ if direction == "left": p1, p2, p3 = (x, x - self.arrow_width + 2, x + self.arrow_width - 2) else: p1, p2, p3 = (end, end + self.arrow_width - 2, end - self.arrow_width + 2) return f"M{p1},{y + 2} L{p2},{y - self.arrow_width} {p3},{y - self.arrow_width}" def get_levels(self, arcs: List[Dict[str, Any]]) -> Dict[Tuple[int, int, str], int]: """Calculate available arc height "levels". Used to calculate arrow heights dynamically and without wasting space. args (list): Individual arcs and their start, end, direction and label. RETURNS (dict): Arc levels keyed by (start, end, label). """ arcs = [dict(t) for t in {tuple(sorted(arc.items())) for arc in arcs}] length = max([arc["end"] for arc in arcs], default=0) max_level = [0] * length levels = {} for arc in sorted(arcs, key=lambda arc: arc["end"] - arc["start"]): level = max(max_level[arc["start"] : arc["end"]]) + 1 for i in range(arc["start"], arc["end"]): max_level[i] = level levels[(arc["start"], arc["end"], arc["label"])] = level return levels class EntityRenderer: """Render named entities as HTML.""" style = "ent" def __init__(self, options: Dict[str, Any] = {}) -> None: """Initialise entity renderer. options (dict): Visualiser-specific options (colors, ents) """ colors = dict(DEFAULT_LABEL_COLORS) user_colors = registry.displacy_colors.get_all() for user_color in user_colors.values(): if callable(user_color): # Since this comes from the function registry, we want to make # sure we support functions that *return* a dict of colors user_color = user_color() if not isinstance(user_color, dict): raise ValueError(Errors.E925.format(obj=type(user_color))) colors.update(user_color) colors.update(options.get("colors", {})) self.default_color = DEFAULT_ENTITY_COLOR self.colors = {label.upper(): color for label, color in colors.items()} self.ents = options.get("ents", None) if self.ents is not None: self.ents = [ent.upper() for ent in self.ents] self.direction = DEFAULT_DIR self.lang = DEFAULT_LANG template = options.get("template") if template: self.ent_template = template else: if self.direction == "rtl": self.ent_template = TPL_ENT_RTL else: self.ent_template = TPL_ENT def render( self, parsed: List[Dict[str, Any]], page: bool = False, minify: bool = False ) -> str: """Render complete markup. parsed (list): Dependency parses to render. page (bool): Render parses wrapped as full HTML page. minify (bool): Minify HTML markup. RETURNS (str): Rendered HTML markup. """ rendered = [] for i, p in enumerate(parsed): if i == 0: settings = p.get("settings", {}) self.direction = settings.get("direction", DEFAULT_DIR) self.lang = settings.get("lang", DEFAULT_LANG) rendered.append(self.render_ents(p["text"], p["ents"], p.get("title"))) if page: docs = "".join([TPL_FIGURE.format(content=doc) for doc in rendered]) markup = TPL_PAGE.format(content=docs, lang=self.lang, dir=self.direction) else: markup = "".join(rendered) if minify: return minify_html(markup) return markup def render_ents( self, text: str, spans: List[Dict[str, Any]], title: Optional[str] ) -> str: """Render entities in text. text (str): Original text. spans (list): Individual entity spans and their start, end, label, kb_id and kb_url. title (str / None): Document title set in Doc.user_data['title']. """ markup = "" offset = 0 for span in spans: label = span["label"] start = span["start"] end = span["end"] kb_id = span.get("kb_id", "") kb_url = span.get("kb_url", "#") kb_link = TPL_KB_LINK.format(kb_id=kb_id, kb_url=kb_url) if kb_id else "" additional_params = span.get("params", {}) entity = escape_html(text[start:end]) fragments = text[offset:start].split("\n") for i, fragment in enumerate(fragments): markup += escape_html(fragment) if len(fragments) > 1 and i != len(fragments) - 1: markup += "
" if self.ents is None or label.upper() in self.ents: color = self.colors.get(label.upper(), self.default_color) ent_settings = { "label": label, "text": entity, "bg": color, "kb_link": kb_link, } ent_settings.update(additional_params) markup += self.ent_template.format(**ent_settings) else: markup += entity offset = end fragments = text[offset:].split("\n") for i, fragment in enumerate(fragments): markup += escape_html(fragment) if len(fragments) > 1 and i != len(fragments) - 1: markup += "
" markup = TPL_ENTS.format(content=markup, dir=self.direction) if title: markup = TPL_TITLE.format(title=title) + markup return markup