From c31792aaec68a67ddaac9ef7ea812345eee0a6bf Mon Sep 17 00:00:00 2001 From: ines Date: Sun, 14 May 2017 17:50:23 +0200 Subject: [PATCH] Add displaCy visualisers (see #1058) --- spacy/displacy/__init__.py | 99 +++++++++++++++++ spacy/displacy/render.py | 216 ++++++++++++++++++++++++++++++++++++ spacy/displacy/templates.py | 60 ++++++++++ spacy/util.py | 10 ++ 4 files changed, 385 insertions(+) create mode 100644 spacy/displacy/__init__.py create mode 100644 spacy/displacy/render.py create mode 100644 spacy/displacy/templates.py diff --git a/spacy/displacy/__init__.py b/spacy/displacy/__init__.py new file mode 100644 index 000000000..f5216ccfc --- /dev/null +++ b/spacy/displacy/__init__.py @@ -0,0 +1,99 @@ +# coding: utf8 +from __future__ import unicode_literals + +from .render import DependencyRenderer, EntityRenderer +from ..tokens import Doc +from ..util import prints + + +_html = {} + + +def render(docs, style='dep', page=False, minify=False, jupyter=False, options={}): + """Render displaCy visualisation. + + docs (list or Doc): Document(s) to visualise. + style (unicode): Visualisation style, 'dep' or 'ent'. + page (bool): Render markup as full HTML page. + minify (bool): Minify HTML markup. + jupyter (bool): Experimental, use Jupyter's display() to output markup. + options (dict): Visualiser-specific options, e.g. colors. + RETURNS: Rendered HTML markup. + """ + if isinstance(docs, Doc): + docs = [docs] + if style is 'dep': + renderer = DependencyRenderer(options=options) + parsed = [parse_deps(doc, options) for doc in docs] + elif style is 'ent': + renderer = EntityRenderer(options=options) + parsed = [parse_ents(doc, options) for doc in docs] + _html['parsed'] = renderer.render(parsed, page=page, minify=minify) + return _html['parsed'] + + +def serve(docs, style='dep', page=True, minify=False, options={}, port=5000): + """Serve displaCy visualisation. + + docs (list or Doc): Document(s) to visualise. + style (unicode): Visualisation style, 'dep' or 'ent'. + page (bool): Render markup as full HTML page. + minify (bool): Minify HTML markup. + options (dict): Visualiser-specific options, e.g. colors. + port (int): Port to serve visualisation. + """ + from wsgiref import simple_server + render(docs, style=style, page=page, minify=minify, options=options) + httpd = simple_server.make_server('0.0.0.0', port, app) + prints("Using the '%s' visualizer" % style, title="Serving on port %d..." % port) + httpd.serve_forever() + + +def app(environ, start_response): + start_response('200 OK', [('Content-type', 'text/html; charset=utf-8')]) + res = _html['parsed'].encode(encoding='utf-8') + return [res] + + +def parse_deps(doc, options={}): + """Generate dependency parse in {'words': [], 'arcs': []} format. + + doc (Doc): Document do parse. + RETURNS (dict): Generated dependency parse keyed by words and arcs. + """ + if options.get('collapse_punct', True): + spans = [] + for word in doc[:-1]: + if word.is_punct or not word.nbor(1).is_punct: + continue + start = word.i + end = word.i + 1 + while end < len(doc) and doc[end].is_punct: + end += 1 + span = doc[start : end] + spans.append((span.start_char, span.end_char, word.tag_, + word.lemma_, word.ent_type_)) + for span_props in spans: + doc.merge(*span_props) + words = [{'text': w.text, 'tag': w.tag_} for w in doc] + arcs = [] + for word in doc: + if word.i < word.head.i: + arcs.append({'start': word.i, 'end': word.head.i, + 'label': word.dep_, 'dir': 'left'}) + elif word.i > word.head.i: + arcs.append({'start': word.head.i, 'end': word.i, + 'label': word.dep_, 'dir': 'right'}) + return {'words': words, 'arcs': arcs} + + +def parse_ents(doc, options={}): + """Generate named entities in [{start: i, end: i, label: 'label'}] format. + + doc (Doc): Document do parse. + RETURNS (dict): Generated entities keyed by text (original text) and ents. + """ + ents = [{'start': ent.start_char, 'end': ent.end_char, 'label': ent.label_} + for ent in doc.ents] + title = doc.user_data.get('title', None) + return {'text': doc.text, 'ents': ents, 'title': title} diff --git a/spacy/displacy/render.py b/spacy/displacy/render.py new file mode 100644 index 000000000..dfcaa8be1 --- /dev/null +++ b/spacy/displacy/render.py @@ -0,0 +1,216 @@ +# coding: utf8 +from __future__ import unicode_literals + +from .templates import TPL_DEP_SVG, TPL_DEP_WORDS, TPL_DEP_ARCS +from .templates import TPL_ENT, TPL_ENTS, TPL_FIGURE, TPL_TITLE, TPL_PAGE +from ..util import minify_html + + +class DependencyRenderer(object): + """Render dependency parses as SVGs.""" + style = 'dep' + + def __init__(self, options={}): + """Initialise dependency renderer. + + options (dict): Visualiser-specific options (compact, word_spacing, + arrow_spacing, arrow_width, arrow_stroke, distance, + offset_x, color, bg, font) + """ + self.compact = options.get('compact', False) + distance, arrow_width = (85, 8) if self.compact else (175, 10) + self.word_spacing = options.get('word_spacing', 45) + self.arrow_spacing = options.get('arrow_spacing', 20) + self.arrow_width = options.get('arrow_width', arrow_width) + self.arrow_stroke = options.get('arrow_stroke', 2) + self.distance = options.get('distance', distance) + self.offset_x = options.get('offset_x', 50) + self.color = options.get('color', '#000000') + self.bg = options.get('bg', '#ffffff') + self.font = options.get('font', 'Arial') + + def render(self, parsed, page=False, minify=False): + """Render complete markup. + + parsed (list): Dependency parses to render. + page (bool): Render parses wrapped as full HTML page. + minify (bool): Minify HTML markup. + RETURNS (unicode): Rendered SVG or HTML markup. + """ + rendered = [self.render_svg(i, p['words'], p['arcs']) + for i, p in enumerate(parsed)] + if page: + content = ''.join([TPL_FIGURE.format(content=svg) for svg in rendered]) + markup = TPL_PAGE.format(content=content) + else: + markup = ''.join(rendered) + if minify: + return minify_html(markup) + return markup + + def render_svg(self, render_id, words, arcs): + """Render SVG. + + render_id (int): Unique ID, typically index of document. + words (list): Individual words and their tags. + arcs (list): Individual arcs and their start, end, direction and label. + RETURNS (unicode): Rendered SVG markup. + """ + self.levels = self.get_levels(arcs) + self.highest_level = len(self.levels) + self.offset_y = self.distance/2*self.highest_level+self.arrow_stroke + self.width = self.offset_x+len(words)*self.distance + self.height = self.offset_y+3*self.word_spacing + self.id = render_id + words = [self.render_word(w['text'], w['tag'], i) + for i, w in enumerate(words)] + arcs = [self.render_arrow(a['label'], a['start'], a['end'], a['dir'], i) + for i, a in enumerate(arcs)] + content = ''.join(words) + ''.join(arcs) + return TPL_DEP_SVG.format(id=self.id, width=self.width, height=self.height, + color=self.color, bg=self.bg, font=self.font, + content=content) + + def render_word(self, text, tag, i): + """Render individual word. + + text (unicode): Word text. + tag (unicode): Part-of-speech tag. + i (int): Unique ID, typically word index. + RETURNS (unicode): Rendered SVG markup. + """ + y = self.offset_y+self.word_spacing + x = self.offset_x+i*self.distance + return TPL_DEP_WORDS.format(text=text, tag=tag, x=x, y=y) + + def render_arrow(self, label, start, end, direction, i): + """Render indivicual arrow. + + label (unicode): Dependency label. + start (int): Index of start word. + end (int): Index of end word. + direction (unicode): Arrow direction, 'left' or 'right'. + i (int): Unique ID, typically arrow index. + RETURNS (unicode): Rendered SVG markup. + """ + level = self.levels.index(end-start)+1 + x_start = self.offset_x+start*self.distance+self.arrow_spacing + y = self.offset_y + x_end = (self.offset_x+(end-start)*self.distance+start*self.distance + -self.arrow_spacing*(self.highest_level-level)/4) + y_curve = self.offset_y-level*self.distance/2 + if y_curve == 0 and len(self.levels) > 5: + y_curve = -self.distance + arrowhead = self.get_arrowhead(direction, x_start, y, x_end) + arc = self.get_arc(x_start, y, y_curve, x_end) + return TPL_DEP_ARCS.format(id=self.id, i=i, stroke=self.arrow_stroke, + head=arrowhead, label=label, arc=arc) + + def get_arc(self, x_start, y, y_curve, x_end): + """Render individual arc. + + x_start (int): X-coordinate of arrow start point. + y (int): Y-coordinate of arrow start and end point. + y_curve (int): Y-corrdinate of Cubic Bézier y_curve point. + x_end (int): X-coordinate of arrow end point. + RETURNS (unicode): Definition of the arc path ('d' attribute). + """ + template = "M{x},{y} C{x},{c} {e},{c} {e},{y}" + if self.compact: + template = "M{x},{y} {x},{c} {e},{c} {e},{y}" + return template.format(x=x_start, y=y, c=y_curve, e=x_end) + + def get_arrowhead(self, direction, x, y, end): + """Render individual arrow head. + + direction (unicode): Arrow direction, 'left' or 'right'. + x (int): X-coordinate of arrow start point. + y (int): Y-coordinate of arrow start and end point. + end (int): X-coordinate of arrow end point. + RETURNS (unicode): Definition of the arrow head path ('d' attribute). + """ + if direction is 'left': + pos1, pos2, pos3 = (x, x-self.arrow_width+2, x+self.arrow_width-2) + else: + pos1, pos2, pos3 = (end, end+self.arrow_width-2, end-self.arrow_width+2) + arrowhead = (pos1, y+2, pos2, y-self.arrow_width, pos3, y-self.arrow_width) + return "M{},{} L{},{} {},{}".format(*arrowhead) + + def get_levels(self, arcs): + """Calculate available arc height "levels". + Used to calculate arrow heights dynamically and without wasting space. + + args (list): Individual arcs and their start, end, direction and label. + RETURNS (list): Arc levels sorted from lowest to highest. + """ + levels = set(map(lambda arc: arc['end'] - arc['start'], arcs)) + return sorted(list(levels)) + + +class EntityRenderer(object): + """Render named entities as HTML.""" + style = 'ent' + + def __init__(self, options={}): + """Initialise dependency renderer. + + options (dict): Visualiser-specific options (colors, ents) + """ + colors = {'org': '#7aecec', 'product': '#bfeeb7', 'gpe': '#feca74', + 'loc': '#ff9561', 'person': '#9886fc', 'norp': '#c887fb', + 'facility': '#9cc9cc', 'event': '#ffeb80', 'language': '#ff8197', + 'work_of_art': '#f0d0ff', 'date': '#bfe1d9', 'time': '#bfe1d9', + 'money': '#e4e7d2', 'quantity': '#e4e7d2', 'ordinal': '#e4e7d2', + 'cardinal': '#e4e7d2', 'percent': '#e4e7d2'} + colors.update(options.get('colors', {})) + self.default_color = '#ddd' + self.colors = colors + self.ents = options.get('ents', None) + + def render(self, parsed, page=False, minify=False): + """Render complete markup. + + parsed (list): Dependency parses to render. + page (bool): Render parses wrapped as full HTML page. + minify (bool): Minify HTML markup. + RETURNS (unicode): Rendered HTML markup. + """ + rendered = [self.render_ents(p['text'], p['ents'], p['title']) for p in parsed] + if page: + docs = ''.join([TPL_FIGURE.format(content=doc) for doc in rendered]) + markup = TPL_PAGE.format(content=docs) + else: + markup = ''.join(rendered) + if minify: + return minify_html(markup) + return markup + + def render_ents(self, text, spans, title): + """Render entities in text. + + text (unicode): Original text. + spans (list): Individual entity spans and their start, end and label. + """ + markup = '' + offset = 0 + for span in spans: + label = span['label'] + start = span['start'] + end = span['end'] + entity = text[start:end] + fragments = text[offset:start].split('\n') + for i, fragment in enumerate(fragments): + markup += fragment + if len(fragments) > 1 and i != len(fragments)-1: + markup += '
' + if self.ents is None or label.lower() in self.ents: + color = self.colors.get(label.lower(), self.default_color) + markup += TPL_ENT.format(label=label, text=entity, bg=color) + else: + markup += entity + offset = end + markup += text[offset:] + markup = TPL_ENTS.format(content=markup, colors=self.colors) + if title: + markup = TPL_TITLE.format(title=title) + markup + return markup diff --git a/spacy/displacy/templates.py b/spacy/displacy/templates.py new file mode 100644 index 000000000..5cbcba98d --- /dev/null +++ b/spacy/displacy/templates.py @@ -0,0 +1,60 @@ +# coding: utf8 +from __future__ import unicode_literals + + +TPL_DEP_SVG = """ +{content} +""" + + +TPL_DEP_WORDS = """ + + {text} + {tag} + +""" + + +TPL_DEP_ARCS = """ + + + + {label} + + + +""" + + +TPL_FIGURE = """ +
{content}
+""" + +TPL_TITLE = """ +

{title}

+""" + + +TPL_ENTS = """ +
{content}
+""" + + +TPL_ENT = """ + + {text} + {label} + +""" + + +TPL_PAGE = """ + + + + displaCy + + + {content} + +""" diff --git a/spacy/util.py b/spacy/util.py index db7be7e69..c24c191ba 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -338,3 +338,13 @@ def _wrap(text, wrap_max=80, indent=4): return textwrap.fill(text, width=wrap_width, initial_indent=indent, subsequent_indent=indent, break_long_words=False, break_on_hyphens=False) + + +def minify_html(html): + """Perform a template-specific, rudimentary HTML minification for displaCy. + Disclaimer: NOT a general-purpose solution, only removes indentation/newlines. + + html (unicode): Markup to minify. + RETURNS (unicode): "Minified" HTML. + """ + return html.strip().replace(' ', '').replace('\n', '')