clean up duplicate code

This commit is contained in:
svlandeg 2019-06-24 15:19:58 +02:00
parent ddc73b11a9
commit 58a5b40ef6

View File

@ -12,8 +12,8 @@ from thinc.api import chain
from thinc.v2v import Affine, Maxout, Softmax from thinc.v2v import Affine, Maxout, Softmax
from thinc.misc import LayerNorm from thinc.misc import LayerNorm
from thinc.neural.util import to_categorical from thinc.neural.util import to_categorical
from thinc.neural.util import get_array_module
from ..cli.pretrain import get_cossim_loss
from .functions import merge_subtokens from .functions import merge_subtokens
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..syntax.nn_parser cimport Parser from ..syntax.nn_parser cimport Parser
@ -1162,26 +1162,11 @@ class EntityLinker(Pipe):
return 0 return 0
def get_loss(self, docs, golds, scores): def get_loss(self, docs, golds, scores):
targets = [[1] for _ in golds] # assuming we're only using positive examples # this loss function assumes we're only using positive examples
loss, gradients = self.get_cossim_loss_2(yh=scores, y=golds, t=targets) loss, gradients = get_cossim_loss(yh=scores, y=golds)
loss = loss / len(golds) loss = loss / len(golds)
return loss, gradients return loss, gradients
def get_cossim_loss_2(self, yh, y, t):
# Add a small constant to avoid 0 vectors
yh = yh + 1e-8
y = y + 1e-8
# https://math.stackexchange.com/questions/1923613/partial-derivative-of-cosine-similarity
xp = get_array_module(yh)
norm_yh = xp.linalg.norm(yh, axis=1, keepdims=True)
norm_y = xp.linalg.norm(y, axis=1, keepdims=True)
mul_norms = norm_yh * norm_y
cos = (yh * y).sum(axis=1, keepdims=True) / mul_norms
d_yh = (y / mul_norms) - (cos * (yh / norm_yh ** 2))
loss = xp.abs(cos - t).sum()
inverse = np.asarray([int(t[i][0]) * d_yh[i] for i in range(len(t))])
return loss, -inverse
def __call__(self, doc): def __call__(self, doc):
entities, kb_ids = self.predict([doc]) entities, kb_ids = self.predict([doc])
self.set_annotations([doc], entities, kb_ids) self.set_annotations([doc], entities, kb_ids)