Accept Span to displacy render (#2478) (closes #2477)

* Add Span to displacy render

* Fix span support, errors and add tests
This commit is contained in:
Ole Henrik Skogstrøm 2018-06-25 14:55:16 +02:00 committed by Ines Montani
parent f33c703066
commit d16cb6bee6
3 changed files with 23 additions and 5 deletions

View File

@ -2,7 +2,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .render import DependencyRenderer, EntityRenderer from .render import DependencyRenderer, EntityRenderer
from ..tokens import Doc from ..tokens import Doc, Span
from ..compat import b_to_str from ..compat import b_to_str
from ..errors import Errors, Warnings, user_warning from ..errors import Errors, Warnings, user_warning
from ..util import prints, is_in_jupyter from ..util import prints, is_in_jupyter
@ -29,8 +29,11 @@ def render(docs, style='dep', page=False, minify=False, jupyter=IS_JUPYTER,
'ent': (EntityRenderer, parse_ents)} 'ent': (EntityRenderer, parse_ents)}
if style not in factories: if style not in factories:
raise ValueError(Errors.E087.format(style=style)) raise ValueError(Errors.E087.format(style=style))
if isinstance(docs, Doc) or isinstance(docs, dict): if isinstance(docs, (Doc, Span, dict)):
docs = [docs] docs = [docs]
docs = [obj if not isinstance(obj, Span) else obj.as_doc() for obj in docs]
if not all(isinstance(obj, (Doc, Span, dict)) for obj in docs):
raise ValueError(Errors.E096)
renderer, converter = factories[style] renderer, converter = factories[style]
renderer = renderer(options=options) renderer = renderer(options=options)
parsed = [converter(doc, options) for doc in docs] if not manual else docs parsed = [converter(doc, options) for doc in docs] if not manual else docs

View File

@ -247,6 +247,8 @@ class Errors(object):
E094 = ("Error reading line {line_num} in vectors file {loc}.") E094 = ("Error reading line {line_num} in vectors file {loc}.")
E095 = ("Can't write to frozen dictionary. This is likely an internal " E095 = ("Can't write to frozen dictionary. This is likely an internal "
"error. Are you writing to a default function argument?") "error. Are you writing to a default function argument?")
E096 = ("Invalid object passed to displaCy: Can only visualize Doc or "
"Span objects, or dicts if set to manual=True.")
@add_codes @add_codes

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
from ..util import ensure_path from ..util import ensure_path
from .. import util from .. import util
from ..displacy import parse_deps, parse_ents from .. import displacy
from ..tokens import Span from ..tokens import Span
from .util import get_doc from .util import get_doc
from .._ml import PrecomputableAffine from .._ml import PrecomputableAffine
@ -38,7 +38,7 @@ def test_displacy_parse_ents(en_vocab):
"""Test that named entities on a Doc are converted into displaCy's format.""" """Test that named entities on a Doc are converted into displaCy's format."""
doc = get_doc(en_vocab, words=["But", "Google", "is", "starting", "from", "behind"]) doc = get_doc(en_vocab, words=["But", "Google", "is", "starting", "from", "behind"])
doc.ents = [Span(doc, 1, 2, label=doc.vocab.strings[u'ORG'])] doc.ents = [Span(doc, 1, 2, label=doc.vocab.strings[u'ORG'])]
ents = parse_ents(doc) ents = displacy.parse_ents(doc)
assert isinstance(ents, dict) assert isinstance(ents, dict)
assert ents['text'] == 'But Google is starting from behind ' assert ents['text'] == 'But Google is starting from behind '
assert ents['ents'] == [{'start': 4, 'end': 10, 'label': 'ORG'}] assert ents['ents'] == [{'start': 4, 'end': 10, 'label': 'ORG'}]
@ -53,7 +53,7 @@ def test_displacy_parse_deps(en_vocab):
deps = ['nsubj', 'ROOT', 'det', 'attr'] deps = ['nsubj', 'ROOT', 'det', 'attr']
doc = get_doc(en_vocab, words=words, heads=heads, pos=pos, tags=tags, doc = get_doc(en_vocab, words=words, heads=heads, pos=pos, tags=tags,
deps=deps) deps=deps)
deps = parse_deps(doc) deps = displacy.parse_deps(doc)
assert isinstance(deps, dict) assert isinstance(deps, dict)
assert deps['words'] == [{'text': 'This', 'tag': 'DET'}, assert deps['words'] == [{'text': 'This', 'tag': 'DET'},
{'text': 'is', 'tag': 'VERB'}, {'text': 'is', 'tag': 'VERB'},
@ -64,6 +64,19 @@ def test_displacy_parse_deps(en_vocab):
{'start': 1, 'end': 3, 'label': 'attr', 'dir': 'right'}] {'start': 1, 'end': 3, 'label': 'attr', 'dir': 'right'}]
def test_displacy_spans(en_vocab):
"""Test that displaCy can render Spans."""
doc = get_doc(en_vocab, words=["But", "Google", "is", "starting", "from", "behind"])
doc.ents = [Span(doc, 1, 2, label=doc.vocab.strings[u'ORG'])]
html = displacy.render(doc[1:4], style='ent')
assert html.startswith('<div')
def test_displacy_raises_for_wrong_type(en_vocab):
with pytest.raises(ValueError):
html = displacy.render('hello world')
def test_PrecomputableAffine(nO=4, nI=5, nF=3, nP=2): def test_PrecomputableAffine(nO=4, nI=5, nF=3, nP=2):
model = PrecomputableAffine(nO=nO, nI=nI, nF=nF, nP=nP) model = PrecomputableAffine(nO=nO, nI=nI, nF=nF, nP=nP)
assert model.W.shape == (nF, nO, nP, nI) assert model.W.shape == (nF, nO, nP, nI)