mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Fix handling of unseen labels in tagger
This commit is contained in:
parent
3aabf621a3
commit
5b56aad4c2
|
@ -501,6 +501,7 @@ class Tagger(Pipe):
|
||||||
cdef int idx = 0
|
cdef int idx = 0
|
||||||
correct = numpy.zeros((scores.shape[0],), dtype='i')
|
correct = numpy.zeros((scores.shape[0],), dtype='i')
|
||||||
guesses = scores.argmax(axis=1)
|
guesses = scores.argmax(axis=1)
|
||||||
|
known_labels = numpy.ones((scores.shape[0],), dtype='f')
|
||||||
for gold in golds:
|
for gold in golds:
|
||||||
for tag in gold.tags:
|
for tag in gold.tags:
|
||||||
if tag is None:
|
if tag is None:
|
||||||
|
@ -508,10 +509,12 @@ class Tagger(Pipe):
|
||||||
elif tag in tag_index:
|
elif tag in tag_index:
|
||||||
correct[idx] = tag_index[tag]
|
correct[idx] = tag_index[tag]
|
||||||
else:
|
else:
|
||||||
correct[idx] = len(tag_index)+1
|
correct[idx] = 0
|
||||||
|
known_labels[idx] = 0.
|
||||||
idx += 1
|
idx += 1
|
||||||
correct = self.model.ops.xp.array(correct, dtype='i')
|
correct = self.model.ops.xp.array(correct, dtype='i')
|
||||||
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
|
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
|
||||||
|
d_scores *= known_labels
|
||||||
loss = (d_scores**2).sum()
|
loss = (d_scores**2).sum()
|
||||||
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
|
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
|
||||||
return float(loss), d_scores
|
return float(loss), d_scores
|
||||||
|
|
Loading…
Reference in New Issue
Block a user