2015-10-08 04:00:11 +03:00
|
|
|
import plac
|
|
|
|
import json
|
|
|
|
from os import path
|
|
|
|
import shutil
|
|
|
|
import os
|
|
|
|
import random
|
2016-05-23 15:01:46 +03:00
|
|
|
import io
|
2015-10-08 04:00:11 +03:00
|
|
|
|
|
|
|
from spacy.syntax.util import Config
|
|
|
|
from spacy.gold import GoldParse
|
|
|
|
from spacy.tokenizer import Tokenizer
|
|
|
|
from spacy.vocab import Vocab
|
|
|
|
from spacy.tagger import Tagger
|
|
|
|
from spacy.syntax.parser import Parser
|
|
|
|
from spacy.syntax.arc_eager import ArcEager
|
|
|
|
from spacy.syntax.parser import get_templates
|
|
|
|
from spacy.scorer import Scorer
|
2016-05-23 13:53:00 +03:00
|
|
|
import spacy.attrs
|
2016-07-20 17:28:02 +03:00
|
|
|
from spacy.syntax.nonproj import PseudoProjectivity
|
2015-10-08 04:00:11 +03:00
|
|
|
|
2016-07-20 17:28:02 +03:00
|
|
|
from spacy.syntax._parse_features import *
|
2015-10-08 04:00:11 +03:00
|
|
|
|
2016-07-20 17:28:02 +03:00
|
|
|
from spacy.language import Language
|
2015-10-08 04:00:11 +03:00
|
|
|
|
|
|
|
try:
|
|
|
|
from codecs import open
|
|
|
|
except ImportError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2016-07-20 17:28:02 +03:00
|
|
|
features = [
|
|
|
|
(S2W,),
|
|
|
|
(S1W, ),
|
|
|
|
(S1rW,),
|
|
|
|
(S0lW, ),
|
|
|
|
(S0l2W, ),
|
|
|
|
(S0W, ),
|
|
|
|
(S0r2W, ),
|
|
|
|
(S0rW, ),
|
|
|
|
(N0l2W, ),
|
|
|
|
(N0lW, ),
|
|
|
|
(N0W, ),
|
|
|
|
(N1W, ),
|
|
|
|
(N2W, )
|
|
|
|
]
|
|
|
|
|
|
|
|
slots = [0] * len(features)
|
|
|
|
|
|
|
|
features += [
|
|
|
|
(S2p,),
|
|
|
|
(S1p, ),
|
|
|
|
(S1rp,),
|
|
|
|
(S0lp,),
|
|
|
|
(S0l2p,),
|
|
|
|
(S0p, ),
|
|
|
|
(S0r2p, ),
|
|
|
|
(S0rp, ),
|
|
|
|
(N0l2p, ),
|
|
|
|
(N0lp, ),
|
|
|
|
(N0p, ),
|
|
|
|
(N1p, ),
|
|
|
|
(N2p, )
|
|
|
|
]
|
|
|
|
|
|
|
|
slots += [1] * (len(features) - len(slots))
|
|
|
|
|
|
|
|
features += [
|
|
|
|
(S2L,),
|
|
|
|
(S1L,),
|
|
|
|
(S1rL,),
|
|
|
|
(S0lL,),
|
|
|
|
(S0l2L,),
|
|
|
|
(S0L,),
|
|
|
|
(S0rL,),
|
|
|
|
(S0r2L,),
|
|
|
|
(N0l2L,),
|
|
|
|
(N0lL,),
|
|
|
|
]
|
|
|
|
slots += [2] * (len(features) - len(slots))
|
|
|
|
#
|
|
|
|
#features += [(S2p, S1p), (S1p, S0p)]
|
|
|
|
#slots += [3, 3]
|
|
|
|
#features += [(S0p, N0p)]
|
|
|
|
#slots += [4]
|
|
|
|
# (S0l2p, S0l2L, S0lp, S0l2L),
|
|
|
|
# (N0l2p, N0l2L, N0lp, N0lL),
|
|
|
|
# (S1p, S1rp, S1rL),
|
|
|
|
# (S0p, S0rp, S0rL),
|
|
|
|
#)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2015-10-08 04:00:11 +03:00
|
|
|
class TreebankParser(object):
|
|
|
|
@staticmethod
|
2016-07-20 17:28:02 +03:00
|
|
|
def setup_model_dir(model_dir, labels, vector_widths=(300,), slots=(0,),
|
|
|
|
hidden_layers=(300, 300),
|
|
|
|
feat_set='basic', seed=0, update_step='sgd', eta=0.005, rho=0.0):
|
2015-10-08 04:00:11 +03:00
|
|
|
dep_model_dir = path.join(model_dir, 'deps')
|
|
|
|
pos_model_dir = path.join(model_dir, 'pos')
|
|
|
|
if path.exists(dep_model_dir):
|
|
|
|
shutil.rmtree(dep_model_dir)
|
|
|
|
if path.exists(pos_model_dir):
|
|
|
|
shutil.rmtree(pos_model_dir)
|
|
|
|
os.mkdir(dep_model_dir)
|
|
|
|
os.mkdir(pos_model_dir)
|
|
|
|
|
2016-07-20 17:28:02 +03:00
|
|
|
Config.write(dep_model_dir, 'config', model='neural', feat_set=feat_set,
|
|
|
|
seed=seed, labels=labels, vector_widths=vector_widths, slots=slots,
|
|
|
|
hidden_layers=hidden_layers, update_step=update_step, eta=eta, rho=rho)
|
2015-10-08 04:00:11 +03:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_dir(cls, tag_map, model_dir):
|
2016-07-20 17:28:02 +03:00
|
|
|
vocab = Vocab.load(model_dir, get_lex_attr=Language.default_lex_attrs())
|
2016-05-23 13:53:00 +03:00
|
|
|
vocab.get_lex_attr[spacy.attrs.LANG] = lambda _: 0
|
2015-10-08 04:00:11 +03:00
|
|
|
tokenizer = Tokenizer(vocab, {}, None, None, None)
|
2016-07-20 17:28:02 +03:00
|
|
|
tagger = Tagger.blank(vocab, Tagger.default_templates())
|
2015-10-08 04:00:11 +03:00
|
|
|
|
|
|
|
cfg = Config.read(path.join(model_dir, 'deps'), 'config')
|
|
|
|
parser = Parser.from_dir(path.join(model_dir, 'deps'), vocab.strings, ArcEager)
|
|
|
|
return cls(vocab, tokenizer, tagger, parser)
|
|
|
|
|
|
|
|
def __init__(self, vocab, tokenizer, tagger, parser):
|
|
|
|
self.vocab = vocab
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
self.tagger = tagger
|
|
|
|
self.parser = parser
|
|
|
|
|
|
|
|
def train(self, words, tags, heads, deps):
|
|
|
|
tokens = self.tokenizer.tokens_from_list(list(words))
|
|
|
|
ids = range(len(words))
|
|
|
|
ner = ['O'] * len(words)
|
2016-07-20 17:28:02 +03:00
|
|
|
gold = GoldParse(tokens, ((ids, words, tags, heads, deps, ner)))
|
|
|
|
self.tagger.tag_from_strings(tokens, tags)
|
|
|
|
loss = self.parser.train(tokens, gold)
|
|
|
|
PseudoProjectivity.deprojectivize(tokens)
|
|
|
|
return loss
|
2015-10-08 04:00:11 +03:00
|
|
|
|
|
|
|
def __call__(self, words, tags=None):
|
|
|
|
tokens = self.tokenizer.tokens_from_list(list(words))
|
|
|
|
if tags is None:
|
|
|
|
self.tagger(tokens)
|
|
|
|
else:
|
|
|
|
self.tagger.tag_from_strings(tokens, tags)
|
|
|
|
self.parser(tokens)
|
2016-07-20 17:28:02 +03:00
|
|
|
PseudoProjectivity.deprojectivize(tokens)
|
2015-10-08 04:00:11 +03:00
|
|
|
return tokens
|
|
|
|
|
|
|
|
def end_training(self, data_dir):
|
2016-05-23 15:01:46 +03:00
|
|
|
self.parser.model.end_training()
|
|
|
|
self.parser.model.dump(path.join(data_dir, 'deps', 'model'))
|
|
|
|
self.tagger.model.end_training()
|
|
|
|
self.tagger.model.dump(path.join(data_dir, 'pos', 'model'))
|
|
|
|
strings_loc = path.join(data_dir, 'vocab', 'strings.json')
|
|
|
|
with io.open(strings_loc, 'w', encoding='utf8') as file_:
|
|
|
|
self.vocab.strings.dump(file_)
|
|
|
|
self.vocab.dump(path.join(data_dir, 'vocab', 'lexemes.bin'))
|
|
|
|
|
|
|
|
|
2015-10-08 04:00:11 +03:00
|
|
|
def read_conllx(loc):
|
|
|
|
with open(loc, 'r', 'utf8') as file_:
|
|
|
|
text = file_.read()
|
|
|
|
for sent in text.strip().split('\n\n'):
|
|
|
|
lines = sent.strip().split('\n')
|
|
|
|
if lines:
|
2016-05-23 13:53:00 +03:00
|
|
|
while lines[0].startswith('#'):
|
2015-10-08 04:00:11 +03:00
|
|
|
lines.pop(0)
|
|
|
|
tokens = []
|
|
|
|
for line in lines:
|
|
|
|
id_, word, lemma, pos, tag, morph, head, dep, _1, _2 = line.split()
|
|
|
|
if '-' in id_:
|
|
|
|
continue
|
|
|
|
id_ = int(id_) - 1
|
|
|
|
head = (int(head) - 1) if head != '0' else id_
|
|
|
|
dep = 'ROOT' if dep == 'root' else dep
|
2016-07-20 17:28:02 +03:00
|
|
|
tokens.append([id_, word, tag, head, dep, 'O'])
|
|
|
|
tuples = [list(el) for el in zip(*tokens)]
|
2015-10-08 04:00:11 +03:00
|
|
|
yield (None, [(tuples, [])])
|
|
|
|
|
|
|
|
|
|
|
|
def score_model(nlp, gold_docs, verbose=False):
|
|
|
|
scorer = Scorer()
|
|
|
|
for _, gold_doc in gold_docs:
|
|
|
|
for annot_tuples, _ in gold_doc:
|
|
|
|
tokens = nlp(list(annot_tuples[1]), tags=list(annot_tuples[2]))
|
|
|
|
gold = GoldParse(tokens, annot_tuples)
|
|
|
|
scorer.score(tokens, gold, verbose=verbose)
|
|
|
|
return scorer
|
|
|
|
|
|
|
|
|
2016-07-20 17:28:02 +03:00
|
|
|
@plac.annotations(
|
|
|
|
n_iter=("Number of training iterations", "option", "i", int),
|
|
|
|
)
|
|
|
|
def main(train_loc, dev_loc, model_dir, tag_map_loc, n_iter=10):
|
2015-10-08 04:00:11 +03:00
|
|
|
with open(tag_map_loc) as file_:
|
|
|
|
tag_map = json.loads(file_.read())
|
|
|
|
train_sents = list(read_conllx(train_loc))
|
2016-07-20 17:28:02 +03:00
|
|
|
train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
|
|
|
|
dev_sents = list(read_conllx(dev_loc))
|
|
|
|
|
2015-10-08 04:00:11 +03:00
|
|
|
labels = ArcEager.get_labels(train_sents)
|
|
|
|
|
2016-07-20 17:28:02 +03:00
|
|
|
TreebankParser.setup_model_dir(model_dir, labels,
|
|
|
|
feat_set=features, vector_widths=(10,10,10,30,30), slots=slots,
|
|
|
|
hidden_layers=(100,100,100), update_step='adam')
|
2015-10-08 04:00:11 +03:00
|
|
|
|
|
|
|
nlp = TreebankParser.from_dir(tag_map, model_dir)
|
2016-07-20 17:28:02 +03:00
|
|
|
nlp.parser.model.rho = 1e-4
|
|
|
|
print(nlp.parser.model.widths)
|
2015-10-08 04:00:11 +03:00
|
|
|
|
2016-07-20 17:28:02 +03:00
|
|
|
for itn in range(n_iter):
|
|
|
|
loss = 0.0
|
2015-10-08 04:00:11 +03:00
|
|
|
for _, doc_sents in train_sents:
|
|
|
|
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
|
2016-07-20 17:28:02 +03:00
|
|
|
loss += nlp.train(words, tags, heads, deps)
|
2015-10-08 04:00:11 +03:00
|
|
|
random.shuffle(train_sents)
|
2016-07-20 17:28:02 +03:00
|
|
|
scorer = score_model(nlp, dev_sents)
|
|
|
|
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.tags_acc))
|
|
|
|
print(nlp.parser.model.mem.size)
|
2015-10-08 04:00:11 +03:00
|
|
|
nlp.end_training(model_dir)
|
|
|
|
scorer = score_model(nlp, read_conllx(dev_loc))
|
2016-07-20 17:28:02 +03:00
|
|
|
print('Dev: %.3f\t%.3f\t%.3f' % (scorer.uas, scorer.las, scorer.tags_acc))
|
2015-10-08 04:00:11 +03:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
plac.call(main)
|