mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Minibatch the forward pass. THe output argmax is incorrect...
This commit is contained in:
parent
8f053fd943
commit
718e66a7b9
|
@ -118,6 +118,42 @@ class BiTagger(object):
|
||||||
tags.append(self.vt.i2w[np.argmax(out.npvalue())])
|
tags.append(self.vt.i2w[np.argmax(out.npvalue())])
|
||||||
return tags
|
return tags
|
||||||
|
|
||||||
|
def predict_batch(self, words_batch):
|
||||||
|
dynet.renew_cg()
|
||||||
|
length = max(len(words) for words in words_batch)
|
||||||
|
word_ids = np.zeros((length, len(words_batch)), dtype='int32')
|
||||||
|
for j, words in enumerate(words_batch):
|
||||||
|
for i, word in enumerate(words):
|
||||||
|
word_ids[i, j] = self.vw.w2i.get(word, self.UNK)
|
||||||
|
wembs = [dynet.lookup_batch(self._E, word_ids[i]) for i in range(length)]
|
||||||
|
|
||||||
|
f_state = self._fwd_lstm.initial_state()
|
||||||
|
b_state = self._bwd_lstm.initial_state()
|
||||||
|
|
||||||
|
fw = [x.output() for x in f_state.add_inputs(wembs)]
|
||||||
|
bw = [x.output() for x in b_state.add_inputs(reversed(wembs))]
|
||||||
|
|
||||||
|
H = dynet.parameter(self._pH)
|
||||||
|
O = dynet.parameter(self._pO)
|
||||||
|
|
||||||
|
tags_batch = [[] for _ in range(len(words_batch))]
|
||||||
|
for i, (f, b) in enumerate(zip(fw, reversed(bw))):
|
||||||
|
r_t = O * (dynet.tanh(H * dynet.concatenate([f, b])))
|
||||||
|
out = dynet.softmax(r_t).npvalue()
|
||||||
|
for j in range(len(words_batch)):
|
||||||
|
tags_batch[j].append(self.vt.i2w[np.argmax(out.T[j])])
|
||||||
|
return tags_batch
|
||||||
|
|
||||||
|
def pipe(self, sentences):
|
||||||
|
batch = []
|
||||||
|
for words in sentences:
|
||||||
|
batch.append(words)
|
||||||
|
if len(batch) == self._minibatch_size:
|
||||||
|
tags_batch = self.predict_batch(batch)
|
||||||
|
for words, tags in zip(batch, tags_batch):
|
||||||
|
yield tags
|
||||||
|
batch = []
|
||||||
|
|
||||||
def update(self, words, tags):
|
def update(self, words, tags):
|
||||||
self._words_batch.append(words)
|
self._words_batch.append(words)
|
||||||
self._tags_batch.append(tags)
|
self._tags_batch.append(tags)
|
||||||
|
@ -193,10 +229,9 @@ def main(train_loc, dev_loc, model_dir):
|
||||||
tagged = 0
|
tagged = 0
|
||||||
if i % 10000 == 0:
|
if i % 10000 == 0:
|
||||||
good = bad = 0.0
|
good = bad = 0.0
|
||||||
for sent in test:
|
word_sents = [[w for w, t in sent] for sent in test]
|
||||||
#word_ids = [vw.w2i.get(w, UNK) for w, t in sent]
|
gold_sents = [[t for w, t in sent] for sent in test]
|
||||||
tags = tagger([w for w, t in sent])
|
for words, tags, golds in zip(words, tagger.pipe(words), gold_sents):
|
||||||
golds = [t for w, t in sent]
|
|
||||||
for go, gu in zip(golds, tags):
|
for go, gu in zip(golds, tags):
|
||||||
if go == gu:
|
if go == gu:
|
||||||
good += 1
|
good += 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user