mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 20:52:23 +03:00
Fix span embeds
Some of the lengths and backprop weren't right. Also various cleanup.
This commit is contained in:
parent
d7d317a1b5
commit
e00bd422d9
|
@ -2,7 +2,8 @@ from dataclasses import dataclass
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from thinc.api import Model, Linear, Relu, Dropout
|
from thinc.api import Model, Linear, Relu, Dropout
|
||||||
from thinc.api import chain, noop, Embed, add, tuplify
|
from thinc.api import chain, noop, Embed, add, tuplify, concatenate
|
||||||
|
from thinc.api import reduce_first, reduce_last, reduce_mean
|
||||||
from thinc.types import Floats2d, Floats1d, Ints2d, Ragged
|
from thinc.types import Floats2d, Floats1d, Ints2d, Ragged
|
||||||
from typing import List, Callable, Tuple, Any
|
from typing import List, Callable, Tuple, Any
|
||||||
from ...tokens import Doc
|
from ...tokens import Doc
|
||||||
|
@ -163,7 +164,8 @@ def span_embeddings_forward(
|
||||||
|
|
||||||
# TODO support attention here
|
# TODO support attention here
|
||||||
tokvecs = xp.concatenate(tokvecs)
|
tokvecs = xp.concatenate(tokvecs)
|
||||||
tokvecs_r = Ragged(tokvecs, docmenlens)
|
doclens = [len(doc) for doc in docs]
|
||||||
|
tokvecs_r = Ragged(tokvecs, doclens)
|
||||||
mentions_r = Ragged(mentions, docmenlens)
|
mentions_r = Ragged(mentions, docmenlens)
|
||||||
|
|
||||||
span_reduce = model.get_ref("span_reducer")
|
span_reduce = model.get_ref("span_reducer")
|
||||||
|
@ -172,16 +174,15 @@ def span_embeddings_forward(
|
||||||
embeds = Ragged(spanvecs, docmenlens)
|
embeds = Ragged(spanvecs, docmenlens)
|
||||||
|
|
||||||
def backprop_span_embed(dY: SpanEmbeddings) -> Tuple[List[Floats2d], List[Doc]]:
|
def backprop_span_embed(dY: SpanEmbeddings) -> Tuple[List[Floats2d], List[Doc]]:
|
||||||
|
grad, idxes = span_reduce_back(dY.vectors.data)
|
||||||
|
|
||||||
oweights = []
|
oweights = []
|
||||||
offset = 0
|
offset = 0
|
||||||
for mlen in dY.vectors.lengths:
|
for doclen in doclens:
|
||||||
hi = offset + mlen
|
hi = offset + doclen
|
||||||
vecs = dY.vectors.data[offset:hi]
|
oweights.append(grad.data[offset:hi])
|
||||||
out, out_idx = span_reduce_back(vecs)
|
|
||||||
oweights.append(out.data)
|
|
||||||
|
|
||||||
offset = hi
|
offset = hi
|
||||||
|
|
||||||
return oweights, docs
|
return oweights, docs
|
||||||
|
|
||||||
return SpanEmbeddings(mentions, embeds), backprop_span_embed
|
return SpanEmbeddings(mentions, embeds), backprop_span_embed
|
||||||
|
@ -420,10 +421,8 @@ def pairwise_sum(ops, mention_scores: Floats1d) -> Tuple[Floats2d, Callable]:
|
||||||
def backward(d_pwsum: Floats2d) -> Floats1d:
|
def backward(d_pwsum: Floats2d) -> Floats1d:
|
||||||
# For the backward pass, the gradient is distributed over the whole row and
|
# For the backward pass, the gradient is distributed over the whole row and
|
||||||
# column, so pull it all in.
|
# column, so pull it all in.
|
||||||
dim = d_pwsum.shape[0]
|
|
||||||
out = ops.alloc1f(dim)
|
out = d_pwsum.sum(axis=0) + d_pwsum.sum(axis=1)
|
||||||
for ii in range(dim):
|
|
||||||
out[ii] = d_pwsum[:, ii].sum() + d_pwsum[ii, :].sum()
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
|
@ -296,7 +296,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
|
|
||||||
clusters = get_clusters_from_doc(example.reference)
|
clusters = get_clusters_from_doc(example.reference)
|
||||||
gscores = create_gold_scores(mention_idx[offset:hi], clusters)
|
gscores = create_gold_scores(mention_idx[offset:hi], clusters)
|
||||||
gscores = xp.asarray(gscores)
|
gscores = ops.asarray2f(gscores)
|
||||||
top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
|
top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
|
||||||
# now add the placeholder
|
# now add the placeholder
|
||||||
gold_placeholder = ~top_gscores.any(axis=1).T
|
gold_placeholder = ~top_gscores.any(axis=1).T
|
||||||
|
@ -311,9 +311,6 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
log_marg = ops.softmax(cscores + ops.xp.log(top_gscores), axis=1)
|
log_marg = ops.softmax(cscores + ops.xp.log(top_gscores), axis=1)
|
||||||
log_norm = ops.softmax(cscores, axis=1)
|
log_norm = ops.softmax(cscores, axis=1)
|
||||||
grad = log_norm - log_marg
|
grad = log_norm - log_marg
|
||||||
# XXX might be better to not square this
|
|
||||||
loss = (grad ** 2).sum()
|
|
||||||
|
|
||||||
gradients.append((grad, cidx))
|
gradients.append((grad, cidx))
|
||||||
total_loss += float(loss)
|
total_loss += float(loss)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user