spaCy/examples/spacy_dynet_lstm.py

209 lines
6.3 KiB
Python
Raw Normal View History

2016-11-13 15:45:21 +03:00
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import division
import pathlib
import plac
import random
from collections import Counter
import numpy as np
from collections import defaultdict
from itertools import count
#import _gdynet as dynet
#from _gdynet import cg
2016-11-13 15:45:21 +03:00
import dynet
from dynet import cg
class Vocab:
def __init__(self, w2i=None):
if w2i is None: w2i = defaultdict(count(0).next)
self.w2i = dict(w2i)
self.i2w = {i:w for w,i in w2i.iteritems()}
@classmethod
def from_corpus(cls, corpus):
w2i = defaultdict(count(0).next)
for sent in corpus:
[w2i[word] for word in sent]
return Vocab(w2i)
def size(self):
return len(self.w2i.keys())
def read_data(path):
with path.open() as file_:
sent = []
for line in file_:
line = line.strip().split()
if not line:
if sent:
yield sent
sent = []
else:
pieces = line
w = pieces[1]
pos = pieces[3]
sent.append((w, pos))
def get_vocab(train, test):
words = []
tags = []
wc = Counter()
for s in train:
for w, p in s:
words.append(w)
tags.append(p)
wc[w] += 1
words.append("_UNK_")
#words=[w if wc[w] > 1 else "_UNK_" for w in words]
tags.append("_START_")
for s in test:
for w, p in s:
words.append(w)
vw = Vocab.from_corpus([words])
vt = Vocab.from_corpus([tags])
return words, tags, wc, vw, vt
class BiTagger(object):
def __init__(self, vw, vt, nwords, ntags):
self.vw = vw
self.vt = vt
2016-11-13 15:45:21 +03:00
self.nwords = nwords
self.ntags = ntags
self.UNK = self.vw.w2i["_UNK_"]
2016-11-13 15:45:21 +03:00
self._model = dynet.Model()
self._sgd = dynet.SimpleSGDTrainer(self._model)
self._E = self._model.add_lookup_parameters((self.nwords, 128))
self._p_t1 = self._model.add_lookup_parameters((self.ntags, 30))
self._pH = self._model.add_parameters((32, 50*2))
self._pO = self._model.add_parameters((self.ntags, 32))
self._fwd_lstm = dynet.LSTMBuilder(1, 128, 50, self._model)
self._bwd_lstm = dynet.LSTMBuilder(1, 128, 50, self._model)
self._words_batch = []
self._tags_batch = []
self._minibatch_size = 32
2016-11-13 15:45:21 +03:00
def __call__(self, words):
2016-11-13 15:45:21 +03:00
dynet.renew_cg()
word_ids = [self.vw.w2i.get(w, self.UNK) for w in words]
wembs = [self._E[w] for w in word_ids]
2016-11-13 15:45:21 +03:00
f_state = self._fwd_lstm.initial_state()
b_state = self._bwd_lstm.initial_state()
fw = [x.output() for x in f_state.add_inputs(wembs)]
bw = [x.output() for x in b_state.add_inputs(reversed(wembs))]
H = dynet.parameter(self._pH)
O = dynet.parameter(self._pO)
tags = []
2016-11-13 15:45:21 +03:00
for i, (f, b) in enumerate(zip(fw, reversed(bw))):
r_t = O * (dynet.tanh(H * dynet.concatenate([f, b])))
out = dynet.softmax(r_t)
tags.append(self.vt.i2w[np.argmax(out.npvalue())])
return tags
def update(self, words, tags):
self._words_batch.append(words)
self._tags_batch.append(tags)
if len(self._words_batch) == self._minibatch_size:
loss = self.update_batch(self._words_batch, self._tags_batch)
self._words_batch = []
self._tags_batch = []
else:
loss = 0
return loss
def update_batch(self, words_batch, tags_batch):
2016-11-13 15:45:21 +03:00
dynet.renew_cg()
length = max(len(words) for words in words_batch)
word_ids = np.zeros((length, len(words_batch)), dtype='int32')
for j, words in enumerate(words_batch):
for i, word in enumerate(words):
word_ids[i, j] = self.vw.w2i.get(word, self.UNK)
tag_ids = np.zeros((length, len(words_batch)), dtype='int32')
for j, tags in enumerate(tags_batch):
for i, tag in enumerate(tags):
tag_ids[i, j] = self.vt.w2i.get(tag, self.UNK)
wembs = [dynet.lookup_batch(self._E, word_ids[i]) for i in range(length)]
2016-11-13 15:45:21 +03:00
wembs = [dynet.noise(we, 0.1) for we in wembs]
f_state = self._fwd_lstm.initial_state()
b_state = self._bwd_lstm.initial_state()
fw = [x.output() for x in f_state.add_inputs(wembs)]
bw = [x.output() for x in b_state.add_inputs(reversed(wembs))]
H = dynet.parameter(self._pH)
O = dynet.parameter(self._pO)
errs = []
for i, (f, b) in enumerate(zip(fw, reversed(bw))):
2016-11-13 15:45:21 +03:00
f_b = dynet.concatenate([f,b])
r_t = O * (dynet.tanh(H * f_b))
err = dynet.pickneglogsoftmax_batch(r_t, tag_ids[i])
errs.append(dynet.sum_batches(err))
2016-11-13 15:45:21 +03:00
sum_errs = dynet.esum(errs)
squared = -sum_errs # * sum_errs
losses = sum_errs.scalar_value()
2016-11-13 15:45:21 +03:00
sum_errs.backward()
self._sgd.update()
return losses
2016-11-13 15:45:21 +03:00
def main(train_loc, dev_loc, model_dir):
train_loc = pathlib.Path(train_loc)
dev_loc = pathlib.Path(dev_loc)
train = list(read_data((train_loc)))
test = list(read_data(dev_loc))
words, tags, wc, vw, vt = get_vocab(train, test)
2016-11-13 15:45:21 +03:00
UNK = vw.w2i["_UNK_"]
nwords = vw.size()
ntags = vt.size()
tagger = BiTagger(vw, vt, nwords, ntags)
2016-11-13 15:45:21 +03:00
tagged = loss = 0
2016-11-13 15:45:21 +03:00
for ITER in xrange(50):
random.shuffle(train)
for i, s in enumerate(train,1):
if i % 5000 == 0:
tagger._sgd.status()
2016-11-13 15:45:21 +03:00
print(loss / tagged)
loss = 0
tagged = 0
if i % 10000 == 0:
good = bad = 0.0
for sent in test:
#word_ids = [vw.w2i.get(w, UNK) for w, t in sent]
tags = tagger([w for w, t in sent])
2016-11-13 15:45:21 +03:00
golds = [t for w, t in sent]
for go, gu in zip(golds, tags):
if go == gu:
good += 1
else:
bad += 1
print(good / (good+bad))
loss += tagger.update([w for w, t in s], [t for w, t in s])
tagged += len(s)
2016-11-13 15:45:21 +03:00
if __name__ == '__main__':
plac.call(main)