Update spacy evaluate and add displaCy option

This commit is contained in:
ines 2017-10-04 00:03:15 +02:00
parent f24c2e3a8a
commit 73ac0aa0b5

View File

@ -33,16 +33,23 @@ numpy.random.seed(0)
data_path=("Location of JSON-formatted evaluation data", "positional", None, str), data_path=("Location of JSON-formatted evaluation data", "positional", None, str),
gold_preproc=("Use gold preprocessing", "flag", "G", bool), gold_preproc=("Use gold preprocessing", "flag", "G", bool),
gpu_id=("Use GPU", "option", "g", int), gpu_id=("Use GPU", "option", "g", int),
displacy_path=("Directory to output rendered parses as HTML", "option", "dp", str),
displacy_limit=("Limit of parses to render as HTML", "option", "dl", int)
) )
def evaluate(cmd, model, data_path, gpu_id=-1, gold_preproc=False): def evaluate(cmd, model, data_path, gpu_id=-1, gold_preproc=False,
displacy_path=None, displacy_limit=25):
""" """
Train a model. Expects data in spaCy's JSON format. Evaluate a model. To render a sample of parses in a HTML file, set an output
directory as the displacy_path argument.
""" """
util.use_gpu(gpu_id) util.use_gpu(gpu_id)
util.set_env_log(False) util.set_env_log(False)
data_path = util.ensure_path(data_path) data_path = util.ensure_path(data_path)
displacy_path = util.ensure_path(displacy_path)
if not data_path.exists(): if not data_path.exists():
prints(data_path, title="Evaluation data not found", exits=1) prints(data_path, title="Evaluation data not found", exits=1)
if displacy_path and not displacy_path.exists():
prints(displacy_path, title="Visualization output directory not found", exits=1)
corpus = GoldCorpus(data_path, data_path) corpus = GoldCorpus(data_path, data_path)
nlp = util.load_model(model) nlp = util.load_model(model)
dev_docs = list(corpus.dev_docs(nlp, gold_preproc=gold_preproc)) dev_docs = list(corpus.dev_docs(nlp, gold_preproc=gold_preproc))
@ -50,18 +57,27 @@ def evaluate(cmd, model, data_path, gpu_id=-1, gold_preproc=False):
scorer = nlp.evaluate(dev_docs, verbose=False) scorer = nlp.evaluate(dev_docs, verbose=False)
end = timer() end = timer()
nwords = sum(len(doc_gold[0]) for doc_gold in dev_docs) nwords = sum(len(doc_gold[0]) for doc_gold in dev_docs)
print('Time', end-begin, 'words', nwords, 'w.p.s', nwords/(end-begin)) print_results(scorer, time=end - begin, words=nwords,
print_results(scorer) wps=nwords / (end - begin))
if displacy_path:
docs, golds = zip(*dev_docs)
render_deps = 'parser' in nlp.meta.get('pipeline', [])
render_ents = 'ner' in nlp.meta.get('pipeline', [])
render_parses(docs, displacy_path, model_name=model, limit=displacy_limit,
deps=render_deps, ents=render_ents)
prints(displacy_path, title="Generated %s parses as HTML" % displacy_limit)
def _render_parses(i, to_render): def render_parses(docs, output_path, model_name='', limit=250, deps=True, ents=True):
to_render[0].user_data['title'] = "Batch %d" % i docs[0].user_data['title'] = model_name
with Path('/tmp/entities.html').open('w') as file_: if ents:
html = displacy.render(to_render[:5], style='ent', page=True) with (output_path / 'entities.html').open('w') as file_:
file_.write(html) html = displacy.render(docs[:limit], style='ent', page=True)
with Path('/tmp/parses.html').open('w') as file_: file_.write(html)
html = displacy.render(to_render[:5], style='dep', page=True) if deps:
file_.write(html) with (output_path / 'parses.html').open('w') as file_:
html = displacy.render(docs[:limit], style='dep', page=True, options={'compact': True})
file_.write(html)
def print_progress(itn, losses, dev_scores, wps=0.0): def print_progress(itn, losses, dev_scores, wps=0.0):
@ -88,8 +104,11 @@ def print_progress(itn, losses, dev_scores, wps=0.0):
print(tpl.format(itn, **scores)) print(tpl.format(itn, **scores))
def print_results(scorer): def print_results(scorer, time, words, wps):
results = { results = {
'Time': '%.2f s' % time,
'Words': words,
'Words/s': '%.0f' % wps,
'TOK': '%.2f' % scorer.token_acc, 'TOK': '%.2f' % scorer.token_acc,
'POS': '%.2f' % scorer.tags_acc, 'POS': '%.2f' % scorer.tags_acc,
'UAS': '%.2f' % scorer.uas, 'UAS': '%.2f' % scorer.uas,