Tidy up displaCy

This commit is contained in:
ines 2017-10-27 14:39:19 +02:00
parent ea4a41c8fb
commit e3265998c0
2 changed files with 41 additions and 28 deletions

View File

@ -12,7 +12,7 @@ IS_JUPYTER = is_in_jupyter()
def render(docs, style='dep', page=False, minify=False, jupyter=IS_JUPYTER, def render(docs, style='dep', page=False, minify=False, jupyter=IS_JUPYTER,
options={}, manual=False): options={}, manual=False):
"""Render displaCy visualisation. """Render displaCy visualisation.
docs (list or Doc): Document(s) to visualise. docs (list or Doc): Document(s) to visualise.
@ -21,7 +21,7 @@ def render(docs, style='dep', page=False, minify=False, jupyter=IS_JUPYTER,
minify (bool): Minify HTML markup. minify (bool): Minify HTML markup.
jupyter (bool): Experimental, use Jupyter's `display()` to output markup. jupyter (bool): Experimental, use Jupyter's `display()` to output markup.
options (dict): Visualiser-specific options, e.g. colors. options (dict): Visualiser-specific options, e.g. colors.
manual (bool): Don't parse `Doc` and instead, expect a dict or list of dicts. manual (bool): Don't parse `Doc` and instead expect a dict/list of dicts.
RETURNS (unicode): Rendered HTML markup. RETURNS (unicode): Rendered HTML markup.
""" """
factories = {'dep': (DependencyRenderer, parse_deps), factories = {'dep': (DependencyRenderer, parse_deps),
@ -35,7 +35,7 @@ def render(docs, style='dep', page=False, minify=False, jupyter=IS_JUPYTER,
parsed = [converter(doc, options) for doc in docs] if not manual else docs parsed = [converter(doc, options) for doc in docs] if not manual else docs
_html['parsed'] = renderer.render(parsed, page=page, minify=minify).strip() _html['parsed'] = renderer.render(parsed, page=page, minify=minify).strip()
html = _html['parsed'] html = _html['parsed']
if jupyter: # return HTML rendered by IPython display() if jupyter: # return HTML rendered by IPython display()
from IPython.core.display import display, HTML from IPython.core.display import display, HTML
return display(HTML(html)) return display(HTML(html))
return html return html
@ -50,13 +50,15 @@ def serve(docs, style='dep', page=True, minify=False, options={}, manual=False,
page (bool): Render markup as full HTML page. page (bool): Render markup as full HTML page.
minify (bool): Minify HTML markup. minify (bool): Minify HTML markup.
options (dict): Visualiser-specific options, e.g. colors. options (dict): Visualiser-specific options, e.g. colors.
manual (bool): Don't parse `Doc` and instead, expect a dict or list of dicts. manual (bool): Don't parse `Doc` and instead expect a dict/list of dicts.
port (int): Port to serve visualisation. port (int): Port to serve visualisation.
""" """
from wsgiref import simple_server from wsgiref import simple_server
render(docs, style=style, page=page, minify=minify, options=options, manual=manual) render(docs, style=style, page=page, minify=minify, options=options,
manual=manual)
httpd = simple_server.make_server('0.0.0.0', port, app) httpd = simple_server.make_server('0.0.0.0', port, app)
prints("Using the '%s' visualizer" % style, title="Serving on port %d..." % port) prints("Using the '%s' visualizer" % style,
title="Serving on port %d..." % port)
try: try:
httpd.serve_forever() httpd.serve_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
@ -67,7 +69,8 @@ def serve(docs, style='dep', page=True, minify=False, options={}, manual=False,
def app(environ, start_response): def app(environ, start_response):
# headers and status need to be bytes in Python 2, see #1227 # headers and status need to be bytes in Python 2, see #1227
headers = [(b_to_str(b'Content-type'), b_to_str(b'text/html; charset=utf-8'))] headers = [(b_to_str(b'Content-type'),
b_to_str(b'text/html; charset=utf-8'))]
start_response(b_to_str(b'200 OK'), headers) start_response(b_to_str(b'200 OK'), headers)
res = _html['parsed'].encode(encoding='utf-8') res = _html['parsed'].encode(encoding='utf-8')
return [res] return [res]
@ -89,9 +92,9 @@ def parse_deps(orig_doc, options={}):
end = word.i + 1 end = word.i + 1
while end < len(doc) and doc[end].is_punct: while end < len(doc) and doc[end].is_punct:
end += 1 end += 1
span = doc[start : end] span = doc[start:end]
spans.append((span.start_char, span.end_char, word.tag_, spans.append((span.start_char, span.end_char, word.tag_,
word.lemma_, word.ent_type_)) word.lemma_, word.ent_type_))
for span_props in spans: for span_props in spans:
doc.merge(*span_props) doc.merge(*span_props)
words = [{'text': w.text, 'tag': w.tag_} for w in doc] words = [{'text': w.text, 'tag': w.tag_} for w in doc]
@ -113,6 +116,7 @@ def parse_ents(doc, options={}):
RETURNS (dict): Generated entities keyed by text (original text) and ents. RETURNS (dict): Generated entities keyed by text (original text) and ents.
""" """
ents = [{'start': ent.start_char, 'end': ent.end_char, 'label': ent.label_} ents = [{'start': ent.start_char, 'end': ent.end_char, 'label': ent.label_}
for ent in doc.ents] for ent in doc.ents]
title = doc.user_data.get('title', None) if hasattr(doc, 'user_data') else None title = (doc.user_data.get('title', None)
if hasattr(doc, 'user_data') else None)
return {'text': doc.text, 'ents': ents, 'title': title} return {'text': doc.text, 'ents': ents, 'title': title}

View File

@ -14,13 +14,15 @@ class DependencyRenderer(object):
"""Initialise dependency renderer. """Initialise dependency renderer.
options (dict): Visualiser-specific options (compact, word_spacing, options (dict): Visualiser-specific options (compact, word_spacing,
arrow_spacing, arrow_width, arrow_stroke, distance, arrow_spacing, arrow_width, arrow_stroke, distance, offset_x,
offset_x, color, bg, font) color, bg, font)
""" """
self.compact = options.get('compact', False) self.compact = options.get('compact', False)
self.word_spacing = options.get('word_spacing', 45) self.word_spacing = options.get('word_spacing', 45)
self.arrow_spacing = options.get('arrow_spacing', 12 if self.compact else 20) self.arrow_spacing = options.get('arrow_spacing',
self.arrow_width = options.get('arrow_width', 6 if self.compact else 10) 12 if self.compact else 20)
self.arrow_width = options.get('arrow_width',
6 if self.compact else 10)
self.arrow_stroke = options.get('arrow_stroke', 2) self.arrow_stroke = options.get('arrow_stroke', 2)
self.distance = options.get('distance', 150 if self.compact else 175) self.distance = options.get('distance', 150 if self.compact else 175)
self.offset_x = options.get('offset_x', 50) self.offset_x = options.get('offset_x', 50)
@ -39,7 +41,8 @@ class DependencyRenderer(object):
rendered = [self.render_svg(i, p['words'], p['arcs']) rendered = [self.render_svg(i, p['words'], p['arcs'])
for i, p in enumerate(parsed)] for i, p in enumerate(parsed)]
if page: if page:
content = ''.join([TPL_FIGURE.format(content=svg) for svg in rendered]) content = ''.join([TPL_FIGURE.format(content=svg)
for svg in rendered])
markup = TPL_PAGE.format(content=content) markup = TPL_PAGE.format(content=content)
else: else:
markup = ''.join(rendered) markup = ''.join(rendered)
@ -63,12 +66,13 @@ class DependencyRenderer(object):
self.id = render_id self.id = render_id
words = [self.render_word(w['text'], w['tag'], i) words = [self.render_word(w['text'], w['tag'], i)
for i, w in enumerate(words)] for i, w in enumerate(words)]
arcs = [self.render_arrow(a['label'], a['start'], a['end'], a['dir'], i) arcs = [self.render_arrow(a['label'], a['start'],
a['end'], a['dir'], i)
for i, a in enumerate(arcs)] for i, a in enumerate(arcs)]
content = ''.join(words) + ''.join(arcs) content = ''.join(words) + ''.join(arcs)
return TPL_DEP_SVG.format(id=self.id, width=self.width, height=self.height, return TPL_DEP_SVG.format(id=self.id, width=self.width,
color=self.color, bg=self.bg, font=self.font, height=self.height, color=self.color,
content=content) bg=self.bg, font=self.font, content=content)
def render_word(self, text, tag, i): def render_word(self, text, tag, i):
"""Render individual word. """Render individual word.
@ -96,7 +100,7 @@ class DependencyRenderer(object):
x_start = self.offset_x+start*self.distance+self.arrow_spacing x_start = self.offset_x+start*self.distance+self.arrow_spacing
y = self.offset_y y = self.offset_y
x_end = (self.offset_x+(end-start)*self.distance+start*self.distance x_end = (self.offset_x+(end-start)*self.distance+start*self.distance
-self.arrow_spacing*(self.highest_level-level)/4) - self.arrow_spacing*(self.highest_level-level)/4)
y_curve = self.offset_y-level*self.distance/2 y_curve = self.offset_y-level*self.distance/2
if self.compact: if self.compact:
y_curve = self.offset_y-level*self.distance/6 y_curve = self.offset_y-level*self.distance/6
@ -133,8 +137,10 @@ class DependencyRenderer(object):
if direction is 'left': if direction is 'left':
pos1, pos2, pos3 = (x, x-self.arrow_width+2, x+self.arrow_width-2) pos1, pos2, pos3 = (x, x-self.arrow_width+2, x+self.arrow_width-2)
else: else:
pos1, pos2, pos3 = (end, end+self.arrow_width-2, end-self.arrow_width+2) pos1, pos2, pos3 = (end, end+self.arrow_width-2,
arrowhead = (pos1, y+2, pos2, y-self.arrow_width, pos3, y-self.arrow_width) end-self.arrow_width+2)
arrowhead = (pos1, y+2, pos2, y-self.arrow_width, pos3,
y-self.arrow_width)
return "M{},{} L{},{} {},{}".format(*arrowhead) return "M{},{} L{},{} {},{}".format(*arrowhead)
def get_levels(self, arcs): def get_levels(self, arcs):
@ -159,9 +165,10 @@ class EntityRenderer(object):
""" """
colors = {'ORG': '#7aecec', 'PRODUCT': '#bfeeb7', 'GPE': '#feca74', colors = {'ORG': '#7aecec', 'PRODUCT': '#bfeeb7', 'GPE': '#feca74',
'LOC': '#ff9561', 'PERSON': '#aa9cfc', 'NORP': '#c887fb', 'LOC': '#ff9561', 'PERSON': '#aa9cfc', 'NORP': '#c887fb',
'FACILITY': '#9cc9cc', 'EVENT': '#ffeb80', 'LANGUAGE': '#ff8197', 'FACILITY': '#9cc9cc', 'EVENT': '#ffeb80', 'LAW': '#ff8197',
'WORK_OF_ART': '#f0d0ff', 'DATE': '#bfe1d9', 'TIME': '#bfe1d9', 'LANGUAGE': '#ff8197', 'WORK_OF_ART': '#f0d0ff',
'MONEY': '#e4e7d2', 'QUANTITY': '#e4e7d2', 'ORDINAL': '#e4e7d2', 'DATE': '#bfe1d9', 'TIME': '#bfe1d9', 'MONEY': '#e4e7d2',
'QUANTITY': '#e4e7d2', 'ORDINAL': '#e4e7d2',
'CARDINAL': '#e4e7d2', 'PERCENT': '#e4e7d2'} 'CARDINAL': '#e4e7d2', 'PERCENT': '#e4e7d2'}
colors.update(options.get('colors', {})) colors.update(options.get('colors', {}))
self.default_color = '#ddd' self.default_color = '#ddd'
@ -176,9 +183,11 @@ class EntityRenderer(object):
minify (bool): Minify HTML markup. minify (bool): Minify HTML markup.
RETURNS (unicode): Rendered HTML markup. RETURNS (unicode): Rendered HTML markup.
""" """
rendered = [self.render_ents(p['text'], p['ents'], p.get('title', None)) for p in parsed] rendered = [self.render_ents(p['text'], p['ents'],
p.get('title', None)) for p in parsed]
if page: if page:
docs = ''.join([TPL_FIGURE.format(content=doc) for doc in rendered]) docs = ''.join([TPL_FIGURE.format(content=doc)
for doc in rendered])
markup = TPL_PAGE.format(content=docs) markup = TPL_PAGE.format(content=docs)
else: else:
markup = ''.join(rendered) markup = ''.join(rendered)