2017-03-23 13:08:41 +03:00
|
|
|
# coding: utf8
|
|
|
|
from __future__ import unicode_literals, division, print_function
|
|
|
|
|
|
|
|
import json
|
2017-04-23 23:27:10 +03:00
|
|
|
from collections import defaultdict
|
2017-05-16 17:17:30 +03:00
|
|
|
import cytoolz
|
2017-05-17 13:04:50 +03:00
|
|
|
from pathlib import Path
|
|
|
|
import dill
|
2017-03-23 13:08:41 +03:00
|
|
|
|
2017-05-17 13:04:50 +03:00
|
|
|
from ..tokens.doc import Doc
|
2017-03-23 13:08:41 +03:00
|
|
|
from ..scorer import Scorer
|
|
|
|
from ..gold import GoldParse, merge_sents
|
|
|
|
from ..gold import read_json_file as read_gold_json
|
2017-05-08 00:25:29 +03:00
|
|
|
from ..util import prints
|
2017-03-23 13:08:41 +03:00
|
|
|
from .. import util
|
2017-05-17 13:04:50 +03:00
|
|
|
from .. import displacy
|
2017-03-23 13:08:41 +03:00
|
|
|
|
|
|
|
|
2017-05-17 13:04:50 +03:00
|
|
|
def train(language, output_dir, train_data, dev_data, n_iter, n_sents,
|
|
|
|
tagger, parser, ner, parser_L1):
|
2017-05-08 00:25:29 +03:00
|
|
|
output_path = util.ensure_path(output_dir)
|
|
|
|
train_path = util.ensure_path(train_data)
|
|
|
|
dev_path = util.ensure_path(dev_data)
|
|
|
|
if not output_path.exists():
|
|
|
|
prints(output_path, title="Output directory not found", exits=True)
|
|
|
|
if not train_path.exists():
|
|
|
|
prints(train_path, title="Training data not found", exits=True)
|
|
|
|
if dev_path and not dev_path.exists():
|
|
|
|
prints(dev_path, title="Development data not found", exits=True)
|
2017-03-23 13:08:41 +03:00
|
|
|
|
|
|
|
lang = util.get_lang_class(language)
|
2017-03-26 15:16:52 +03:00
|
|
|
parser_cfg = {
|
|
|
|
'pseudoprojective': True,
|
2017-03-26 15:24:07 +03:00
|
|
|
'L1': parser_L1,
|
2017-03-26 15:16:52 +03:00
|
|
|
'n_iter': n_iter,
|
|
|
|
'lang': language,
|
|
|
|
'features': lang.Defaults.parser_features}
|
|
|
|
entity_cfg = {
|
|
|
|
'n_iter': n_iter,
|
|
|
|
'lang': language,
|
|
|
|
'features': lang.Defaults.entity_features}
|
|
|
|
tagger_cfg = {
|
|
|
|
'n_iter': n_iter,
|
|
|
|
'lang': language,
|
|
|
|
'features': lang.Defaults.tagger_features}
|
2017-05-17 13:04:50 +03:00
|
|
|
gold_train = list(read_gold_json(train_path, limit=n_sents))
|
|
|
|
gold_dev = list(read_gold_json(dev_path, limit=n_sents)) if dev_path else None
|
2017-05-16 17:17:30 +03:00
|
|
|
|
2017-05-16 12:21:59 +03:00
|
|
|
train_model(lang, gold_train, gold_dev, output_path, n_iter)
|
2017-03-26 12:48:17 +03:00
|
|
|
if gold_dev:
|
|
|
|
scorer = evaluate(lang, gold_dev, output_path)
|
|
|
|
print_results(scorer)
|
2017-03-23 13:08:41 +03:00
|
|
|
|
|
|
|
|
|
|
|
def train_config(config):
|
2017-05-08 00:25:29 +03:00
|
|
|
config_path = util.ensure_path(config)
|
2017-03-23 13:08:41 +03:00
|
|
|
if not config_path.is_file():
|
2017-05-08 00:25:29 +03:00
|
|
|
prints(config_path, title="Config file not found", exits=True)
|
2017-03-23 13:08:41 +03:00
|
|
|
config = json.load(config_path)
|
|
|
|
for setting in []:
|
|
|
|
if setting not in config.keys():
|
2017-05-08 00:25:29 +03:00
|
|
|
prints("%s not found in config file." % setting, title="Missing setting")
|
2017-03-23 13:08:41 +03:00
|
|
|
|
|
|
|
|
2017-05-16 12:21:59 +03:00
|
|
|
def train_model(Language, train_data, dev_data, output_path, n_iter, **cfg):
|
2017-05-16 17:17:30 +03:00
|
|
|
print("Itn.\tDep. Loss\tUAS\tNER F.\tTag %\tToken %")
|
2017-03-23 13:08:41 +03:00
|
|
|
|
2017-05-17 13:04:50 +03:00
|
|
|
nlp = Language(pipeline=['token_vectors', 'tags', 'dependencies'])
|
2017-05-15 22:46:08 +03:00
|
|
|
|
2017-05-16 12:21:59 +03:00
|
|
|
# TODO: Get spaCy using Thinc's trainer and optimizer
|
|
|
|
with nlp.begin_training(train_data, **cfg) as (trainer, optimizer):
|
2017-05-17 13:04:50 +03:00
|
|
|
for itn, epoch in enumerate(trainer.epochs(n_iter, gold_preproc=True)):
|
2017-05-16 12:21:59 +03:00
|
|
|
losses = defaultdict(float)
|
2017-05-17 13:04:50 +03:00
|
|
|
to_render = []
|
|
|
|
for i, (docs, golds) in enumerate(epoch):
|
2017-05-16 17:17:30 +03:00
|
|
|
state = nlp.update(docs, golds, drop=0., sgd=optimizer)
|
|
|
|
losses['dep_loss'] += state.get('parser_loss', 0.0)
|
2017-05-17 13:04:50 +03:00
|
|
|
to_render.insert(0, nlp(docs[-1].text))
|
|
|
|
to_render[0].user_data['title'] = "Batch %d" % i
|
|
|
|
with Path('/tmp/entities.html').open('w') as file_:
|
|
|
|
html = displacy.render(to_render[:5], style='ent', page=True,
|
|
|
|
options={'compact': True})
|
|
|
|
file_.write(html)
|
|
|
|
with Path('/tmp/parses.html').open('w') as file_:
|
|
|
|
html = displacy.render(to_render[:5], style='dep', page=True,
|
|
|
|
options={'compact': True})
|
|
|
|
file_.write(html)
|
2017-05-15 22:46:08 +03:00
|
|
|
if dev_data:
|
|
|
|
dev_scores = trainer.evaluate(dev_data).scores
|
|
|
|
else:
|
2017-05-16 17:17:30 +03:00
|
|
|
dev_scores = defaultdict(float)
|
|
|
|
print_progress(itn, losses, dev_scores)
|
2017-05-17 13:04:50 +03:00
|
|
|
with (output_path / 'model.bin').open('wb') as file_:
|
|
|
|
dill.dump(nlp, file_, -1)
|
|
|
|
#nlp.to_disk(output_path, tokenizer=False)
|
2017-03-23 13:08:41 +03:00
|
|
|
|
|
|
|
|
2017-05-17 13:04:50 +03:00
|
|
|
def evaluate(Language, gold_tuples, path):
|
|
|
|
with (path / 'model.bin').open('rb') as file_:
|
|
|
|
nlp = dill.load(file_)
|
2017-03-23 13:08:41 +03:00
|
|
|
scorer = Scorer()
|
|
|
|
for raw_text, sents in gold_tuples:
|
|
|
|
sents = merge_sents(sents)
|
|
|
|
for annot_tuples, brackets in sents:
|
|
|
|
if raw_text is None:
|
2017-05-17 13:04:50 +03:00
|
|
|
tokens = Doc(nlp.vocab, words=annot_tuples[1])
|
|
|
|
state = None
|
|
|
|
for proc in nlp.pipeline:
|
|
|
|
state = proc(tokens, state=state)
|
2017-03-23 13:08:41 +03:00
|
|
|
else:
|
|
|
|
tokens = nlp(raw_text)
|
|
|
|
gold = GoldParse.from_annot_tuples(tokens, annot_tuples)
|
|
|
|
scorer.score(tokens, gold)
|
|
|
|
return scorer
|
|
|
|
|
|
|
|
|
2017-05-16 17:17:30 +03:00
|
|
|
def print_progress(itn, losses, dev_scores):
|
2017-05-08 00:25:29 +03:00
|
|
|
# TODO: Fix!
|
2017-05-16 17:17:30 +03:00
|
|
|
scores = {}
|
|
|
|
for col in ['dep_loss', 'uas', 'tags_acc', 'token_acc', 'ents_f']:
|
|
|
|
scores[col] = 0.0
|
|
|
|
scores.update(losses)
|
|
|
|
scores.update(dev_scores)
|
|
|
|
tpl = '{:d}\t{dep_loss:.3f}\t{uas:.3f}\t{ents_f:.3f}\t{tags_acc:.3f}\t{token_acc:.3f}'
|
|
|
|
print(tpl.format(itn, **scores))
|
2017-03-23 13:08:41 +03:00
|
|
|
|
|
|
|
|
|
|
|
def print_results(scorer):
|
2017-03-26 15:16:52 +03:00
|
|
|
results = {
|
|
|
|
'TOK': '%.2f' % scorer.token_acc,
|
|
|
|
'POS': '%.2f' % scorer.tags_acc,
|
|
|
|
'UAS': '%.2f' % scorer.uas,
|
|
|
|
'LAS': '%.2f' % scorer.las,
|
|
|
|
'NER P': '%.2f' % scorer.ents_p,
|
|
|
|
'NER R': '%.2f' % scorer.ents_r,
|
|
|
|
'NER F': '%.2f' % scorer.ents_f}
|
2017-03-23 13:08:41 +03:00
|
|
|
util.print_table(results, title="Results")
|