mirror of
https://github.com/explosion/spaCy.git
synced 2025-10-20 18:54:21 +03:00
This includes the coref code that was being tested separately, modified to work in spaCy. It hasn't been tested yet and presumably still needs fixes. In particular, the evaluation code is currently omitted. It's unclear at the moment whether we want to use a complex scorer similar to the official one, or a simpler scorer using more modern evaluation methods.
403 lines
13 KiB
Python
403 lines
13 KiB
Python
from dataclasses import dataclass
|
|
|
|
from thinc.api import Model, Linear, Relu, Dropout, chain, noop
|
|
from thinc.types import Floats2d, Floats1d, Ints2d, Ragged
|
|
from typing import List, Callable, Tuple
|
|
from ...tokens import Doc
|
|
from ...util import registry
|
|
|
|
from .coref_util import (
|
|
get_predicted_clusters,
|
|
get_candidate_mentions,
|
|
select_non_crossing_spans,
|
|
make_clean_doc,
|
|
create_gold_scores,
|
|
logsumexp,
|
|
topk,
|
|
)
|
|
|
|
|
|
@registry.architectures("spacy.Coref.v0")
|
|
def build_coref(
|
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
|
get_mentions: Callable = get_candidate_mentions,
|
|
hidden: int = 1000,
|
|
dropout: float = 0.3,
|
|
mention_limit: int = 3900,
|
|
max_span_width: int = 20,
|
|
):
|
|
dim = tok2vec.get_dim("nO") * 3
|
|
|
|
span_embedder = build_span_embedder(get_mentions, max_span_width)
|
|
|
|
with Model.define_operators({">>": chain, "&": tuplify}):
|
|
|
|
mention_scorer = (
|
|
Linear(nI=dim, nO=hidden)
|
|
>> Relu(nI=hidden, nO=hidden)
|
|
>> Dropout(dropout)
|
|
>> Linear(nI=hidden, nO=1)
|
|
)
|
|
mention_scorer.initialize()
|
|
|
|
bilinear = Linear(nI=dim, nO=dim) >> Dropout(dropout)
|
|
bilinear.initialize()
|
|
|
|
ms = build_take_vecs() >> mention_scorer
|
|
|
|
model = (
|
|
(tok2vec & noop())
|
|
>> span_embedder
|
|
>> (ms & noop())
|
|
>> build_coarse_pruner(mention_limit)
|
|
>> build_ant_scorer(bilinear, Dropout(dropout))
|
|
)
|
|
return model
|
|
|
|
|
|
# TODO replace this with thinc version once PR is in
|
|
def tuplify(layer1: Model, layer2: Model, *layers) -> Model:
|
|
layers = (layer1, layer2) + layers
|
|
names = [layer.name for layer in layers]
|
|
return Model(
|
|
"tuple(" + ", ".join(names) + ")",
|
|
tuplify_forward,
|
|
layers=layers,
|
|
)
|
|
|
|
|
|
def tuplify_forward(model, X, is_train):
|
|
Ys = []
|
|
backprops = []
|
|
for layer in model.layers:
|
|
Y, backprop = layer(X, is_train)
|
|
Ys.append(Y)
|
|
backprops.append(backprop)
|
|
|
|
def backprop_tuplify(dYs):
|
|
dXs = [bp(dY) for bp, dY in zip(backprops, dYs)]
|
|
dX = dXs[0]
|
|
for dx in dXs[1:]:
|
|
dX += dx
|
|
return dX
|
|
|
|
return tuple(Ys), backprop_tuplify
|
|
|
|
|
|
@dataclass
|
|
class SpanEmbeddings:
|
|
indices: Ints2d # Array with 2 columns (for start and end index)
|
|
vectors: Ragged # Ragged[Floats2d] # One vector per span
|
|
# NB: We assume that the indices refer to a concatenated Floats2d that
|
|
# has one row per token in the *batch* of documents. This makes it unambiguous
|
|
# which row is in which document, because if the lengths are e.g. [10, 5],
|
|
# a span starting at 11 must be starting at token 2 of doc 1. A bug could
|
|
# potentially cause you to have a span which crosses a doc boundary though,
|
|
# which would be bad.
|
|
# The lengths in the Ragged are not the tokens per doc, but the number of
|
|
# mentions per doc.
|
|
|
|
def __add__(self, right):
|
|
out = self.vectors.data + right.vectors.data
|
|
return SpanEmbeddings(self.indices, Ragged(out, self.vectors.lengths))
|
|
|
|
def __iadd__(self, right):
|
|
self.vectors.data += right.vectors.data
|
|
return self
|
|
|
|
|
|
# model converting a Doc/Mention to span embeddings
|
|
# get_mentions: Callable[Doc, Pairs[int]]
|
|
def build_span_embedder(
|
|
get_mentions: Callable,
|
|
max_span_width: int = 20,
|
|
) -> Model[Tuple[List[Floats2d], List[Doc]], SpanEmbeddings]:
|
|
|
|
return Model(
|
|
"SpanEmbedding",
|
|
forward=span_embeddings_forward,
|
|
attrs={
|
|
"get_mentions": get_mentions,
|
|
# XXX might be better to make this an implicit parameter in the
|
|
# mention generator
|
|
"max_span_width": max_span_width,
|
|
},
|
|
)
|
|
|
|
|
|
def span_embeddings_forward(
|
|
model, inputs: Tuple[List[Floats2d], List[Doc]], is_train
|
|
) -> SpanEmbeddings:
|
|
ops = model.ops
|
|
xp = ops.xp
|
|
|
|
tokvecs, docs = inputs
|
|
|
|
dim = tokvecs[0].shape[1]
|
|
|
|
get_mentions = model.attrs["get_mentions"]
|
|
max_span_width = model.attrs["max_span_width"]
|
|
mentions = ops.alloc2i(0, 2)
|
|
total_length = 0
|
|
docmenlens = [] # number of mentions per doc
|
|
for doc in docs:
|
|
starts, ends = get_mentions(doc, max_span_width)
|
|
docmenlens.append(len(starts))
|
|
cments = ops.asarray2i([starts, ends]).transpose()
|
|
|
|
mentions = xp.concatenate((mentions, cments + total_length))
|
|
total_length += len(doc)
|
|
|
|
# TODO support attention here
|
|
tokvecs = xp.concatenate(tokvecs)
|
|
spans = [tokvecs[ii:jj] for ii, jj in mentions.tolist()]
|
|
avgs = [xp.mean(ss, axis=0) for ss in spans]
|
|
spanvecs = ops.asarray2f(avgs)
|
|
|
|
# first and last token embeds
|
|
starts = [tokvecs[ii] for ii in mentions[:, 0]]
|
|
ends = [tokvecs[jj] for jj in mentions[:, 1]]
|
|
|
|
starts = ops.asarray2f(starts)
|
|
ends = ops.asarray2f(ends)
|
|
concat = xp.concatenate((starts, ends, spanvecs), 1)
|
|
embeds = Ragged(concat, docmenlens)
|
|
|
|
def backprop_span_embed(dY: SpanEmbeddings) -> Tuple[List[Floats2d], List[Doc]]:
|
|
|
|
oweights = []
|
|
odocs = []
|
|
offset = 0
|
|
tokoffset = 0
|
|
for indoc, mlen in zip(docs, dY.vectors.lengths):
|
|
hi = offset + mlen
|
|
hitok = tokoffset + len(indoc)
|
|
odocs.append(indoc) # no change
|
|
vecs = dY.vectors.data[offset:hi]
|
|
|
|
starts = vecs[:, :dim]
|
|
ends = vecs[:, dim : 2 * dim]
|
|
spanvecs = vecs[:, 2 * dim :]
|
|
|
|
out = model.ops.alloc2f(len(indoc), dim)
|
|
|
|
for ii, (start, end) in enumerate(dY.indices[offset:hi]):
|
|
# adjust indexes to align with doc
|
|
start -= tokoffset
|
|
end -= tokoffset
|
|
|
|
out[start] += starts[ii]
|
|
out[end] += ends[ii]
|
|
out[start:end] += spanvecs[ii]
|
|
oweights.append(out)
|
|
|
|
offset = hi
|
|
tokoffset = hitok
|
|
return oweights, odocs
|
|
|
|
return SpanEmbeddings(mentions, embeds), backprop_span_embed
|
|
|
|
|
|
def build_coarse_pruner(
|
|
mention_limit: int,
|
|
) -> Model[SpanEmbeddings, SpanEmbeddings]:
|
|
model = Model(
|
|
"CoarsePruner",
|
|
forward=coarse_prune,
|
|
attrs={
|
|
"mention_limit": mention_limit,
|
|
},
|
|
)
|
|
return model
|
|
|
|
|
|
def coarse_prune(
|
|
model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train
|
|
) -> SpanEmbeddings:
|
|
"""Given scores for mention, output the top non-crossing mentions.
|
|
|
|
Mentions can contain other mentions, but candidate mentions cannot cross each other.
|
|
"""
|
|
rawscores, spanembeds = inputs
|
|
scores = rawscores.squeeze()
|
|
mention_limit = model.attrs["mention_limit"]
|
|
# XXX: Issue here. Don't need docs to find crossing spans, but might for the limits.
|
|
# In old code the limit can be:
|
|
# - hard number per doc
|
|
# - ratio of tokens in the doc
|
|
|
|
offset = 0
|
|
selected = []
|
|
sellens = []
|
|
for menlen in spanembeds.vectors.lengths:
|
|
hi = offset + menlen
|
|
cscores = scores[offset:hi]
|
|
|
|
# negate it so highest numbers come first
|
|
tops = (model.ops.xp.argsort(-1 * cscores)).tolist()
|
|
starts = spanembeds.indices[offset:hi, 0].tolist()
|
|
ends = spanembeds.indices[offset:hi:, 1].tolist()
|
|
|
|
# csel is a 1d integer list
|
|
csel = select_non_crossing_spans(tops, starts, ends, mention_limit)
|
|
# add the offset so these indices are absolute
|
|
csel = [ii + offset for ii in csel]
|
|
# this should be constant because short choices are padded
|
|
sellens.append(len(csel))
|
|
selected += csel
|
|
offset += menlen
|
|
|
|
selected = model.ops.asarray1i(selected)
|
|
top_spans = spanembeds.indices[selected]
|
|
top_vecs = spanembeds.vectors.data[selected]
|
|
|
|
out = SpanEmbeddings(top_spans, Ragged(top_vecs, sellens))
|
|
|
|
def coarse_prune_backprop(
|
|
dY: Tuple[Floats1d, SpanEmbeddings]
|
|
) -> Tuple[Floats1d, SpanEmbeddings]:
|
|
ll = spanembeds.indices.shape[0]
|
|
|
|
dYscores, dYembeds = dY
|
|
|
|
dXscores = model.ops.alloc1f(ll)
|
|
dXscores[selected] = dYscores.squeeze()
|
|
|
|
dXvecs = model.ops.alloc2f(*spanembeds.vectors.data.shape)
|
|
dXvecs[selected] = dYembeds.vectors.data
|
|
rout = Ragged(dXvecs, out.vectors.lengths)
|
|
dXembeds = SpanEmbeddings(spanembeds.indices, rout)
|
|
|
|
# inflate for mention scorer
|
|
dXscores = model.ops.xp.expand_dims(dXscores, 1)
|
|
|
|
return (dXscores, dXembeds)
|
|
|
|
return (scores[selected], out), coarse_prune_backprop
|
|
|
|
|
|
def build_take_vecs() -> Model[SpanEmbeddings, Floats2d]:
|
|
# this just gets vectors out of spanembeddings
|
|
# XXX Might be better to convert SpanEmbeddings to a tuple and use with_getitem
|
|
return Model("TakeVecs", forward=take_vecs_forward)
|
|
|
|
|
|
def take_vecs_forward(model, inputs: SpanEmbeddings, is_train) -> Floats2d:
|
|
def backprop(dY: Floats2d) -> SpanEmbeddings:
|
|
vecs = Ragged(dY, inputs.vectors.lengths)
|
|
return SpanEmbeddings(inputs.indices, vecs)
|
|
|
|
return inputs.vectors.data, backprop
|
|
|
|
|
|
def build_ant_scorer(
|
|
bilinear, dropout, ant_limit=50
|
|
) -> Model[Tuple[Floats1d, SpanEmbeddings], List[Floats2d]]:
|
|
return Model(
|
|
"AntScorer",
|
|
forward=ant_scorer_forward,
|
|
layers=[bilinear, dropout],
|
|
attrs={
|
|
"ant_limit": ant_limit,
|
|
},
|
|
)
|
|
|
|
|
|
def ant_scorer_forward(
|
|
model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train
|
|
) -> Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d]:
|
|
ops = model.ops
|
|
xp = ops.xp
|
|
|
|
ant_limit = model.attrs["ant_limit"]
|
|
# this contains the coarse bilinear in coref-hoi
|
|
# coarse bilinear is a single layer linear network
|
|
# TODO make these proper refs
|
|
bilinear = model.layers[0]
|
|
dropout = model.layers[1]
|
|
|
|
# XXX Note on dimensions: This won't work as a ragged because the floats2ds
|
|
# are not all the same dimentions. Each floats2d is a square in the size of
|
|
# the number of antecedents in the document. Actually, that will have the
|
|
# same size if antecedents are padded... Needs checking.
|
|
|
|
mscores, sembeds = inputs
|
|
vecs = sembeds.vectors # ragged
|
|
|
|
offset = 0
|
|
backprops = []
|
|
out = []
|
|
for ll in vecs.lengths:
|
|
hi = offset + ll
|
|
# each iteration is one doc
|
|
|
|
# first calculate the pairwise product scores
|
|
cvecs = vecs.data[offset:hi]
|
|
source, source_b = bilinear(cvecs, is_train)
|
|
target, target_b = dropout(cvecs, is_train)
|
|
pw_prod = xp.matmul(source, target.T)
|
|
|
|
# now calculate the pairwise mention scores
|
|
ms = mscores[offset:hi].squeeze()
|
|
pw_sum = xp.expand_dims(ms, 1) + xp.expand_dims(ms, 0)
|
|
|
|
# make a mask so antecedents precede referrents
|
|
ant_range = xp.arange(0, cvecs.shape[0])
|
|
# with xp.errstate(divide="ignore"):
|
|
# mask = xp.log(
|
|
# (xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1
|
|
# ).astype(float)
|
|
mask = xp.log(
|
|
(xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1
|
|
).astype(float)
|
|
|
|
scores = pw_prod + pw_sum + mask
|
|
|
|
top_scores, top_scores_idx = topk(xp, scores, ant_limit)
|
|
out.append((top_scores, top_scores_idx))
|
|
|
|
# In the full model these scores can be further refined. In the current
|
|
# state of this model we're done here, so this pruning is less important,
|
|
# but it's still helpful for reducing memory usage (since scores can be
|
|
# garbage collected when the loop exits).
|
|
|
|
offset += ll
|
|
backprops.append((source_b, target_b, source, target))
|
|
|
|
def backprop(
|
|
dYs: Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d]
|
|
) -> Tuple[Floats2d, SpanEmbeddings]:
|
|
dYscores, dYembeds = dYs
|
|
dXembeds = Ragged(ops.alloc2f(*vecs.data.shape), vecs.lengths)
|
|
dXscores = ops.alloc1f(*mscores.shape)
|
|
|
|
offset = 0
|
|
for dy, (source_b, target_b, source, target), ll in zip(
|
|
dYscores, backprops, vecs.lengths
|
|
):
|
|
# I'm not undoing the operations in the right order here.
|
|
dyscore, dyidx = dy
|
|
# the full score grid is square
|
|
|
|
fullscore = ops.alloc2f(ll, ll)
|
|
# cupy has no put_along_axis
|
|
# xp.put_along_axis(fullscore, dyidx, dyscore, 1)
|
|
for ii, (ridx, rscores) in enumerate(zip(dyidx, dyscore)):
|
|
fullscore[ii][ridx] = rscores
|
|
|
|
dS = source_b(fullscore @ target)
|
|
dT = target_b(fullscore @ source)
|
|
dXembeds.data[offset : offset + ll] = dS + dT
|
|
|
|
# The gradient can be distributed over all the rows and columns here,
|
|
# so aggregate it
|
|
section = dXscores[offset : offset + ll]
|
|
for ii in range(ll):
|
|
section[ii] = fullscore[:, ii].sum() + fullscore[ii, :].sum()
|
|
offset += ll
|
|
# make it fit back into the linear
|
|
dXscores = xp.expand_dims(dXscores, 1)
|
|
return (dXscores, SpanEmbeddings(sembeds.indices, dXembeds))
|
|
|
|
return (out, sembeds.indices), backprop
|