mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Add spacy evaluate
This commit is contained in:
parent
cdb2d83e16
commit
69c7c642c2
|
@ -7,7 +7,7 @@ if __name__ == '__main__':
|
||||||
import plac
|
import plac
|
||||||
import sys
|
import sys
|
||||||
from spacy.cli import download, link, info, package, train, convert, model
|
from spacy.cli import download, link, info, package, train, convert, model
|
||||||
from spacy.cli import profile
|
from spacy.cli import profile, evaluate
|
||||||
from spacy.util import prints
|
from spacy.util import prints
|
||||||
|
|
||||||
commands = {
|
commands = {
|
||||||
|
@ -15,6 +15,7 @@ if __name__ == '__main__':
|
||||||
'link': link,
|
'link': link,
|
||||||
'info': info,
|
'info': info,
|
||||||
'train': train,
|
'train': train,
|
||||||
|
'evaluate': evaluate,
|
||||||
'convert': convert,
|
'convert': convert,
|
||||||
'package': package,
|
'package': package,
|
||||||
'model': model,
|
'model': model,
|
||||||
|
|
|
@ -4,5 +4,6 @@ from .link import link
|
||||||
from .package import package
|
from .package import package
|
||||||
from .profile import profile
|
from .profile import profile
|
||||||
from .train import train
|
from .train import train
|
||||||
|
from .evaluate import evaluate
|
||||||
from .convert import convert
|
from .convert import convert
|
||||||
from .model import model
|
from .model import model
|
||||||
|
|
93
spacy/cli/evaluate.py
Normal file
93
spacy/cli/evaluate.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals, division, print_function
|
||||||
|
|
||||||
|
import plac
|
||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
import cytoolz
|
||||||
|
from pathlib import Path
|
||||||
|
import dill
|
||||||
|
import tqdm
|
||||||
|
from thinc.neural._classes.model import Model
|
||||||
|
from thinc.neural.optimizers import linear_decay
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
import random
|
||||||
|
import numpy.random
|
||||||
|
|
||||||
|
from ..tokens.doc import Doc
|
||||||
|
from ..scorer import Scorer
|
||||||
|
from ..gold import GoldParse, merge_sents
|
||||||
|
from ..gold import GoldCorpus, minibatch
|
||||||
|
from ..util import prints
|
||||||
|
from .. import util
|
||||||
|
from .. import about
|
||||||
|
from .. import displacy
|
||||||
|
from ..compat import json_dumps
|
||||||
|
|
||||||
|
random.seed(0)
|
||||||
|
numpy.random.seed(0)
|
||||||
|
|
||||||
|
|
||||||
|
@plac.annotations(
|
||||||
|
model=("Model name or path", "positional", None, str),
|
||||||
|
data_path=("Location of JSON-formatted evaluation data", "positional", None, str),
|
||||||
|
gold_preproc=("Use gold preprocessing", "flag", "G", bool),
|
||||||
|
)
|
||||||
|
def evaluate(cmd, model, data_path, gold_preproc=False):
|
||||||
|
"""
|
||||||
|
Train a model. Expects data in spaCy's JSON format.
|
||||||
|
"""
|
||||||
|
util.set_env_log(True)
|
||||||
|
data_path = util.ensure_path(data_path)
|
||||||
|
if not data_path.exists():
|
||||||
|
prints(data_path, title="Evaluation data not found", exits=1)
|
||||||
|
corpus = GoldCorpus(data_path, data_path)
|
||||||
|
nlp = util.load_model(model)
|
||||||
|
scorer = nlp.evaluate(list(corpus.dev_docs(nlp, gold_preproc=gold_preproc)))
|
||||||
|
print_results(scorer)
|
||||||
|
|
||||||
|
|
||||||
|
def _render_parses(i, to_render):
|
||||||
|
to_render[0].user_data['title'] = "Batch %d" % i
|
||||||
|
with Path('/tmp/entities.html').open('w') as file_:
|
||||||
|
html = displacy.render(to_render[:5], style='ent', page=True)
|
||||||
|
file_.write(html)
|
||||||
|
with Path('/tmp/parses.html').open('w') as file_:
|
||||||
|
html = displacy.render(to_render[:5], style='dep', page=True)
|
||||||
|
file_.write(html)
|
||||||
|
|
||||||
|
|
||||||
|
def print_progress(itn, losses, dev_scores, wps=0.0):
|
||||||
|
scores = {}
|
||||||
|
for col in ['dep_loss', 'tag_loss', 'uas', 'tags_acc', 'token_acc',
|
||||||
|
'ents_p', 'ents_r', 'ents_f', 'wps']:
|
||||||
|
scores[col] = 0.0
|
||||||
|
scores['dep_loss'] = losses.get('parser', 0.0)
|
||||||
|
scores['ner_loss'] = losses.get('ner', 0.0)
|
||||||
|
scores['tag_loss'] = losses.get('tagger', 0.0)
|
||||||
|
scores.update(dev_scores)
|
||||||
|
scores['wps'] = wps
|
||||||
|
tpl = '\t'.join((
|
||||||
|
'{:d}',
|
||||||
|
'{dep_loss:.3f}',
|
||||||
|
'{ner_loss:.3f}',
|
||||||
|
'{uas:.3f}',
|
||||||
|
'{ents_p:.3f}',
|
||||||
|
'{ents_r:.3f}',
|
||||||
|
'{ents_f:.3f}',
|
||||||
|
'{tags_acc:.3f}',
|
||||||
|
'{token_acc:.3f}',
|
||||||
|
'{wps:.1f}'))
|
||||||
|
print(tpl.format(itn, **scores))
|
||||||
|
|
||||||
|
|
||||||
|
def print_results(scorer):
|
||||||
|
results = {
|
||||||
|
'TOK': '%.2f' % scorer.token_acc,
|
||||||
|
'POS': '%.2f' % scorer.tags_acc,
|
||||||
|
'UAS': '%.2f' % scorer.uas,
|
||||||
|
'LAS': '%.2f' % scorer.las,
|
||||||
|
'NER P': '%.2f' % scorer.ents_p,
|
||||||
|
'NER R': '%.2f' % scorer.ents_r,
|
||||||
|
'NER F': '%.2f' % scorer.ents_f}
|
||||||
|
util.print_table(results, title="Results")
|
Loading…
Reference in New Issue
Block a user