spaCy/bin/parser/train_ud.py

202 lines
7.1 KiB
Python
Raw Normal View History

2017-05-06 17:47:15 +03:00
from __future__ import unicode_literals, print_function
import plac
import json
import random
2016-11-26 02:45:45 +03:00
import pathlib
from spacy.tokens import Doc
from spacy.syntax.nonproj import PseudoProjectivity
from spacy.language import Language
from spacy.gold import GoldParse
from spacy.tagger import Tagger
2017-05-06 17:47:15 +03:00
from spacy.pipeline import DependencyParser, TokenVectorEncoder
from spacy.syntax.parser import get_templates
from spacy.syntax.arc_eager import ArcEager
from spacy.scorer import Scorer
2017-01-09 18:53:46 +03:00
from spacy.language_data.tag_map import TAG_MAP as DEFAULT_TAG_MAP
import spacy.attrs
2016-11-26 02:45:45 +03:00
import io
2017-05-07 15:31:09 +03:00
from thinc.neural.ops import CupyOps
from thinc.neural import Model
2017-05-08 16:29:36 +03:00
from spacy.es import Spanish
from spacy.attrs import POS
2017-05-07 15:31:09 +03:00
2017-05-07 19:04:24 +03:00
from thinc.neural import Model
try:
import cupy
from thinc.neural.ops import CupyOps
except:
cupy = None
2017-03-11 20:11:05 +03:00
def read_conllx(loc, n=0):
2016-11-26 02:45:45 +03:00
with io.open(loc, 'r', encoding='utf8') as file_:
text = file_.read()
2017-03-11 20:11:05 +03:00
i = 0
for sent in text.strip().split('\n\n'):
lines = sent.strip().split('\n')
if lines:
while lines[0].startswith('#'):
lines.pop(0)
tokens = []
for line in lines:
id_, word, lemma, pos, tag, morph, head, dep, _1, \
_2 = line.split('\t')
if '-' in id_ or '.' in id_:
continue
try:
id_ = int(id_) - 1
head = (int(head) - 1) if head != '0' else id_
2017-05-07 04:57:26 +03:00
dep = 'ROOT' if dep == 'root' else dep #'unlabelled'
2017-05-08 15:50:01 +03:00
tag = pos+'__'+dep+'__'+morph
Spanish.Defaults.tag_map[tag] = {POS: pos}
tokens.append((id_, word, tag, head, dep, 'O'))
except:
raise
tuples = [list(t) for t in zip(*tokens)]
yield (None, [[tuples, []]])
2017-03-11 20:11:05 +03:00
i += 1
if n >= 1 and i >= n:
break
2017-05-08 15:50:01 +03:00
def score_model(vocab, encoder, parser, Xs, ys, verbose=False):
scorer = Scorer()
2017-05-06 17:47:15 +03:00
correct = 0.
total = 0.
for doc, gold in zip(Xs, ys):
doc = Doc(vocab, words=[w.text for w in doc])
encoder(doc)
parser(doc)
PseudoProjectivity.deprojectivize(doc)
scorer.score(doc, gold, verbose=verbose)
for token, tag in zip(doc, gold.tags):
2017-05-08 15:50:01 +03:00
if '_' in token.tag_:
univ_guess, _ = token.tag_.split('_', 1)
else:
univ_guess = ''
2017-05-06 17:47:15 +03:00
univ_truth, _ = tag.split('_', 1)
correct += univ_guess == univ_truth
total += 1
return scorer
def organize_data(vocab, train_sents):
Xs = []
ys = []
for _, doc_sents in train_sents:
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
doc = Doc(vocab, words=words)
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
2017-05-06 17:47:15 +03:00
Xs.append(doc)
ys.append(gold)
return Xs, ys
def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None):
LangClass = spacy.util.get_lang_class(lang_name)
train_sents = list(read_conllx(train_loc))
2017-05-07 04:57:26 +03:00
dev_sents = list(read_conllx(dev_loc))
train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
2016-11-26 02:45:45 +03:00
actions = ArcEager.get_actions(gold_parses=train_sents)
features = get_templates('basic')
2017-03-11 20:11:05 +03:00
2016-11-26 02:45:45 +03:00
model_dir = pathlib.Path(model_dir)
if not model_dir.exists():
model_dir.mkdir()
if not (model_dir / 'deps').exists():
(model_dir / 'deps').mkdir()
if not (model_dir / 'pos').exists():
(model_dir / 'pos').mkdir()
2017-01-09 18:53:46 +03:00
with (model_dir / 'deps' / 'config.json').open('wb') as file_:
file_.write(
json.dumps(
{'pseudoprojective': True, 'labels': actions, 'features': features}).encode('utf8'))
vocab = LangClass.Defaults.create_vocab()
if not (model_dir / 'vocab').exists():
(model_dir / 'vocab').mkdir()
else:
if (model_dir / 'vocab' / 'strings.json').exists():
with (model_dir / 'vocab' / 'strings.json').open() as file_:
vocab.strings.load(file_)
if (model_dir / 'vocab' / 'lexemes.bin').exists():
vocab.load_lexemes(model_dir / 'vocab' / 'lexemes.bin')
if clusters_loc is not None:
clusters_loc = pathlib.Path(clusters_loc)
with clusters_loc.open() as file_:
for line in file_:
try:
cluster, word, freq = line.split()
except ValueError:
continue
lex = vocab[word]
lex.cluster = int(cluster[::-1], 2)
# Populate vocab
for _, doc_sents in train_sents:
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
for word in words:
_ = vocab[word]
2016-11-26 02:45:45 +03:00
for dep in deps:
_ = vocab[dep]
for tag in tags:
_ = vocab[tag]
if vocab.morphology.tag_map:
2017-01-09 18:53:46 +03:00
for tag in tags:
2017-05-08 15:55:34 +03:00
vocab.morphology.tag_map[tag] = {POS: tag.split('__', 1)[0]}
tagger = Tagger(vocab)
2017-05-08 16:29:36 +03:00
encoder = TokenVectorEncoder(vocab, width=64)
2017-03-11 20:11:05 +03:00
parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0)
2017-05-06 17:47:15 +03:00
Xs, ys = organize_data(vocab, train_sents)
2017-05-07 04:57:26 +03:00
dev_Xs, dev_ys = organize_data(vocab, dev_sents)
2017-05-06 17:47:15 +03:00
with encoder.model.begin_training(Xs[:100], ys[:100]) as (trainer, optimizer):
docs = list(Xs)
for doc in docs:
encoder(doc)
nn_loss = [0.]
def track_progress():
2017-05-07 04:57:26 +03:00
with encoder.tagger.use_params(optimizer.averages):
2017-05-08 15:55:34 +03:00
with parser.model.use_params(optimizer.averages):
scorer = score_model(vocab, encoder, parser, dev_Xs, dev_ys)
2017-05-06 17:47:15 +03:00
itn = len(nn_loss)
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, nn_loss[-1], scorer.uas, scorer.tags_acc))
nn_loss.append(0.)
2017-05-08 15:55:34 +03:00
track_progress()
2017-05-06 17:47:15 +03:00
trainer.each_epoch.append(track_progress)
2017-05-07 19:04:24 +03:00
trainer.batch_size = 24
2017-05-08 15:55:34 +03:00
trainer.nb_epoch = 40
for docs, golds in trainer.iterate(Xs, ys, progress_bar=True):
2017-05-06 17:47:15 +03:00
docs = [Doc(vocab, words=[w.text for w in doc]) for doc in docs]
tokvecs, upd_tokvecs = encoder.begin_update(docs)
for doc, tokvec in zip(docs, tokvecs):
doc.tensor = tokvec
2017-05-08 16:39:59 +03:00
d_tokvecs = parser.update(docs, golds, sgd=optimizer)
2017-05-06 17:47:15 +03:00
upd_tokvecs(d_tokvecs, sgd=optimizer)
2017-05-07 15:31:09 +03:00
encoder.update(docs, golds, sgd=optimizer)
2017-05-08 16:29:36 +03:00
nlp = LangClass(vocab=vocab, parser=parser)
scorer = score_model(vocab, encoder, parser, read_conllx(dev_loc))
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc))
#nlp.end_training(model_dir)
#scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
#print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc))
2017-03-11 20:11:05 +03:00
if __name__ == '__main__':
import cProfile
import pstats
2017-05-08 01:38:35 +03:00
if 1:
plac.call(main)
else:
cProfile.runctx("plac.call(main)", globals(), locals(), "Profile.prof")
s = pstats.Stats("Profile.prof")
s.strip_dirs().sort_stats("time").print_stats()
2017-05-08 16:29:36 +03:00
plac.call(main)