mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-25 19:33:42 +03:00
Update dynet example to use minibatching
This commit is contained in:
parent
ae681aa555
commit
ef76c28d70
|
@ -11,6 +11,8 @@ import numpy as np
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from itertools import count
|
from itertools import count
|
||||||
|
|
||||||
|
#import _gdynet as dynet
|
||||||
|
#from _gdynet import cg
|
||||||
import dynet
|
import dynet
|
||||||
from dynet import cg
|
from dynet import cg
|
||||||
|
|
||||||
|
@ -65,15 +67,18 @@ def get_vocab(train, test):
|
||||||
words.append(w)
|
words.append(w)
|
||||||
vw = Vocab.from_corpus([words])
|
vw = Vocab.from_corpus([words])
|
||||||
vt = Vocab.from_corpus([tags])
|
vt = Vocab.from_corpus([tags])
|
||||||
UNK = vw.w2i["_UNK_"]
|
|
||||||
return words, tags, wc, vw, vt
|
return words, tags, wc, vw, vt
|
||||||
|
|
||||||
|
|
||||||
class BiTagger(object):
|
class BiTagger(object):
|
||||||
def __init__(self, nwords, ntags):
|
def __init__(self, vw, vt, nwords, ntags):
|
||||||
|
self.vw = vw
|
||||||
|
self.vt = vt
|
||||||
self.nwords = nwords
|
self.nwords = nwords
|
||||||
self.ntags = ntags
|
self.ntags = ntags
|
||||||
|
|
||||||
|
self.UNK = self.vw.w2i["_UNK_"]
|
||||||
|
|
||||||
self._model = dynet.Model()
|
self._model = dynet.Model()
|
||||||
self._sgd = dynet.SimpleSGDTrainer(self._model)
|
self._sgd = dynet.SimpleSGDTrainer(self._model)
|
||||||
|
|
||||||
|
@ -85,11 +90,14 @@ class BiTagger(object):
|
||||||
|
|
||||||
self._fwd_lstm = dynet.LSTMBuilder(1, 128, 50, self._model)
|
self._fwd_lstm = dynet.LSTMBuilder(1, 128, 50, self._model)
|
||||||
self._bwd_lstm = dynet.LSTMBuilder(1, 128, 50, self._model)
|
self._bwd_lstm = dynet.LSTMBuilder(1, 128, 50, self._model)
|
||||||
|
self._words_batch = []
|
||||||
|
self._tags_batch = []
|
||||||
|
self._minibatch_size = 32
|
||||||
|
|
||||||
def __call__(self, doc):
|
def __call__(self, words):
|
||||||
dynet.renew_cg()
|
dynet.renew_cg()
|
||||||
|
word_ids = [self.vw.w2i.get(w, self.UNK) for w in words]
|
||||||
wembs = [self._E[word.rank] for word in doc]
|
wembs = [self._E[w] for w in word_ids]
|
||||||
|
|
||||||
f_state = self._fwd_lstm.initial_state()
|
f_state = self._fwd_lstm.initial_state()
|
||||||
b_state = self._bwd_lstm.initial_state()
|
b_state = self._bwd_lstm.initial_state()
|
||||||
|
@ -100,14 +108,36 @@ class BiTagger(object):
|
||||||
H = dynet.parameter(self._pH)
|
H = dynet.parameter(self._pH)
|
||||||
O = dynet.parameter(self._pO)
|
O = dynet.parameter(self._pO)
|
||||||
|
|
||||||
|
tags = []
|
||||||
for i, (f, b) in enumerate(zip(fw, reversed(bw))):
|
for i, (f, b) in enumerate(zip(fw, reversed(bw))):
|
||||||
r_t = O * (dynet.tanh(H * dynet.concatenate([f, b])))
|
r_t = O * (dynet.tanh(H * dynet.concatenate([f, b])))
|
||||||
out = dynet.softmax(r_t)
|
out = dynet.softmax(r_t)
|
||||||
doc[i].tag = np.argmax(out.npvalue())
|
tags.append(self.vt.i2w[np.argmax(out.npvalue())])
|
||||||
|
return tags
|
||||||
|
|
||||||
def update(self, doc, gold):
|
def update(self, words, tags):
|
||||||
|
self._words_batch.append(words)
|
||||||
|
self._tags_batch.append(tags)
|
||||||
|
if len(self._words_batch) == self._minibatch_size:
|
||||||
|
loss = self.update_batch(self._words_batch, self._tags_batch)
|
||||||
|
self._words_batch = []
|
||||||
|
self._tags_batch = []
|
||||||
|
else:
|
||||||
|
loss = 0
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def update_batch(self, words_batch, tags_batch):
|
||||||
dynet.renew_cg()
|
dynet.renew_cg()
|
||||||
wembs = [self._E[word.rank] for word in doc]
|
length = max(len(words) for words in words_batch)
|
||||||
|
word_ids = np.zeros((length, len(words_batch)), dtype='int32')
|
||||||
|
for j, words in enumerate(words_batch):
|
||||||
|
for i, word in enumerate(words):
|
||||||
|
word_ids[i, j] = self.vw.w2i.get(word, self.UNK)
|
||||||
|
tag_ids = np.zeros((length, len(words_batch)), dtype='int32')
|
||||||
|
for j, tags in enumerate(tags_batch):
|
||||||
|
for i, tag in enumerate(tags):
|
||||||
|
tag_ids[i, j] = self.vt.w2i.get(tag, self.UNK)
|
||||||
|
wembs = [dynet.lookup_batch(self._E, word_ids[i]) for i in range(length)]
|
||||||
wembs = [dynet.noise(we, 0.1) for we in wembs]
|
wembs = [dynet.noise(we, 0.1) for we in wembs]
|
||||||
|
|
||||||
f_state = self._fwd_lstm.initial_state()
|
f_state = self._fwd_lstm.initial_state()
|
||||||
|
@ -120,17 +150,17 @@ class BiTagger(object):
|
||||||
O = dynet.parameter(self._pO)
|
O = dynet.parameter(self._pO)
|
||||||
|
|
||||||
errs = []
|
errs = []
|
||||||
for f, b, t in zip(fw, reversed(bw), tags):
|
for i, (f, b) in enumerate(zip(fw, reversed(bw))):
|
||||||
f_b = dynet.concatenate([f,b])
|
f_b = dynet.concatenate([f,b])
|
||||||
r_t = O * (dynet.tanh(H * f_b))
|
r_t = O * (dynet.tanh(H * f_b))
|
||||||
err = dynet.pickneglogsoftmax(r_t, t)
|
err = dynet.pickneglogsoftmax_batch(r_t, tag_ids[i])
|
||||||
errs.append(err)
|
errs.append(dynet.sum_batches(err))
|
||||||
|
|
||||||
sum_errs = dynet.esum(errs)
|
sum_errs = dynet.esum(errs)
|
||||||
squared = -sum_errs # * sum_errs
|
squared = -sum_errs # * sum_errs
|
||||||
loss += sum_errs.scalar_value()
|
losses = sum_errs.scalar_value()
|
||||||
sum_errs.backward()
|
sum_errs.backward()
|
||||||
sgd.update()
|
self._sgd.update()
|
||||||
|
return losses
|
||||||
|
|
||||||
|
|
||||||
def main(train_loc, dev_loc, model_dir):
|
def main(train_loc, dev_loc, model_dir):
|
||||||
|
@ -140,40 +170,29 @@ def main(train_loc, dev_loc, model_dir):
|
||||||
train = list(read_data((train_loc)))
|
train = list(read_data((train_loc)))
|
||||||
test = list(read_data(dev_loc))
|
test = list(read_data(dev_loc))
|
||||||
|
|
||||||
tagger = BiTagger(vocab)
|
words, tags, wc, vw, vt = get_vocab(train, test)
|
||||||
|
|
||||||
UNK = vw.w2i["_UNK_"]
|
UNK = vw.w2i["_UNK_"]
|
||||||
nwords = vw.size()
|
nwords = vw.size()
|
||||||
ntags = vt.size()
|
ntags = vt.size()
|
||||||
|
|
||||||
model = dynet.Model()
|
tagger = BiTagger(vw, vt, nwords, ntags)
|
||||||
sgd = dynet.SimpleSGDTrainer(model)
|
|
||||||
|
|
||||||
E = model.add_lookup_parameters((nwords, 128))
|
|
||||||
p_t1 = model.add_lookup_parameters((ntags, 30))
|
|
||||||
|
|
||||||
pH = model.add_parameters((32, 50*2))
|
|
||||||
pO = model.add_parameters((ntags, 32))
|
|
||||||
|
|
||||||
builders=[
|
|
||||||
dynet.LSTMBuilder(1, 128, 50, model),
|
|
||||||
dynet.LSTMBuilder(1, 128, 50, model),
|
|
||||||
]
|
|
||||||
|
|
||||||
tagged = loss = 0
|
tagged = loss = 0
|
||||||
|
|
||||||
for ITER in xrange(50):
|
for ITER in xrange(50):
|
||||||
random.shuffle(train)
|
random.shuffle(train)
|
||||||
for i, s in enumerate(train,1):
|
for i, s in enumerate(train,1):
|
||||||
if i % 5000 == 0:
|
if i % 5000 == 0:
|
||||||
sgd.status()
|
tagger._sgd.status()
|
||||||
print(loss / tagged)
|
print(loss / tagged)
|
||||||
loss = 0
|
loss = 0
|
||||||
tagged = 0
|
tagged = 0
|
||||||
if i % 10000 == 0:
|
if i % 10000 == 0:
|
||||||
good = bad = 0.0
|
good = bad = 0.0
|
||||||
for sent in test:
|
for sent in test:
|
||||||
word_ids = [vw.w2i.get(w, UNK) for w, t in sent]
|
#word_ids = [vw.w2i.get(w, UNK) for w, t in sent]
|
||||||
tags = tagger.tag_sent(word_ids)
|
tags = tagger([w for w, t in sent])
|
||||||
golds = [t for w, t in sent]
|
golds = [t for w, t in sent]
|
||||||
for go, gu in zip(golds, tags):
|
for go, gu in zip(golds, tags):
|
||||||
if go == gu:
|
if go == gu:
|
||||||
|
@ -181,9 +200,8 @@ def main(train_loc, dev_loc, model_dir):
|
||||||
else:
|
else:
|
||||||
bad += 1
|
bad += 1
|
||||||
print(good / (good+bad))
|
print(good / (good+bad))
|
||||||
ws = [vw.w2i.get(w, UNK) for w,p in s]
|
loss += tagger.update([w for w, t in s], [t for w, t in s])
|
||||||
ps = [vt.w2i[p] for w, p in s]
|
tagged += len(s)
|
||||||
model.update(ws, ps)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user