2015-10-30 04:54:49 +03:00
|
|
|
#!/usr/bin/env python
|
2016-07-20 17:28:02 +03:00
|
|
|
from __future__ import print_function
|
2015-10-30 04:54:49 +03:00
|
|
|
from __future__ import division
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
|
|
|
|
import os
|
|
|
|
from os import path
|
|
|
|
import shutil
|
2016-02-03 00:29:34 +03:00
|
|
|
import io
|
2015-10-30 04:54:49 +03:00
|
|
|
import random
|
|
|
|
import time
|
|
|
|
import gzip
|
2016-07-20 17:28:02 +03:00
|
|
|
import re
|
|
|
|
import numpy
|
2016-09-04 17:57:10 +03:00
|
|
|
from math import sqrt
|
2015-10-30 04:54:49 +03:00
|
|
|
|
|
|
|
import plac
|
|
|
|
import cProfile
|
|
|
|
import pstats
|
|
|
|
|
|
|
|
import spacy.util
|
|
|
|
from spacy.en import English
|
|
|
|
from spacy.gold import GoldParse
|
|
|
|
|
|
|
|
from spacy.syntax.util import Config
|
2015-10-30 16:53:51 +03:00
|
|
|
from spacy.syntax.arc_eager import ArcEager
|
2016-07-20 17:28:02 +03:00
|
|
|
from spacy.syntax.parser import Parser, get_templates
|
|
|
|
from spacy.syntax.beam_parser import BeamParser
|
2015-10-30 04:54:49 +03:00
|
|
|
from spacy.scorer import Scorer
|
2015-10-30 16:53:51 +03:00
|
|
|
from spacy.tagger import Tagger
|
2016-07-20 17:28:02 +03:00
|
|
|
from spacy.syntax.nonproj import PseudoProjectivity
|
|
|
|
from spacy.syntax import _parse_features as pf
|
2015-10-30 16:53:51 +03:00
|
|
|
|
|
|
|
# Last updated for spaCy v0.97
|
2015-10-30 04:54:49 +03:00
|
|
|
|
|
|
|
|
2016-07-20 17:28:02 +03:00
|
|
|
def read_conll(file_, n=0):
|
2015-10-30 04:54:49 +03:00
|
|
|
"""Read a standard CoNLL/MALT-style format"""
|
2016-07-20 17:28:02 +03:00
|
|
|
text = file_.read().strip()
|
|
|
|
sent_strs = re.split(r'\n\s*\n', text)
|
|
|
|
for sent_id, sent_str in enumerate(sent_strs):
|
|
|
|
if not sent_str.strip():
|
|
|
|
continue
|
2015-10-30 04:54:49 +03:00
|
|
|
ids = []
|
|
|
|
words = []
|
|
|
|
heads = []
|
|
|
|
labels = []
|
|
|
|
tags = []
|
2016-07-20 17:28:02 +03:00
|
|
|
for i, line in enumerate(sent_str.strip().split('\n')):
|
2015-10-30 04:54:49 +03:00
|
|
|
word, pos_string, head_idx, label = _parse_line(line)
|
|
|
|
words.append(word)
|
|
|
|
if head_idx < 0:
|
|
|
|
head_idx = i
|
|
|
|
ids.append(i)
|
|
|
|
heads.append(head_idx)
|
|
|
|
labels.append(label)
|
|
|
|
tags.append(pos_string)
|
|
|
|
annot = (ids, words, tags, heads, labels, ['O'] * len(ids))
|
2016-07-20 17:28:02 +03:00
|
|
|
yield (None, [(annot, None)])
|
|
|
|
if n and sent_id >= n:
|
|
|
|
break
|
2015-10-30 04:54:49 +03:00
|
|
|
|
|
|
|
|
|
|
|
def _parse_line(line):
|
|
|
|
pieces = line.split()
|
|
|
|
if len(pieces) == 4:
|
|
|
|
word, pos, head_idx, label = pieces
|
|
|
|
head_idx = int(head_idx)
|
2016-02-03 00:29:34 +03:00
|
|
|
elif len(pieces) == 15:
|
|
|
|
id_ = int(pieces[0].split('_')[-1])
|
|
|
|
word = pieces[1]
|
|
|
|
pos = pieces[4]
|
|
|
|
head_idx = int(pieces[8])-1
|
|
|
|
label = pieces[10]
|
2015-10-30 04:54:49 +03:00
|
|
|
else:
|
2016-02-03 00:29:34 +03:00
|
|
|
id_ = int(pieces[0].split('_')[-1])
|
2015-10-30 04:54:49 +03:00
|
|
|
word = pieces[1]
|
|
|
|
pos = pieces[4]
|
|
|
|
head_idx = int(pieces[6])-1
|
|
|
|
label = pieces[7]
|
2016-07-20 17:28:02 +03:00
|
|
|
if head_idx < 0:
|
2016-02-03 00:29:34 +03:00
|
|
|
label = 'ROOT'
|
2015-10-30 04:54:49 +03:00
|
|
|
return word, pos, head_idx, label
|
|
|
|
|
2016-07-20 17:28:02 +03:00
|
|
|
|
|
|
|
def print_words(strings, words, embeddings):
|
|
|
|
ids = {strings[word]: word for word in words}
|
|
|
|
vectors = {}
|
|
|
|
for key, values in embeddings[5]:
|
|
|
|
if key in ids:
|
|
|
|
vectors[strings[key]] = values
|
|
|
|
for word in words:
|
|
|
|
if word in vectors:
|
|
|
|
print(word, vectors[word])
|
|
|
|
|
2015-10-30 04:54:49 +03:00
|
|
|
|
|
|
|
def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False):
|
|
|
|
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
2016-07-20 17:28:02 +03:00
|
|
|
nlp.tagger.tag_from_strings(tokens, annot_tuples[2])
|
2015-10-30 04:54:49 +03:00
|
|
|
nlp.parser(tokens)
|
2016-02-03 00:29:34 +03:00
|
|
|
gold = GoldParse(tokens, annot_tuples, make_projective=False)
|
|
|
|
scorer.score(tokens, gold, verbose=verbose, punct_labels=('--', 'p', 'punct'))
|
2015-10-30 04:54:49 +03:00
|
|
|
|
|
|
|
|
2016-07-24 11:44:59 +03:00
|
|
|
def score_file(nlp, loc):
|
|
|
|
scorer = Scorer()
|
|
|
|
with io.open(loc, 'r', encoding='utf8') as file_:
|
|
|
|
for _, sents in read_conll(file_):
|
|
|
|
for annot_tuples, _ in sents:
|
|
|
|
score_model(scorer, nlp, None, annot_tuples)
|
|
|
|
return scorer
|
|
|
|
|
|
|
|
|
|
|
|
def score_sents(nlp, gold_tuples):
|
|
|
|
scorer = Scorer()
|
|
|
|
for _, sents in gold_tuples:
|
|
|
|
for annot_tuples, _ in sents:
|
|
|
|
score_model(scorer, nlp, None, annot_tuples)
|
|
|
|
return scorer
|
|
|
|
|
|
|
|
|
2016-07-20 17:28:02 +03:00
|
|
|
def train(Language, gold_tuples, model_dir, dev_loc, n_iter=15, feat_set=u'basic',
|
2016-08-29 15:24:30 +03:00
|
|
|
width=128, depth=3,
|
|
|
|
learn_rate=0.001, noise=0.01, update_step='sgd_cm', regularization=0.0,
|
2016-07-20 17:28:02 +03:00
|
|
|
batch_norm=False, seed=0, gold_preproc=False, force_gold=False):
|
2015-10-30 04:54:49 +03:00
|
|
|
dep_model_dir = path.join(model_dir, 'deps')
|
|
|
|
pos_model_dir = path.join(model_dir, 'pos')
|
|
|
|
if path.exists(dep_model_dir):
|
|
|
|
shutil.rmtree(dep_model_dir)
|
|
|
|
if path.exists(pos_model_dir):
|
|
|
|
shutil.rmtree(pos_model_dir)
|
|
|
|
os.mkdir(dep_model_dir)
|
|
|
|
os.mkdir(pos_model_dir)
|
|
|
|
|
2016-07-20 17:28:02 +03:00
|
|
|
if feat_set != 'neural':
|
2016-07-24 11:44:59 +03:00
|
|
|
Config.write(dep_model_dir, 'config', feat_set=feat_set, seed=seed,
|
2016-08-29 15:24:30 +03:00
|
|
|
labels=ArcEager.get_labels(gold_tuples),
|
|
|
|
eta=learn_rate, rho=regularization)
|
2016-07-20 17:28:02 +03:00
|
|
|
|
|
|
|
else:
|
2016-08-29 15:24:30 +03:00
|
|
|
hidden_layers = [width] * depth
|
2016-07-20 17:28:02 +03:00
|
|
|
Config.write(dep_model_dir, 'config',
|
|
|
|
model='neural',
|
|
|
|
seed=seed,
|
|
|
|
labels=ArcEager.get_labels(gold_tuples),
|
|
|
|
feat_set=feat_set,
|
|
|
|
hidden_layers=hidden_layers,
|
|
|
|
update_step=update_step,
|
|
|
|
batch_norm=batch_norm,
|
|
|
|
eta=learn_rate,
|
|
|
|
mu=0.9,
|
2016-08-05 19:24:01 +03:00
|
|
|
noise=noise,
|
2016-08-29 15:24:30 +03:00
|
|
|
rho=regularization)
|
2015-10-30 04:54:49 +03:00
|
|
|
|
2015-10-30 16:53:51 +03:00
|
|
|
nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False)
|
2016-09-01 11:45:06 +03:00
|
|
|
# Insert into vocab
|
|
|
|
for _, sents in gold_tuples:
|
|
|
|
for annot_tuples, _ in sents:
|
|
|
|
for word in annot_tuples[1]:
|
|
|
|
_ = nlp.vocab[word]
|
2015-10-30 16:53:51 +03:00
|
|
|
nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates())
|
2016-08-29 15:24:30 +03:00
|
|
|
#nlp.parser = BeamParser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager)
|
|
|
|
nlp.parser = Parser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager)
|
2016-07-20 17:28:02 +03:00
|
|
|
for word in nlp.vocab:
|
|
|
|
word.norm = word.orth
|
|
|
|
|
|
|
|
print(nlp.parser.model.widths)
|
2015-10-30 16:53:51 +03:00
|
|
|
|
2016-09-01 11:45:06 +03:00
|
|
|
print("Itn.\tP.Loss\tTrain\tDev\tnr_weight\tnr_feat")
|
2016-07-20 17:28:02 +03:00
|
|
|
last_score = 0.0
|
|
|
|
nr_trimmed = 0
|
|
|
|
eg_seen = 0
|
|
|
|
loss = 0
|
2016-07-27 03:56:36 +03:00
|
|
|
micro_eval = gold_tuples[:50]
|
2015-10-30 04:54:49 +03:00
|
|
|
for itn in range(n_iter):
|
2016-07-27 03:56:36 +03:00
|
|
|
try:
|
|
|
|
eg_seen = _train_epoch(nlp, gold_tuples, eg_seen, itn,
|
|
|
|
dev_loc, micro_eval)
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
print("Saving model...")
|
|
|
|
break
|
2016-08-20 05:16:50 +03:00
|
|
|
dev_uas = score_file(nlp, dev_loc).uas
|
|
|
|
print("Dev before average", dev_uas)
|
|
|
|
|
2016-08-05 19:24:01 +03:00
|
|
|
nlp.parser.model.end_training()
|
2016-09-01 11:45:06 +03:00
|
|
|
nlp.parser.model.dump(path.join(model_dir, 'deps', 'model'))
|
2016-07-27 03:56:36 +03:00
|
|
|
print("Saved. Evaluating...")
|
|
|
|
return nlp
|
|
|
|
|
2016-08-20 05:16:50 +03:00
|
|
|
|
2016-07-27 03:56:36 +03:00
|
|
|
def _train_epoch(nlp, gold_tuples, eg_seen, itn, dev_loc, micro_eval):
|
|
|
|
random.shuffle(gold_tuples)
|
|
|
|
loss = 0
|
|
|
|
nr_trimmed = 0
|
|
|
|
for _, sents in gold_tuples:
|
|
|
|
for annot_tuples, _ in sents:
|
|
|
|
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
|
|
|
nlp.tagger.tag_from_strings(tokens, annot_tuples[2])
|
|
|
|
gold = GoldParse(tokens, annot_tuples)
|
|
|
|
loss += nlp.parser.train(tokens, gold, itn=itn)
|
|
|
|
eg_seen += 1
|
|
|
|
if eg_seen % 1000 == 0:
|
|
|
|
if eg_seen % 20000 == 0:
|
2016-07-24 11:44:59 +03:00
|
|
|
dev_uas = score_file(nlp, dev_loc).uas
|
2016-07-27 03:56:36 +03:00
|
|
|
else:
|
|
|
|
dev_uas = 0.0
|
|
|
|
train_uas = score_sents(nlp, micro_eval).uas
|
|
|
|
nr_upd = nlp.parser.model.time
|
2016-09-01 11:45:06 +03:00
|
|
|
nr_weight = nlp.parser.model.nr_weight
|
|
|
|
nr_feat = nlp.parser.model.nr_active_feat
|
|
|
|
print('%d,%d:\t%d\t%.3f\t%.3f\t%d\t%d' % (itn, nr_upd, int(loss),
|
|
|
|
train_uas, dev_uas,
|
|
|
|
nr_weight, nr_feat))
|
2016-07-27 03:56:36 +03:00
|
|
|
loss = 0
|
2016-09-04 17:57:10 +03:00
|
|
|
nlp.parser.model.learn_rate *= 0.99
|
2016-07-27 03:56:36 +03:00
|
|
|
return eg_seen
|
2015-10-30 04:54:49 +03:00
|
|
|
|
|
|
|
|
2016-02-03 00:58:06 +03:00
|
|
|
@plac.annotations(
|
|
|
|
train_loc=("Location of CoNLL 09 formatted training file"),
|
|
|
|
dev_loc=("Location of CoNLL 09 formatted development file"),
|
|
|
|
model_dir=("Location of output model directory"),
|
|
|
|
n_iter=("Number of training iterations", "option", "i", int),
|
2016-07-20 17:28:02 +03:00
|
|
|
batch_norm=("Use batch normalization and residual connections", "flag", "b"),
|
|
|
|
update_step=("Update step", "option", "u", str),
|
|
|
|
learn_rate=("Learn rate", "option", "e", float),
|
2016-08-29 15:24:30 +03:00
|
|
|
regularization=("Regularization penalty", "option", "r", float),
|
|
|
|
gradient_noise=("Gradient noise", "option", "W", float),
|
|
|
|
neural=("Use neural network?", "flag", "N"),
|
|
|
|
width=("Width of hidden layers", "option", "w", int),
|
|
|
|
depth=("Number of hidden layers", "option", "d", int),
|
2016-02-03 00:58:06 +03:00
|
|
|
)
|
2016-07-20 17:28:02 +03:00
|
|
|
def main(train_loc, dev_loc, model_dir, n_iter=15, neural=False, batch_norm=False,
|
2016-08-29 15:24:30 +03:00
|
|
|
width=128, depth=3, learn_rate=0.001, gradient_noise=0.0, regularization=0.0,
|
|
|
|
update_step='sgd_cm'):
|
2016-02-03 00:29:34 +03:00
|
|
|
with io.open(train_loc, 'r', encoding='utf8') as file_:
|
2016-07-20 17:28:02 +03:00
|
|
|
train_sents = list(read_conll(file_))
|
2016-07-31 12:42:17 +03:00
|
|
|
# Preprocess training data here before ArcEager.get_labels() is called
|
2016-07-20 17:28:02 +03:00
|
|
|
train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
|
|
|
|
|
|
|
|
nlp = train(English, train_sents, model_dir, dev_loc, n_iter=n_iter,
|
2016-08-29 15:24:30 +03:00
|
|
|
width=width, depth=depth,
|
2016-07-20 17:28:02 +03:00
|
|
|
feat_set='neural' if neural else 'basic',
|
|
|
|
batch_norm=batch_norm,
|
|
|
|
learn_rate=learn_rate,
|
2016-08-29 15:24:30 +03:00
|
|
|
regularization=regularization,
|
2016-08-05 19:24:01 +03:00
|
|
|
update_step=update_step,
|
|
|
|
noise=gradient_noise)
|
2016-07-27 03:56:36 +03:00
|
|
|
|
|
|
|
scorer = score_file(nlp, dev_loc)
|
2016-07-20 17:28:02 +03:00
|
|
|
print('TOK', scorer.token_acc)
|
2015-10-30 04:54:49 +03:00
|
|
|
print('POS', scorer.tags_acc)
|
|
|
|
print('UAS', scorer.uas)
|
|
|
|
print('LAS', scorer.las)
|
2016-09-01 11:45:06 +03:00
|
|
|
print('nr_weight', nlp.parser.model.nr_weight)
|
|
|
|
print('nr_feat', nlp.parser.model.nr_active_feat)
|
2015-10-30 04:54:49 +03:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
plac.call(main)
|