Catch a stray reference

This commit is contained in:
Paul O'Leary McCann 2021-05-20 21:30:46 +09:00
parent 8c5df622d8
commit ff3fed06cf

View File

@ -391,6 +391,7 @@ def ant_scorer_forward(
vecshape = vecs.data.shape
veclens = vecs.lengths
scoreshape = mscores.shape
idxes = sembeds.indices
def backprop(
dYs: Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d]
@ -417,7 +418,7 @@ def ant_scorer_forward(
offset += ll
# make it fit back into the linear
dXscores = xp.expand_dims(dXscores, 1)
return (dXscores, SpanEmbeddings(sembeds.indices, dXembeds))
return (dXscores, SpanEmbeddings(idxes, dXembeds))
return (out, sembeds.indices), backprop