mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Update spacy evaluate and add displaCy option
This commit is contained in:
parent
f24c2e3a8a
commit
73ac0aa0b5
|
@ -33,16 +33,23 @@ numpy.random.seed(0)
|
||||||
data_path=("Location of JSON-formatted evaluation data", "positional", None, str),
|
data_path=("Location of JSON-formatted evaluation data", "positional", None, str),
|
||||||
gold_preproc=("Use gold preprocessing", "flag", "G", bool),
|
gold_preproc=("Use gold preprocessing", "flag", "G", bool),
|
||||||
gpu_id=("Use GPU", "option", "g", int),
|
gpu_id=("Use GPU", "option", "g", int),
|
||||||
|
displacy_path=("Directory to output rendered parses as HTML", "option", "dp", str),
|
||||||
|
displacy_limit=("Limit of parses to render as HTML", "option", "dl", int)
|
||||||
)
|
)
|
||||||
def evaluate(cmd, model, data_path, gpu_id=-1, gold_preproc=False):
|
def evaluate(cmd, model, data_path, gpu_id=-1, gold_preproc=False,
|
||||||
|
displacy_path=None, displacy_limit=25):
|
||||||
"""
|
"""
|
||||||
Train a model. Expects data in spaCy's JSON format.
|
Evaluate a model. To render a sample of parses in a HTML file, set an output
|
||||||
|
directory as the displacy_path argument.
|
||||||
"""
|
"""
|
||||||
util.use_gpu(gpu_id)
|
util.use_gpu(gpu_id)
|
||||||
util.set_env_log(False)
|
util.set_env_log(False)
|
||||||
data_path = util.ensure_path(data_path)
|
data_path = util.ensure_path(data_path)
|
||||||
|
displacy_path = util.ensure_path(displacy_path)
|
||||||
if not data_path.exists():
|
if not data_path.exists():
|
||||||
prints(data_path, title="Evaluation data not found", exits=1)
|
prints(data_path, title="Evaluation data not found", exits=1)
|
||||||
|
if displacy_path and not displacy_path.exists():
|
||||||
|
prints(displacy_path, title="Visualization output directory not found", exits=1)
|
||||||
corpus = GoldCorpus(data_path, data_path)
|
corpus = GoldCorpus(data_path, data_path)
|
||||||
nlp = util.load_model(model)
|
nlp = util.load_model(model)
|
||||||
dev_docs = list(corpus.dev_docs(nlp, gold_preproc=gold_preproc))
|
dev_docs = list(corpus.dev_docs(nlp, gold_preproc=gold_preproc))
|
||||||
|
@ -50,17 +57,26 @@ def evaluate(cmd, model, data_path, gpu_id=-1, gold_preproc=False):
|
||||||
scorer = nlp.evaluate(dev_docs, verbose=False)
|
scorer = nlp.evaluate(dev_docs, verbose=False)
|
||||||
end = timer()
|
end = timer()
|
||||||
nwords = sum(len(doc_gold[0]) for doc_gold in dev_docs)
|
nwords = sum(len(doc_gold[0]) for doc_gold in dev_docs)
|
||||||
print('Time', end-begin, 'words', nwords, 'w.p.s', nwords/(end-begin))
|
print_results(scorer, time=end - begin, words=nwords,
|
||||||
print_results(scorer)
|
wps=nwords / (end - begin))
|
||||||
|
if displacy_path:
|
||||||
|
docs, golds = zip(*dev_docs)
|
||||||
|
render_deps = 'parser' in nlp.meta.get('pipeline', [])
|
||||||
|
render_ents = 'ner' in nlp.meta.get('pipeline', [])
|
||||||
|
render_parses(docs, displacy_path, model_name=model, limit=displacy_limit,
|
||||||
|
deps=render_deps, ents=render_ents)
|
||||||
|
prints(displacy_path, title="Generated %s parses as HTML" % displacy_limit)
|
||||||
|
|
||||||
|
|
||||||
def _render_parses(i, to_render):
|
def render_parses(docs, output_path, model_name='', limit=250, deps=True, ents=True):
|
||||||
to_render[0].user_data['title'] = "Batch %d" % i
|
docs[0].user_data['title'] = model_name
|
||||||
with Path('/tmp/entities.html').open('w') as file_:
|
if ents:
|
||||||
html = displacy.render(to_render[:5], style='ent', page=True)
|
with (output_path / 'entities.html').open('w') as file_:
|
||||||
|
html = displacy.render(docs[:limit], style='ent', page=True)
|
||||||
file_.write(html)
|
file_.write(html)
|
||||||
with Path('/tmp/parses.html').open('w') as file_:
|
if deps:
|
||||||
html = displacy.render(to_render[:5], style='dep', page=True)
|
with (output_path / 'parses.html').open('w') as file_:
|
||||||
|
html = displacy.render(docs[:limit], style='dep', page=True, options={'compact': True})
|
||||||
file_.write(html)
|
file_.write(html)
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,8 +104,11 @@ def print_progress(itn, losses, dev_scores, wps=0.0):
|
||||||
print(tpl.format(itn, **scores))
|
print(tpl.format(itn, **scores))
|
||||||
|
|
||||||
|
|
||||||
def print_results(scorer):
|
def print_results(scorer, time, words, wps):
|
||||||
results = {
|
results = {
|
||||||
|
'Time': '%.2f s' % time,
|
||||||
|
'Words': words,
|
||||||
|
'Words/s': '%.0f' % wps,
|
||||||
'TOK': '%.2f' % scorer.token_acc,
|
'TOK': '%.2f' % scorer.token_acc,
|
||||||
'POS': '%.2f' % scorer.tags_acc,
|
'POS': '%.2f' % scorer.tags_acc,
|
||||||
'UAS': '%.2f' % scorer.uas,
|
'UAS': '%.2f' % scorer.uas,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user