Fix span embeds

Some of the lengths and backprop weren't right.

Also various cleanup.
This commit is contained in:
Paul O'Leary McCann 2021-07-10 20:44:20 +09:00
parent d7d317a1b5
commit e00bd422d9
3 changed files with 13 additions and 17 deletions

View File

@ -2,7 +2,8 @@ from dataclasses import dataclass
import warnings
from thinc.api import Model, Linear, Relu, Dropout
from thinc.api import chain, noop, Embed, add, tuplify
from thinc.api import chain, noop, Embed, add, tuplify, concatenate
from thinc.api import reduce_first, reduce_last, reduce_mean
from thinc.types import Floats2d, Floats1d, Ints2d, Ragged
from typing import List, Callable, Tuple, Any
from ...tokens import Doc
@ -163,7 +164,8 @@ def span_embeddings_forward(
# TODO support attention here
tokvecs = xp.concatenate(tokvecs)
tokvecs_r = Ragged(tokvecs, docmenlens)
doclens = [len(doc) for doc in docs]
tokvecs_r = Ragged(tokvecs, doclens)
mentions_r = Ragged(mentions, docmenlens)
span_reduce = model.get_ref("span_reducer")
@ -172,16 +174,15 @@ def span_embeddings_forward(
embeds = Ragged(spanvecs, docmenlens)
def backprop_span_embed(dY: SpanEmbeddings) -> Tuple[List[Floats2d], List[Doc]]:
grad, idxes = span_reduce_back(dY.vectors.data)
oweights = []
offset = 0
for mlen in dY.vectors.lengths:
hi = offset + mlen
vecs = dY.vectors.data[offset:hi]
out, out_idx = span_reduce_back(vecs)
oweights.append(out.data)
for doclen in doclens:
hi = offset + doclen
oweights.append(grad.data[offset:hi])
offset = hi
return oweights, docs
return SpanEmbeddings(mentions, embeds), backprop_span_embed
@ -420,10 +421,8 @@ def pairwise_sum(ops, mention_scores: Floats1d) -> Tuple[Floats2d, Callable]:
def backward(d_pwsum: Floats2d) -> Floats1d:
# For the backward pass, the gradient is distributed over the whole row and
# column, so pull it all in.
dim = d_pwsum.shape[0]
out = ops.alloc1f(dim)
for ii in range(dim):
out[ii] = d_pwsum[:, ii].sum() + d_pwsum[ii, :].sum()
out = d_pwsum.sum(axis=0) + d_pwsum.sum(axis=1)
return out

View File

@ -25,7 +25,7 @@ def build_mean_max_reducer(hidden_size: int) -> Model[Ragged, Floats2d]:
return chain(
concatenate(reduce_last(), reduce_first(), reduce_mean(), reduce_max()),
Maxout(nO=hidden_size, normalize=True, dropout=0.0),
)
)
@registry.architectures.register("spacy.SpanCategorizer.v1")

View File

@ -296,7 +296,7 @@ class CoreferenceResolver(TrainablePipe):
clusters = get_clusters_from_doc(example.reference)
gscores = create_gold_scores(mention_idx[offset:hi], clusters)
gscores = xp.asarray(gscores)
gscores = ops.asarray2f(gscores)
top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
# now add the placeholder
gold_placeholder = ~top_gscores.any(axis=1).T
@ -311,9 +311,6 @@ class CoreferenceResolver(TrainablePipe):
log_marg = ops.softmax(cscores + ops.xp.log(top_gscores), axis=1)
log_norm = ops.softmax(cscores, axis=1)
grad = log_norm - log_marg
# XXX might be better to not square this
loss = (grad ** 2).sum()
gradients.append((grad, cidx))
total_loss += float(loss)