mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
a82c3153ad
<!--- Provide a general summary of your changes in the title. --> Referring #2452, fixing displacy arrow directions to match the input. ## Description The fix is simply replacing `direction is 'left'` with `direction == 'left'` to include the case `direction` is a `str` and not a `unicode`. ### Types of change bug fix ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [ ] I have submitted the spaCy Contributor Agreement. - [ ] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
230 lines
9.4 KiB
Python
230 lines
9.4 KiB
Python
# coding: utf8
|
|
from __future__ import unicode_literals
|
|
|
|
from .templates import TPL_DEP_SVG, TPL_DEP_WORDS, TPL_DEP_ARCS
|
|
from .templates import TPL_ENT, TPL_ENTS, TPL_FIGURE, TPL_TITLE, TPL_PAGE
|
|
from ..util import minify_html, escape_html
|
|
|
|
|
|
class DependencyRenderer(object):
|
|
"""Render dependency parses as SVGs."""
|
|
style = 'dep'
|
|
|
|
def __init__(self, options={}):
|
|
"""Initialise dependency renderer.
|
|
|
|
options (dict): Visualiser-specific options (compact, word_spacing,
|
|
arrow_spacing, arrow_width, arrow_stroke, distance, offset_x,
|
|
color, bg, font)
|
|
"""
|
|
self.compact = options.get('compact', False)
|
|
self.word_spacing = options.get('word_spacing', 45)
|
|
self.arrow_spacing = options.get('arrow_spacing',
|
|
12 if self.compact else 20)
|
|
self.arrow_width = options.get('arrow_width',
|
|
6 if self.compact else 10)
|
|
self.arrow_stroke = options.get('arrow_stroke', 2)
|
|
self.distance = options.get('distance', 150 if self.compact else 175)
|
|
self.offset_x = options.get('offset_x', 50)
|
|
self.color = options.get('color', '#000000')
|
|
self.bg = options.get('bg', '#ffffff')
|
|
self.font = options.get('font', 'Arial')
|
|
|
|
def render(self, parsed, page=False, minify=False):
|
|
"""Render complete markup.
|
|
|
|
parsed (list): Dependency parses to render.
|
|
page (bool): Render parses wrapped as full HTML page.
|
|
minify (bool): Minify HTML markup.
|
|
RETURNS (unicode): Rendered SVG or HTML markup.
|
|
"""
|
|
rendered = [self.render_svg(i, p['words'], p['arcs'])
|
|
for i, p in enumerate(parsed)]
|
|
if page:
|
|
content = ''.join([TPL_FIGURE.format(content=svg)
|
|
for svg in rendered])
|
|
markup = TPL_PAGE.format(content=content)
|
|
else:
|
|
markup = ''.join(rendered)
|
|
if minify:
|
|
return minify_html(markup)
|
|
return markup
|
|
|
|
def render_svg(self, render_id, words, arcs):
|
|
"""Render SVG.
|
|
|
|
render_id (int): Unique ID, typically index of document.
|
|
words (list): Individual words and their tags.
|
|
arcs (list): Individual arcs and their start, end, direction and label.
|
|
RETURNS (unicode): Rendered SVG markup.
|
|
"""
|
|
self.levels = self.get_levels(arcs)
|
|
self.highest_level = len(self.levels)
|
|
self.offset_y = self.distance/2*self.highest_level+self.arrow_stroke
|
|
self.width = self.offset_x+len(words)*self.distance
|
|
self.height = self.offset_y+3*self.word_spacing
|
|
self.id = render_id
|
|
words = [self.render_word(w['text'], w['tag'], i)
|
|
for i, w in enumerate(words)]
|
|
arcs = [self.render_arrow(a['label'], a['start'],
|
|
a['end'], a['dir'], i)
|
|
for i, a in enumerate(arcs)]
|
|
content = ''.join(words) + ''.join(arcs)
|
|
return TPL_DEP_SVG.format(id=self.id, width=self.width,
|
|
height=self.height, color=self.color,
|
|
bg=self.bg, font=self.font, content=content)
|
|
|
|
def render_word(self, text, tag, i):
|
|
"""Render individual word.
|
|
|
|
text (unicode): Word text.
|
|
tag (unicode): Part-of-speech tag.
|
|
i (int): Unique ID, typically word index.
|
|
RETURNS (unicode): Rendered SVG markup.
|
|
"""
|
|
y = self.offset_y+self.word_spacing
|
|
x = self.offset_x+i*self.distance
|
|
html_text = escape_html(text)
|
|
return TPL_DEP_WORDS.format(text=html_text, tag=tag, x=x, y=y)
|
|
|
|
|
|
def render_arrow(self, label, start, end, direction, i):
|
|
"""Render indivicual arrow.
|
|
|
|
label (unicode): Dependency label.
|
|
start (int): Index of start word.
|
|
end (int): Index of end word.
|
|
direction (unicode): Arrow direction, 'left' or 'right'.
|
|
i (int): Unique ID, typically arrow index.
|
|
RETURNS (unicode): Rendered SVG markup.
|
|
"""
|
|
level = self.levels.index(end-start)+1
|
|
x_start = self.offset_x+start*self.distance+self.arrow_spacing
|
|
y = self.offset_y
|
|
x_end = (self.offset_x+(end-start)*self.distance+start*self.distance
|
|
- self.arrow_spacing*(self.highest_level-level)/4)
|
|
y_curve = self.offset_y-level*self.distance/2
|
|
if self.compact:
|
|
y_curve = self.offset_y-level*self.distance/6
|
|
if y_curve == 0 and len(self.levels) > 5:
|
|
y_curve = -self.distance
|
|
arrowhead = self.get_arrowhead(direction, x_start, y, x_end)
|
|
arc = self.get_arc(x_start, y, y_curve, x_end)
|
|
return TPL_DEP_ARCS.format(id=self.id, i=i, stroke=self.arrow_stroke,
|
|
head=arrowhead, label=label, arc=arc)
|
|
|
|
def get_arc(self, x_start, y, y_curve, x_end):
|
|
"""Render individual arc.
|
|
|
|
x_start (int): X-coordinate of arrow start point.
|
|
y (int): Y-coordinate of arrow start and end point.
|
|
y_curve (int): Y-corrdinate of Cubic Bézier y_curve point.
|
|
x_end (int): X-coordinate of arrow end point.
|
|
RETURNS (unicode): Definition of the arc path ('d' attribute).
|
|
"""
|
|
template = "M{x},{y} C{x},{c} {e},{c} {e},{y}"
|
|
if self.compact:
|
|
template = "M{x},{y} {x},{c} {e},{c} {e},{y}"
|
|
return template.format(x=x_start, y=y, c=y_curve, e=x_end)
|
|
|
|
def get_arrowhead(self, direction, x, y, end):
|
|
"""Render individual arrow head.
|
|
|
|
direction (unicode): Arrow direction, 'left' or 'right'.
|
|
x (int): X-coordinate of arrow start point.
|
|
y (int): Y-coordinate of arrow start and end point.
|
|
end (int): X-coordinate of arrow end point.
|
|
RETURNS (unicode): Definition of the arrow head path ('d' attribute).
|
|
"""
|
|
if direction == 'left':
|
|
pos1, pos2, pos3 = (x, x-self.arrow_width+2, x+self.arrow_width-2)
|
|
else:
|
|
pos1, pos2, pos3 = (end, end+self.arrow_width-2,
|
|
end-self.arrow_width+2)
|
|
arrowhead = (pos1, y+2, pos2, y-self.arrow_width, pos3,
|
|
y-self.arrow_width)
|
|
return "M{},{} L{},{} {},{}".format(*arrowhead)
|
|
|
|
def get_levels(self, arcs):
|
|
"""Calculate available arc height "levels".
|
|
Used to calculate arrow heights dynamically and without wasting space.
|
|
|
|
args (list): Individual arcs and their start, end, direction and label.
|
|
RETURNS (list): Arc levels sorted from lowest to highest.
|
|
"""
|
|
levels = set(map(lambda arc: arc['end'] - arc['start'], arcs))
|
|
return sorted(list(levels))
|
|
|
|
|
|
class EntityRenderer(object):
|
|
"""Render named entities as HTML."""
|
|
style = 'ent'
|
|
|
|
def __init__(self, options={}):
|
|
"""Initialise dependency renderer.
|
|
|
|
options (dict): Visualiser-specific options (colors, ents)
|
|
"""
|
|
colors = {'ORG': '#7aecec', 'PRODUCT': '#bfeeb7', 'GPE': '#feca74',
|
|
'LOC': '#ff9561', 'PERSON': '#aa9cfc', 'NORP': '#c887fb',
|
|
'FACILITY': '#9cc9cc', 'EVENT': '#ffeb80', 'LAW': '#ff8197',
|
|
'LANGUAGE': '#ff8197', 'WORK_OF_ART': '#f0d0ff',
|
|
'DATE': '#bfe1d9', 'TIME': '#bfe1d9', 'MONEY': '#e4e7d2',
|
|
'QUANTITY': '#e4e7d2', 'ORDINAL': '#e4e7d2',
|
|
'CARDINAL': '#e4e7d2', 'PERCENT': '#e4e7d2'}
|
|
colors.update(options.get('colors', {}))
|
|
self.default_color = '#ddd'
|
|
self.colors = colors
|
|
self.ents = options.get('ents', None)
|
|
|
|
def render(self, parsed, page=False, minify=False):
|
|
"""Render complete markup.
|
|
|
|
parsed (list): Dependency parses to render.
|
|
page (bool): Render parses wrapped as full HTML page.
|
|
minify (bool): Minify HTML markup.
|
|
RETURNS (unicode): Rendered HTML markup.
|
|
"""
|
|
rendered = [self.render_ents(p['text'], p['ents'],
|
|
p.get('title', None)) for p in parsed]
|
|
if page:
|
|
docs = ''.join([TPL_FIGURE.format(content=doc)
|
|
for doc in rendered])
|
|
markup = TPL_PAGE.format(content=docs)
|
|
else:
|
|
markup = ''.join(rendered)
|
|
if minify:
|
|
return minify_html(markup)
|
|
return markup
|
|
|
|
def render_ents(self, text, spans, title):
|
|
"""Render entities in text.
|
|
|
|
text (unicode): Original text.
|
|
spans (list): Individual entity spans and their start, end and label.
|
|
title (unicode or None): Document title set in Doc.user_data['title'].
|
|
"""
|
|
markup = ''
|
|
offset = 0
|
|
for span in spans:
|
|
label = span['label']
|
|
start = span['start']
|
|
end = span['end']
|
|
entity = text[start:end]
|
|
fragments = text[offset:start].split('\n')
|
|
for i, fragment in enumerate(fragments):
|
|
markup += fragment
|
|
if len(fragments) > 1 and i != len(fragments)-1:
|
|
markup += '</br>'
|
|
if self.ents is None or label.upper() in self.ents:
|
|
color = self.colors.get(label.upper(), self.default_color)
|
|
markup += TPL_ENT.format(label=label, text=entity, bg=color)
|
|
else:
|
|
markup += entity
|
|
offset = end
|
|
markup += text[offset:]
|
|
markup = TPL_ENTS.format(content=markup, colors=self.colors)
|
|
if title:
|
|
markup = TPL_TITLE.format(title=title) + markup
|
|
return markup
|