Fix span embeds

Some of the lengths and backprop weren't right.

Also various cleanup.
This commit is contained in:
Paul O'Leary McCann 2021-07-10 20:44:20 +09:00
parent d7d317a1b5
commit e00bd422d9
3 changed files with 13 additions and 17 deletions

View File

@ -2,7 +2,8 @@ from dataclasses import dataclass
import warnings import warnings
from thinc.api import Model, Linear, Relu, Dropout from thinc.api import Model, Linear, Relu, Dropout
from thinc.api import chain, noop, Embed, add, tuplify from thinc.api import chain, noop, Embed, add, tuplify, concatenate
from thinc.api import reduce_first, reduce_last, reduce_mean
from thinc.types import Floats2d, Floats1d, Ints2d, Ragged from thinc.types import Floats2d, Floats1d, Ints2d, Ragged
from typing import List, Callable, Tuple, Any from typing import List, Callable, Tuple, Any
from ...tokens import Doc from ...tokens import Doc
@ -163,7 +164,8 @@ def span_embeddings_forward(
# TODO support attention here # TODO support attention here
tokvecs = xp.concatenate(tokvecs) tokvecs = xp.concatenate(tokvecs)
tokvecs_r = Ragged(tokvecs, docmenlens) doclens = [len(doc) for doc in docs]
tokvecs_r = Ragged(tokvecs, doclens)
mentions_r = Ragged(mentions, docmenlens) mentions_r = Ragged(mentions, docmenlens)
span_reduce = model.get_ref("span_reducer") span_reduce = model.get_ref("span_reducer")
@ -172,16 +174,15 @@ def span_embeddings_forward(
embeds = Ragged(spanvecs, docmenlens) embeds = Ragged(spanvecs, docmenlens)
def backprop_span_embed(dY: SpanEmbeddings) -> Tuple[List[Floats2d], List[Doc]]: def backprop_span_embed(dY: SpanEmbeddings) -> Tuple[List[Floats2d], List[Doc]]:
grad, idxes = span_reduce_back(dY.vectors.data)
oweights = [] oweights = []
offset = 0 offset = 0
for mlen in dY.vectors.lengths: for doclen in doclens:
hi = offset + mlen hi = offset + doclen
vecs = dY.vectors.data[offset:hi] oweights.append(grad.data[offset:hi])
out, out_idx = span_reduce_back(vecs)
oweights.append(out.data)
offset = hi offset = hi
return oweights, docs return oweights, docs
return SpanEmbeddings(mentions, embeds), backprop_span_embed return SpanEmbeddings(mentions, embeds), backprop_span_embed
@ -420,10 +421,8 @@ def pairwise_sum(ops, mention_scores: Floats1d) -> Tuple[Floats2d, Callable]:
def backward(d_pwsum: Floats2d) -> Floats1d: def backward(d_pwsum: Floats2d) -> Floats1d:
# For the backward pass, the gradient is distributed over the whole row and # For the backward pass, the gradient is distributed over the whole row and
# column, so pull it all in. # column, so pull it all in.
dim = d_pwsum.shape[0]
out = ops.alloc1f(dim) out = d_pwsum.sum(axis=0) + d_pwsum.sum(axis=1)
for ii in range(dim):
out[ii] = d_pwsum[:, ii].sum() + d_pwsum[ii, :].sum()
return out return out

View File

@ -296,7 +296,7 @@ class CoreferenceResolver(TrainablePipe):
clusters = get_clusters_from_doc(example.reference) clusters = get_clusters_from_doc(example.reference)
gscores = create_gold_scores(mention_idx[offset:hi], clusters) gscores = create_gold_scores(mention_idx[offset:hi], clusters)
gscores = xp.asarray(gscores) gscores = ops.asarray2f(gscores)
top_gscores = xp.take_along_axis(gscores, cidx, axis=1) top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
# now add the placeholder # now add the placeholder
gold_placeholder = ~top_gscores.any(axis=1).T gold_placeholder = ~top_gscores.any(axis=1).T
@ -311,9 +311,6 @@ class CoreferenceResolver(TrainablePipe):
log_marg = ops.softmax(cscores + ops.xp.log(top_gscores), axis=1) log_marg = ops.softmax(cscores + ops.xp.log(top_gscores), axis=1)
log_norm = ops.softmax(cscores, axis=1) log_norm = ops.softmax(cscores, axis=1)
grad = log_norm - log_marg grad = log_norm - log_marg
# XXX might be better to not square this
loss = (grad ** 2).sum()
gradients.append((grad, cidx)) gradients.append((grad, cidx))
total_loss += float(loss) total_loss += float(loss)