mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Minibatch the forward pass. THe output argmax is incorrect...
This commit is contained in:
parent
8f053fd943
commit
718e66a7b9
|
@ -118,6 +118,42 @@ class BiTagger(object):
|
|||
tags.append(self.vt.i2w[np.argmax(out.npvalue())])
|
||||
return tags
|
||||
|
||||
def predict_batch(self, words_batch):
|
||||
dynet.renew_cg()
|
||||
length = max(len(words) for words in words_batch)
|
||||
word_ids = np.zeros((length, len(words_batch)), dtype='int32')
|
||||
for j, words in enumerate(words_batch):
|
||||
for i, word in enumerate(words):
|
||||
word_ids[i, j] = self.vw.w2i.get(word, self.UNK)
|
||||
wembs = [dynet.lookup_batch(self._E, word_ids[i]) for i in range(length)]
|
||||
|
||||
f_state = self._fwd_lstm.initial_state()
|
||||
b_state = self._bwd_lstm.initial_state()
|
||||
|
||||
fw = [x.output() for x in f_state.add_inputs(wembs)]
|
||||
bw = [x.output() for x in b_state.add_inputs(reversed(wembs))]
|
||||
|
||||
H = dynet.parameter(self._pH)
|
||||
O = dynet.parameter(self._pO)
|
||||
|
||||
tags_batch = [[] for _ in range(len(words_batch))]
|
||||
for i, (f, b) in enumerate(zip(fw, reversed(bw))):
|
||||
r_t = O * (dynet.tanh(H * dynet.concatenate([f, b])))
|
||||
out = dynet.softmax(r_t).npvalue()
|
||||
for j in range(len(words_batch)):
|
||||
tags_batch[j].append(self.vt.i2w[np.argmax(out.T[j])])
|
||||
return tags_batch
|
||||
|
||||
def pipe(self, sentences):
|
||||
batch = []
|
||||
for words in sentences:
|
||||
batch.append(words)
|
||||
if len(batch) == self._minibatch_size:
|
||||
tags_batch = self.predict_batch(batch)
|
||||
for words, tags in zip(batch, tags_batch):
|
||||
yield tags
|
||||
batch = []
|
||||
|
||||
def update(self, words, tags):
|
||||
self._words_batch.append(words)
|
||||
self._tags_batch.append(tags)
|
||||
|
@ -193,10 +229,9 @@ def main(train_loc, dev_loc, model_dir):
|
|||
tagged = 0
|
||||
if i % 10000 == 0:
|
||||
good = bad = 0.0
|
||||
for sent in test:
|
||||
#word_ids = [vw.w2i.get(w, UNK) for w, t in sent]
|
||||
tags = tagger([w for w, t in sent])
|
||||
golds = [t for w, t in sent]
|
||||
word_sents = [[w for w, t in sent] for sent in test]
|
||||
gold_sents = [[t for w, t in sent] for sent in test]
|
||||
for words, tags, golds in zip(words, tagger.pipe(words), gold_sents):
|
||||
for go, gu in zip(golds, tags):
|
||||
if go == gu:
|
||||
good += 1
|
||||
|
|
Loading…
Reference in New Issue
Block a user