mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-09 22:54:53 +03:00
Migrate coref code
This includes the coref code that was being tested separately, modified to work in spaCy. It hasn't been tested yet and presumably still needs fixes. In particular, the evaluation code is currently omitted. It's unclear at the moment whether we want to use a complex scorer similar to the official one, or a simpler scorer using more modern evaluation methods.
This commit is contained in:
parent
3608b7b3f9
commit
7c42a8c90a
|
@ -1,18 +1,402 @@
|
|||
from typing import List
|
||||
from thinc.api import Model
|
||||
from thinc.types import Floats2d
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ...util import registry
|
||||
from thinc.api import Model, Linear, Relu, Dropout, chain, noop
|
||||
from thinc.types import Floats2d, Floats1d, Ints2d, Ragged
|
||||
from typing import List, Callable, Tuple
|
||||
from ...tokens import Doc
|
||||
from ...util import registry
|
||||
|
||||
from .coref_util import (
|
||||
get_predicted_clusters,
|
||||
get_candidate_mentions,
|
||||
select_non_crossing_spans,
|
||||
make_clean_doc,
|
||||
create_gold_scores,
|
||||
logsumexp,
|
||||
topk,
|
||||
)
|
||||
|
||||
|
||||
@registry.architectures("spacy.Coref.v0")
|
||||
def build_coref_model(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]]
|
||||
) -> Model:
|
||||
"""Build a coref resolution model, using a provided token-to-vector component.
|
||||
TODO.
|
||||
def build_coref(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
get_mentions: Callable = get_candidate_mentions,
|
||||
hidden: int = 1000,
|
||||
dropout: float = 0.3,
|
||||
mention_limit: int = 3900,
|
||||
max_span_width: int = 20,
|
||||
):
|
||||
dim = tok2vec.get_dim("nO") * 3
|
||||
|
||||
tok2vec (Model[List[Doc], List[Floats2d]]): The token-to-vector subnetwork.
|
||||
span_embedder = build_span_embedder(get_mentions, max_span_width)
|
||||
|
||||
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||
|
||||
mention_scorer = (
|
||||
Linear(nI=dim, nO=hidden)
|
||||
>> Relu(nI=hidden, nO=hidden)
|
||||
>> Dropout(dropout)
|
||||
>> Linear(nI=hidden, nO=1)
|
||||
)
|
||||
mention_scorer.initialize()
|
||||
|
||||
bilinear = Linear(nI=dim, nO=dim) >> Dropout(dropout)
|
||||
bilinear.initialize()
|
||||
|
||||
ms = build_take_vecs() >> mention_scorer
|
||||
|
||||
model = (
|
||||
(tok2vec & noop())
|
||||
>> span_embedder
|
||||
>> (ms & noop())
|
||||
>> build_coarse_pruner(mention_limit)
|
||||
>> build_ant_scorer(bilinear, Dropout(dropout))
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
# TODO replace this with thinc version once PR is in
|
||||
def tuplify(layer1: Model, layer2: Model, *layers) -> Model:
|
||||
layers = (layer1, layer2) + layers
|
||||
names = [layer.name for layer in layers]
|
||||
return Model(
|
||||
"tuple(" + ", ".join(names) + ")",
|
||||
tuplify_forward,
|
||||
layers=layers,
|
||||
)
|
||||
|
||||
|
||||
def tuplify_forward(model, X, is_train):
|
||||
Ys = []
|
||||
backprops = []
|
||||
for layer in model.layers:
|
||||
Y, backprop = layer(X, is_train)
|
||||
Ys.append(Y)
|
||||
backprops.append(backprop)
|
||||
|
||||
def backprop_tuplify(dYs):
|
||||
dXs = [bp(dY) for bp, dY in zip(backprops, dYs)]
|
||||
dX = dXs[0]
|
||||
for dx in dXs[1:]:
|
||||
dX += dx
|
||||
return dX
|
||||
|
||||
return tuple(Ys), backprop_tuplify
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpanEmbeddings:
|
||||
indices: Ints2d # Array with 2 columns (for start and end index)
|
||||
vectors: Ragged # Ragged[Floats2d] # One vector per span
|
||||
# NB: We assume that the indices refer to a concatenated Floats2d that
|
||||
# has one row per token in the *batch* of documents. This makes it unambiguous
|
||||
# which row is in which document, because if the lengths are e.g. [10, 5],
|
||||
# a span starting at 11 must be starting at token 2 of doc 1. A bug could
|
||||
# potentially cause you to have a span which crosses a doc boundary though,
|
||||
# which would be bad.
|
||||
# The lengths in the Ragged are not the tokens per doc, but the number of
|
||||
# mentions per doc.
|
||||
|
||||
def __add__(self, right):
|
||||
out = self.vectors.data + right.vectors.data
|
||||
return SpanEmbeddings(self.indices, Ragged(out, self.vectors.lengths))
|
||||
|
||||
def __iadd__(self, right):
|
||||
self.vectors.data += right.vectors.data
|
||||
return self
|
||||
|
||||
|
||||
# model converting a Doc/Mention to span embeddings
|
||||
# get_mentions: Callable[Doc, Pairs[int]]
|
||||
def build_span_embedder(
|
||||
get_mentions: Callable,
|
||||
max_span_width: int = 20,
|
||||
) -> Model[Tuple[List[Floats2d], List[Doc]], SpanEmbeddings]:
|
||||
|
||||
return Model(
|
||||
"SpanEmbedding",
|
||||
forward=span_embeddings_forward,
|
||||
attrs={
|
||||
"get_mentions": get_mentions,
|
||||
# XXX might be better to make this an implicit parameter in the
|
||||
# mention generator
|
||||
"max_span_width": max_span_width,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def span_embeddings_forward(
|
||||
model, inputs: Tuple[List[Floats2d], List[Doc]], is_train
|
||||
) -> SpanEmbeddings:
|
||||
ops = model.ops
|
||||
xp = ops.xp
|
||||
|
||||
tokvecs, docs = inputs
|
||||
|
||||
dim = tokvecs[0].shape[1]
|
||||
|
||||
get_mentions = model.attrs["get_mentions"]
|
||||
max_span_width = model.attrs["max_span_width"]
|
||||
mentions = ops.alloc2i(0, 2)
|
||||
total_length = 0
|
||||
docmenlens = [] # number of mentions per doc
|
||||
for doc in docs:
|
||||
starts, ends = get_mentions(doc, max_span_width)
|
||||
docmenlens.append(len(starts))
|
||||
cments = ops.asarray2i([starts, ends]).transpose()
|
||||
|
||||
mentions = xp.concatenate((mentions, cments + total_length))
|
||||
total_length += len(doc)
|
||||
|
||||
# TODO support attention here
|
||||
tokvecs = xp.concatenate(tokvecs)
|
||||
spans = [tokvecs[ii:jj] for ii, jj in mentions.tolist()]
|
||||
avgs = [xp.mean(ss, axis=0) for ss in spans]
|
||||
spanvecs = ops.asarray2f(avgs)
|
||||
|
||||
# first and last token embeds
|
||||
starts = [tokvecs[ii] for ii in mentions[:, 0]]
|
||||
ends = [tokvecs[jj] for jj in mentions[:, 1]]
|
||||
|
||||
starts = ops.asarray2f(starts)
|
||||
ends = ops.asarray2f(ends)
|
||||
concat = xp.concatenate((starts, ends, spanvecs), 1)
|
||||
embeds = Ragged(concat, docmenlens)
|
||||
|
||||
def backprop_span_embed(dY: SpanEmbeddings) -> Tuple[List[Floats2d], List[Doc]]:
|
||||
|
||||
oweights = []
|
||||
odocs = []
|
||||
offset = 0
|
||||
tokoffset = 0
|
||||
for indoc, mlen in zip(docs, dY.vectors.lengths):
|
||||
hi = offset + mlen
|
||||
hitok = tokoffset + len(indoc)
|
||||
odocs.append(indoc) # no change
|
||||
vecs = dY.vectors.data[offset:hi]
|
||||
|
||||
starts = vecs[:, :dim]
|
||||
ends = vecs[:, dim : 2 * dim]
|
||||
spanvecs = vecs[:, 2 * dim :]
|
||||
|
||||
out = model.ops.alloc2f(len(indoc), dim)
|
||||
|
||||
for ii, (start, end) in enumerate(dY.indices[offset:hi]):
|
||||
# adjust indexes to align with doc
|
||||
start -= tokoffset
|
||||
end -= tokoffset
|
||||
|
||||
out[start] += starts[ii]
|
||||
out[end] += ends[ii]
|
||||
out[start:end] += spanvecs[ii]
|
||||
oweights.append(out)
|
||||
|
||||
offset = hi
|
||||
tokoffset = hitok
|
||||
return oweights, odocs
|
||||
|
||||
return SpanEmbeddings(mentions, embeds), backprop_span_embed
|
||||
|
||||
|
||||
def build_coarse_pruner(
|
||||
mention_limit: int,
|
||||
) -> Model[SpanEmbeddings, SpanEmbeddings]:
|
||||
model = Model(
|
||||
"CoarsePruner",
|
||||
forward=coarse_prune,
|
||||
attrs={
|
||||
"mention_limit": mention_limit,
|
||||
},
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def coarse_prune(
|
||||
model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train
|
||||
) -> SpanEmbeddings:
|
||||
"""Given scores for mention, output the top non-crossing mentions.
|
||||
|
||||
Mentions can contain other mentions, but candidate mentions cannot cross each other.
|
||||
"""
|
||||
return tok2vec
|
||||
rawscores, spanembeds = inputs
|
||||
scores = rawscores.squeeze()
|
||||
mention_limit = model.attrs["mention_limit"]
|
||||
# XXX: Issue here. Don't need docs to find crossing spans, but might for the limits.
|
||||
# In old code the limit can be:
|
||||
# - hard number per doc
|
||||
# - ratio of tokens in the doc
|
||||
|
||||
offset = 0
|
||||
selected = []
|
||||
sellens = []
|
||||
for menlen in spanembeds.vectors.lengths:
|
||||
hi = offset + menlen
|
||||
cscores = scores[offset:hi]
|
||||
|
||||
# negate it so highest numbers come first
|
||||
tops = (model.ops.xp.argsort(-1 * cscores)).tolist()
|
||||
starts = spanembeds.indices[offset:hi, 0].tolist()
|
||||
ends = spanembeds.indices[offset:hi:, 1].tolist()
|
||||
|
||||
# csel is a 1d integer list
|
||||
csel = select_non_crossing_spans(tops, starts, ends, mention_limit)
|
||||
# add the offset so these indices are absolute
|
||||
csel = [ii + offset for ii in csel]
|
||||
# this should be constant because short choices are padded
|
||||
sellens.append(len(csel))
|
||||
selected += csel
|
||||
offset += menlen
|
||||
|
||||
selected = model.ops.asarray1i(selected)
|
||||
top_spans = spanembeds.indices[selected]
|
||||
top_vecs = spanembeds.vectors.data[selected]
|
||||
|
||||
out = SpanEmbeddings(top_spans, Ragged(top_vecs, sellens))
|
||||
|
||||
def coarse_prune_backprop(
|
||||
dY: Tuple[Floats1d, SpanEmbeddings]
|
||||
) -> Tuple[Floats1d, SpanEmbeddings]:
|
||||
ll = spanembeds.indices.shape[0]
|
||||
|
||||
dYscores, dYembeds = dY
|
||||
|
||||
dXscores = model.ops.alloc1f(ll)
|
||||
dXscores[selected] = dYscores.squeeze()
|
||||
|
||||
dXvecs = model.ops.alloc2f(*spanembeds.vectors.data.shape)
|
||||
dXvecs[selected] = dYembeds.vectors.data
|
||||
rout = Ragged(dXvecs, out.vectors.lengths)
|
||||
dXembeds = SpanEmbeddings(spanembeds.indices, rout)
|
||||
|
||||
# inflate for mention scorer
|
||||
dXscores = model.ops.xp.expand_dims(dXscores, 1)
|
||||
|
||||
return (dXscores, dXembeds)
|
||||
|
||||
return (scores[selected], out), coarse_prune_backprop
|
||||
|
||||
|
||||
def build_take_vecs() -> Model[SpanEmbeddings, Floats2d]:
|
||||
# this just gets vectors out of spanembeddings
|
||||
# XXX Might be better to convert SpanEmbeddings to a tuple and use with_getitem
|
||||
return Model("TakeVecs", forward=take_vecs_forward)
|
||||
|
||||
|
||||
def take_vecs_forward(model, inputs: SpanEmbeddings, is_train) -> Floats2d:
|
||||
def backprop(dY: Floats2d) -> SpanEmbeddings:
|
||||
vecs = Ragged(dY, inputs.vectors.lengths)
|
||||
return SpanEmbeddings(inputs.indices, vecs)
|
||||
|
||||
return inputs.vectors.data, backprop
|
||||
|
||||
|
||||
def build_ant_scorer(
|
||||
bilinear, dropout, ant_limit=50
|
||||
) -> Model[Tuple[Floats1d, SpanEmbeddings], List[Floats2d]]:
|
||||
return Model(
|
||||
"AntScorer",
|
||||
forward=ant_scorer_forward,
|
||||
layers=[bilinear, dropout],
|
||||
attrs={
|
||||
"ant_limit": ant_limit,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def ant_scorer_forward(
|
||||
model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train
|
||||
) -> Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d]:
|
||||
ops = model.ops
|
||||
xp = ops.xp
|
||||
|
||||
ant_limit = model.attrs["ant_limit"]
|
||||
# this contains the coarse bilinear in coref-hoi
|
||||
# coarse bilinear is a single layer linear network
|
||||
# TODO make these proper refs
|
||||
bilinear = model.layers[0]
|
||||
dropout = model.layers[1]
|
||||
|
||||
# XXX Note on dimensions: This won't work as a ragged because the floats2ds
|
||||
# are not all the same dimentions. Each floats2d is a square in the size of
|
||||
# the number of antecedents in the document. Actually, that will have the
|
||||
# same size if antecedents are padded... Needs checking.
|
||||
|
||||
mscores, sembeds = inputs
|
||||
vecs = sembeds.vectors # ragged
|
||||
|
||||
offset = 0
|
||||
backprops = []
|
||||
out = []
|
||||
for ll in vecs.lengths:
|
||||
hi = offset + ll
|
||||
# each iteration is one doc
|
||||
|
||||
# first calculate the pairwise product scores
|
||||
cvecs = vecs.data[offset:hi]
|
||||
source, source_b = bilinear(cvecs, is_train)
|
||||
target, target_b = dropout(cvecs, is_train)
|
||||
pw_prod = xp.matmul(source, target.T)
|
||||
|
||||
# now calculate the pairwise mention scores
|
||||
ms = mscores[offset:hi].squeeze()
|
||||
pw_sum = xp.expand_dims(ms, 1) + xp.expand_dims(ms, 0)
|
||||
|
||||
# make a mask so antecedents precede referrents
|
||||
ant_range = xp.arange(0, cvecs.shape[0])
|
||||
# with xp.errstate(divide="ignore"):
|
||||
# mask = xp.log(
|
||||
# (xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1
|
||||
# ).astype(float)
|
||||
mask = xp.log(
|
||||
(xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1
|
||||
).astype(float)
|
||||
|
||||
scores = pw_prod + pw_sum + mask
|
||||
|
||||
top_scores, top_scores_idx = topk(xp, scores, ant_limit)
|
||||
out.append((top_scores, top_scores_idx))
|
||||
|
||||
# In the full model these scores can be further refined. In the current
|
||||
# state of this model we're done here, so this pruning is less important,
|
||||
# but it's still helpful for reducing memory usage (since scores can be
|
||||
# garbage collected when the loop exits).
|
||||
|
||||
offset += ll
|
||||
backprops.append((source_b, target_b, source, target))
|
||||
|
||||
def backprop(
|
||||
dYs: Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d]
|
||||
) -> Tuple[Floats2d, SpanEmbeddings]:
|
||||
dYscores, dYembeds = dYs
|
||||
dXembeds = Ragged(ops.alloc2f(*vecs.data.shape), vecs.lengths)
|
||||
dXscores = ops.alloc1f(*mscores.shape)
|
||||
|
||||
offset = 0
|
||||
for dy, (source_b, target_b, source, target), ll in zip(
|
||||
dYscores, backprops, vecs.lengths
|
||||
):
|
||||
# I'm not undoing the operations in the right order here.
|
||||
dyscore, dyidx = dy
|
||||
# the full score grid is square
|
||||
|
||||
fullscore = ops.alloc2f(ll, ll)
|
||||
# cupy has no put_along_axis
|
||||
# xp.put_along_axis(fullscore, dyidx, dyscore, 1)
|
||||
for ii, (ridx, rscores) in enumerate(zip(dyidx, dyscore)):
|
||||
fullscore[ii][ridx] = rscores
|
||||
|
||||
dS = source_b(fullscore @ target)
|
||||
dT = target_b(fullscore @ source)
|
||||
dXembeds.data[offset : offset + ll] = dS + dT
|
||||
|
||||
# The gradient can be distributed over all the rows and columns here,
|
||||
# so aggregate it
|
||||
section = dXscores[offset : offset + ll]
|
||||
for ii in range(ll):
|
||||
section[ii] = fullscore[:, ii].sum() + fullscore[ii, :].sum()
|
||||
offset += ll
|
||||
# make it fit back into the linear
|
||||
dXscores = xp.expand_dims(dXscores, 1)
|
||||
return (dXscores, SpanEmbeddings(sembeds.indices, dXembeds))
|
||||
|
||||
return (out, sembeds.indices), backprop
|
||||
|
|
252
spacy/ml/models/coref_util.py
Normal file
252
spacy/ml/models/coref_util.py
Normal file
|
@ -0,0 +1,252 @@
|
|||
from thinc.types import Ints2d
|
||||
from spacy.tokens import Doc
|
||||
from typing import List, Tuple
|
||||
|
||||
# type alias to make writing this less tedious
|
||||
MentionClusters = List[List[Tuple[int, int]]]
|
||||
|
||||
DEFAULT_CLUSTER_PREFIX = "coref_clusters"
|
||||
|
||||
|
||||
def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
|
||||
"""Given a doc, give the mention clusters.
|
||||
|
||||
This is useful for scoring.
|
||||
"""
|
||||
out = []
|
||||
for name, val in doc.spans.items():
|
||||
if not name.startswith(prefix):
|
||||
continue
|
||||
|
||||
cluster = []
|
||||
for mention in val:
|
||||
cluster.append((mention.start, mention.end))
|
||||
out.append(cluster)
|
||||
return out
|
||||
|
||||
|
||||
def topk(xp, arr, k, axis=None):
|
||||
"""Given and array and a k value, give the top values and idxs for each row."""
|
||||
|
||||
part = xp.argpartition(arr, -k, axis=1)
|
||||
idxs = xp.flip(part)[:, :k]
|
||||
|
||||
vals = xp.take_along_axis(arr, idxs, axis=1)
|
||||
|
||||
sidxs = xp.argsort(vals, axis=1)
|
||||
# map these idxs back to the original
|
||||
oidxs = xp.take_along_axis(idxs, sidxs, axis=1)
|
||||
svals = xp.take_along_axis(vals, sidxs, axis=1)
|
||||
return svals, oidxs
|
||||
|
||||
|
||||
def logsumexp(xp, arr, axis=None):
|
||||
"""Emulate torch.logsumexp by returning the log of summed exponentials
|
||||
along each row in the given dimension.
|
||||
|
||||
Reduces a 2d array to 1d."""
|
||||
# from slide 5 here:
|
||||
# https://www.slideshare.net/ryokuta/cupy
|
||||
hi = arr.max(axis=axis)
|
||||
hi = xp.expand_dims(hi, 1)
|
||||
return hi.squeeze() + xp.log(xp.exp(arr - hi).sum(axis=axis))
|
||||
|
||||
|
||||
# from model.py, refactored to be non-member
|
||||
def get_predicted_antecedents(xp, antecedent_idx, antecedent_scores):
|
||||
"""Get the ID of the antecedent for each span. -1 if no antecedent."""
|
||||
predicted_antecedents = []
|
||||
for i, idx in enumerate(xp.argmax(antecedent_scores, axis=1) - 1):
|
||||
if idx < 0:
|
||||
predicted_antecedents.append(-1)
|
||||
else:
|
||||
predicted_antecedents.append(antecedent_idx[i][idx])
|
||||
return predicted_antecedents
|
||||
|
||||
|
||||
# from model.py, refactored to be non-member
|
||||
def get_predicted_clusters(
|
||||
xp, span_starts, span_ends, antecedent_idx, antecedent_scores
|
||||
):
|
||||
"""Convert predictions to usable cluster data.
|
||||
|
||||
return values:
|
||||
|
||||
clusters: a list of spans (i, j) that are a cluster
|
||||
|
||||
Note that not all spans will be in the final output; spans with no
|
||||
antecedent or referrent are omitted from clusters and mention2cluster.
|
||||
"""
|
||||
# Get predicted antecedents
|
||||
predicted_antecedents = get_predicted_antecedents(
|
||||
xp, antecedent_idx, antecedent_scores
|
||||
)
|
||||
|
||||
# Get predicted clusters
|
||||
mention_to_cluster_id = {}
|
||||
predicted_clusters = []
|
||||
for i, predicted_idx in enumerate(predicted_antecedents):
|
||||
if predicted_idx < 0:
|
||||
continue
|
||||
assert i > predicted_idx, f"span idx: {i}; antecedent idx: {predicted_idx}"
|
||||
# Check antecedent's cluster
|
||||
antecedent = (int(span_starts[predicted_idx]), int(span_ends[predicted_idx]))
|
||||
antecedent_cluster_id = mention_to_cluster_id.get(antecedent, -1)
|
||||
if antecedent_cluster_id == -1:
|
||||
antecedent_cluster_id = len(predicted_clusters)
|
||||
predicted_clusters.append([antecedent])
|
||||
mention_to_cluster_id[antecedent] = antecedent_cluster_id
|
||||
# Add mention to cluster
|
||||
mention = (int(span_starts[i]), int(span_ends[i]))
|
||||
predicted_clusters[antecedent_cluster_id].append(mention)
|
||||
mention_to_cluster_id[mention] = antecedent_cluster_id
|
||||
|
||||
predicted_clusters = [tuple(c) for c in predicted_clusters]
|
||||
return predicted_clusters
|
||||
|
||||
|
||||
def get_sentence_map(doc: Doc):
|
||||
"""For the given span, return a list of sentence indexes."""
|
||||
|
||||
si = 0
|
||||
out = []
|
||||
for sent in doc.sents:
|
||||
for tok in sent:
|
||||
out.append(si)
|
||||
si += 1
|
||||
return out
|
||||
|
||||
|
||||
def get_candidate_mentions(
|
||||
doc: Doc, max_span_width: int = 20
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
"""Given a Doc, return candidate mentions.
|
||||
|
||||
This isn't a trainable layer, it just returns raw candidates.
|
||||
"""
|
||||
# XXX Note that in coref-hoi the indexes are designed so you actually want [i:j+1], but here
|
||||
# we're using [i:j], which is more natural.
|
||||
|
||||
sentence_map = get_sentence_map(doc)
|
||||
|
||||
begins = []
|
||||
ends = []
|
||||
for tok in doc:
|
||||
si = sentence_map[tok.i] # sentence index
|
||||
for ii in range(1, max_span_width):
|
||||
ei = tok.i + ii # end index
|
||||
if ei < len(doc) and sentence_map[ei] == si:
|
||||
begins.append(tok.i)
|
||||
ends.append(ei)
|
||||
|
||||
return (begins, ends)
|
||||
|
||||
|
||||
def select_non_crossing_spans(
|
||||
idxs: List[int], starts: List[int], ends: List[int], limit: int
|
||||
) -> List[int]:
|
||||
"""Given a list of spans sorted in descending order, return the indexes of
|
||||
spans to keep, discarding spans that cross.
|
||||
|
||||
Nested spans are allowed.
|
||||
"""
|
||||
# ported from Model._extract_top_spans
|
||||
selected = []
|
||||
start_to_max_end = {}
|
||||
end_to_min_start = {}
|
||||
|
||||
for idx in idxs:
|
||||
if len(selected) >= limit or idx > len(starts):
|
||||
break
|
||||
|
||||
start, end = starts[idx], ends[idx]
|
||||
cross = False
|
||||
|
||||
for ti in range(start, end + 1):
|
||||
max_end = start_to_max_end.get(ti, -1)
|
||||
if ti > start and max_end > end:
|
||||
cross = True
|
||||
break
|
||||
|
||||
min_start = end_to_min_start.get(ti, -1)
|
||||
if ti < end and 0 <= min_start < start:
|
||||
cross = True
|
||||
break
|
||||
|
||||
if not cross:
|
||||
# this index will be kept
|
||||
# record it so we can exclude anything that crosses it
|
||||
selected.append(idx)
|
||||
max_end = start_to_max_end.get(start, -1)
|
||||
if end > max_end:
|
||||
start_to_max_end[start] = end
|
||||
min_start = end_to_min_start.get(end, -1)
|
||||
if start == -1 or start < min_start:
|
||||
end_to_min_start[end] = start
|
||||
|
||||
# sort idxs by order in doc
|
||||
selected = sorted(selected, key=lambda idx: (starts[idx], ends[idx]))
|
||||
while len(selected) < limit:
|
||||
selected.append(selected[0]) # this seems a bit weird?
|
||||
return selected
|
||||
|
||||
|
||||
def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
|
||||
"""Given a Doc, convert the cluster spans to simple int tuple lists."""
|
||||
out = []
|
||||
for key, val in doc.spans.items():
|
||||
cluster = []
|
||||
for span in val:
|
||||
# TODO check that there isn't an off-by-one error here
|
||||
cluster.append((span.start, span.end))
|
||||
out.append(cluster)
|
||||
return out
|
||||
|
||||
|
||||
def make_clean_doc(nlp, doc):
|
||||
"""Return a doc with raw data but not span annotations."""
|
||||
# Surely there is a better way to do this?
|
||||
|
||||
sents = [tok.is_sent_start for tok in doc]
|
||||
words = [tok.text for tok in doc]
|
||||
out = Doc(nlp.vocab, words=words, sent_starts=sents)
|
||||
return out
|
||||
|
||||
|
||||
def create_gold_scores(
|
||||
ments: Ints2d, clusters: List[List[Tuple[int, int]]]
|
||||
) -> List[List[bool]]:
|
||||
"""Given mentions considered for antecedents and gold clusters,
|
||||
construct a gold score matrix. This does not include the placeholder."""
|
||||
# make a mapping of mentions to cluster id
|
||||
# id is not important but equality will be
|
||||
ment2cid = {}
|
||||
for cid, cluster in enumerate(clusters):
|
||||
for ment in cluster:
|
||||
ment2cid[ment] = cid
|
||||
|
||||
ll = len(ments)
|
||||
out = []
|
||||
# The .tolist() call is necessary with cupy but not numpy
|
||||
mentuples = [tuple(mm.tolist()) for mm in ments]
|
||||
for ii, ment in enumerate(mentuples):
|
||||
if ment not in ment2cid:
|
||||
# this is not in a cluster so it has no antecedent
|
||||
out.append([False] * ll)
|
||||
continue
|
||||
|
||||
# this might change if no real antecedent is a candidate
|
||||
row = []
|
||||
cid = ment2cid[ment]
|
||||
for jj, ante in enumerate(mentuples):
|
||||
# antecedents must come first
|
||||
if jj >= ii:
|
||||
row.append(False)
|
||||
continue
|
||||
|
||||
row.append(cid == ment2cid.get(ante, -1))
|
||||
|
||||
out.append(row)
|
||||
|
||||
# caller needs to convert to array, and add placeholder
|
||||
return out
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Iterable, Tuple, Optional, Dict, Callable, Any
|
||||
from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List
|
||||
|
||||
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
|
||||
from thinc.types import Floats2d, Ints2d
|
||||
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
||||
from itertools import islice
|
||||
|
||||
from .trainable_pipe import TrainablePipe
|
||||
|
@ -12,10 +13,25 @@ from ..scorer import Scorer
|
|||
from ..tokens import Doc
|
||||
from ..vocab import Vocab
|
||||
|
||||
from ..ml.models.coref_util import (
|
||||
create_gold_scores,
|
||||
MentionClusters,
|
||||
get_clusters_from_doc,
|
||||
logsumexp,
|
||||
get_predicted_clusters,
|
||||
DEFAULT_CLUSTER_PREFIX,
|
||||
doc2clusters,
|
||||
)
|
||||
|
||||
|
||||
default_config = """
|
||||
[model]
|
||||
@architectures = "spacy.Coref.v0"
|
||||
max_span_width = 20
|
||||
mention_limit = 3900
|
||||
dropout = 0.3
|
||||
hidden = 1000
|
||||
@get_mentions = "spacy.CorefCandidateGenerator.v0"
|
||||
|
||||
[model.tok2vec]
|
||||
@architectures = "spacy.Tok2Vec.v2"
|
||||
|
@ -41,12 +57,11 @@ DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
|
|||
|
||||
@Language.factory(
|
||||
"coref",
|
||||
assigns=[f"doc.spans"],
|
||||
assigns=["doc.spans"],
|
||||
requires=["doc.spans"],
|
||||
default_config={
|
||||
"model": DEFAULT_MODEL,
|
||||
"span_mentions": DEFAULT_MENTIONS,
|
||||
"span_cluster_prefix": DEFAULT_CLUSTERS_PREFIX,
|
||||
"span_cluster_prefix": DEFAULT_CLUSTER_PREFIX,
|
||||
},
|
||||
default_score_weights={"coref_f": 1.0, "coref_p": None, "coref_r": None},
|
||||
)
|
||||
|
@ -54,21 +69,11 @@ def make_coref(
|
|||
nlp: Language,
|
||||
name: str,
|
||||
model,
|
||||
span_mentions: str,
|
||||
span_cluster_prefix: str,
|
||||
span_cluster_prefix: str = "coref",
|
||||
) -> "CoreferenceResolver":
|
||||
"""Create a CoreferenceResolver component. TODO
|
||||
"""Create a CoreferenceResolver component."""
|
||||
|
||||
model (Model[List[Doc], List[Floats2d]]): A model instance that predicts ...
|
||||
threshold (float): Cutoff to consider a prediction "positive".
|
||||
"""
|
||||
return CoreferenceResolver(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
span_mentions=span_mentions,
|
||||
span_cluster_prefix=span_cluster_prefix,
|
||||
)
|
||||
return CoreferenceResolver(nlp.vocab, model, name, span_cluster_prefix)
|
||||
|
||||
|
||||
class CoreferenceResolver(TrainablePipe):
|
||||
|
@ -105,9 +110,11 @@ class CoreferenceResolver(TrainablePipe):
|
|||
self.span_mentions = span_mentions
|
||||
self.span_cluster_prefix = span_cluster_prefix
|
||||
self._rehearsal_model = None
|
||||
self.loss = CategoricalCrossentropy()
|
||||
|
||||
self.cfg = {}
|
||||
|
||||
def predict(self, docs: Iterable[Doc]):
|
||||
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
TODO: write actual algorithm
|
||||
|
||||
|
@ -116,12 +123,27 @@ class CoreferenceResolver(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/coref#predict (TODO)
|
||||
"""
|
||||
scores, idxs = self.model.predict(docs)
|
||||
# idxs is a list of mentions (start / end idxs)
|
||||
# each item in scores includes scores and a mapping from scores to mentions
|
||||
|
||||
xp = self.model.ops.xp
|
||||
|
||||
clusters_by_doc = []
|
||||
for i, doc in enumerate(docs):
|
||||
clusters = []
|
||||
for span in doc.spans[self.span_mentions]:
|
||||
clusters.append([span])
|
||||
clusters_by_doc.append(clusters)
|
||||
offset = 0
|
||||
for cscores, ant_idxs in scores:
|
||||
ll = cscores.shape[0]
|
||||
hi = offset + ll
|
||||
|
||||
starts = idxs[offset:hi, 0]
|
||||
ends = idxs[offset:hi, 1]
|
||||
|
||||
# need to add the placeholder
|
||||
placeholder = self.model.ops.alloc2f(cscores.shape[0], 1)
|
||||
cscores = xp.concatenate((placeholder, cscores), 1)
|
||||
|
||||
predicted = get_predicted_clusters(xp, starts, ends, ant_idxs, cscores)
|
||||
clusters_by_doc.append(predicted)
|
||||
return clusters_by_doc
|
||||
|
||||
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
|
||||
|
@ -133,18 +155,24 @@ class CoreferenceResolver(TrainablePipe):
|
|||
DOCS: https://spacy.io/api/coref#set_annotations (TODO)
|
||||
"""
|
||||
if len(docs) != len(clusters_by_doc):
|
||||
raise ValueError("Found coref clusters incompatible with the "
|
||||
raise ValueError(
|
||||
"Found coref clusters incompatible with the "
|
||||
"documents provided to the 'coref' component. "
|
||||
"This is likely a bug in spaCy.")
|
||||
"This is likely a bug in spaCy."
|
||||
)
|
||||
for doc, clusters in zip(docs, clusters_by_doc):
|
||||
index = 0
|
||||
for cluster in clusters:
|
||||
key = self.span_cluster_prefix + str(index)
|
||||
for ii, cluster in enumerate(clusters):
|
||||
key = self.span_cluster_prefix + "_" + str(ii)
|
||||
if key in doc.spans:
|
||||
raise ValueError(f"Couldn't store the results of {self.name}, as the key "
|
||||
f"{key} already exists in 'doc.spans'.")
|
||||
doc.spans[key] = cluster
|
||||
index += 1
|
||||
raise ValueError(
|
||||
"Found coref clusters incompatible with the "
|
||||
"documents provided to the 'coref' component. "
|
||||
"This is likely a bug in spaCy."
|
||||
)
|
||||
|
||||
doc.spans[key] = []
|
||||
for mention in cluster:
|
||||
doc.spans[key].append(doc[mention[0] : mention[1]])
|
||||
|
||||
def update(
|
||||
self,
|
||||
|
@ -174,13 +202,16 @@ class CoreferenceResolver(TrainablePipe):
|
|||
# Handle cases where there are no tokens in any docs.
|
||||
return losses
|
||||
set_dropout_rate(self.model, drop)
|
||||
scores, bp_scores = self.model.begin_update([eg.predicted for eg in examples])
|
||||
# TODO below
|
||||
# loss, d_scores = self.get_loss(examples, scores)
|
||||
# bp_scores(d_scores)
|
||||
|
||||
inputs = (example.predicted for example in examples)
|
||||
preds, backprop = self.model.begin_update(inputs)
|
||||
score_matrix, mention_idx = preds
|
||||
loss, d_scores = self.get_loss(examples, score_matrix, mention_idx)
|
||||
backprop(d_scores)
|
||||
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
# losses[self.name] += loss
|
||||
losses[self.name] += loss
|
||||
return losses
|
||||
|
||||
def rehearse(
|
||||
|
@ -236,7 +267,12 @@ class CoreferenceResolver(TrainablePipe):
|
|||
)
|
||||
)
|
||||
|
||||
def get_loss(self, examples: Iterable[Example], scores) -> Tuple[float, float]:
|
||||
def get_loss(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
score_matrix: List[Tuple[Floats2d, Ints2d]],
|
||||
mention_idx: Ints2d,
|
||||
):
|
||||
"""Find the loss and gradient of loss for the batch of documents and
|
||||
their predicted scores.
|
||||
|
||||
|
@ -246,9 +282,46 @@ class CoreferenceResolver(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/coref#get_loss (TODO)
|
||||
"""
|
||||
validate_examples(examples, "CoreferenceResolver.get_loss")
|
||||
# TODO
|
||||
return None
|
||||
ops = self.model.ops
|
||||
xp = ops.xp
|
||||
|
||||
offset = 0
|
||||
gradients = []
|
||||
loss = 0
|
||||
for example, (cscores, cidx) in zip(examples, score_matrix):
|
||||
# assume cids has absolute mention ids
|
||||
|
||||
ll = cscores.shape[0]
|
||||
hi = offset + ll
|
||||
|
||||
clusters = get_clusters_from_doc(example.reference)
|
||||
gscores = create_gold_scores(mention_idx[offset:hi], clusters)
|
||||
gscores = xp.asarray(gscores)
|
||||
top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
|
||||
# now add the placeholder
|
||||
gold_placeholder = ~top_gscores.any(axis=1).T
|
||||
gold_placeholder = xp.expand_dims(gold_placeholder, 1)
|
||||
top_gscores = xp.concatenate((gold_placeholder, top_gscores), 1)
|
||||
|
||||
# boolean to float
|
||||
top_gscores = ops.asarray2f(top_gscores)
|
||||
|
||||
# add the placeholder to cscores
|
||||
placeholder = self.model.ops.alloc2f(ll, 1)
|
||||
cscores = xp.concatenate((placeholder, cscores), 1)
|
||||
|
||||
# do softmax to cscores
|
||||
cscores = ops.softmax(cscores, axis=1)
|
||||
|
||||
diff = self.loss.get_grad(cscores, top_gscores)
|
||||
diff = diff[:, 1:]
|
||||
gradients.append((diff, cidx))
|
||||
|
||||
# scalar loss
|
||||
# loss += xp.sum(log_norm - log_marg)
|
||||
loss += self.loss.get_loss(cscores, top_gscores)
|
||||
offset += ll
|
||||
return loss, gradients
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
|
@ -279,10 +352,39 @@ class CoreferenceResolver(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/coref#score (TODO)
|
||||
"""
|
||||
|
||||
def clusters_getter(doc, span_key):
|
||||
return [spans for name, spans in doc.spans.items() if name.startswith(span_key)]
|
||||
return [
|
||||
spans for name, spans in doc.spans.items() if name.startswith(span_key)
|
||||
]
|
||||
|
||||
validate_examples(examples, "CoreferenceResolver.score")
|
||||
kwargs.setdefault("getter", clusters_getter)
|
||||
kwargs.setdefault("attr", self.span_cluster_prefix)
|
||||
kwargs.setdefault("include_label", False)
|
||||
return Scorer.score_clusters(examples, **kwargs)
|
||||
|
||||
|
||||
# from ..coref_scorer import Evaluator, get_cluster_info, b_cubed
|
||||
# TODO consider whether to use this
|
||||
# def score(self, examples, **kwargs):
|
||||
# """Score a batch of examples."""
|
||||
#
|
||||
# #TODO traditionally coref uses the average of b_cubed, muc, and ceaf.
|
||||
# # we need to handle the average ourselves.
|
||||
# evaluator = Evaluator(b_cubed)
|
||||
#
|
||||
# for ex in examples:
|
||||
# p_clusters = doc2clusters(ex.predicted, self.span_cluster_prefix)
|
||||
# g_clusters = doc2clusters(ex.reference, self.span_cluster_prefix)
|
||||
#
|
||||
# cluster_info = get_cluster_info(p_clusters, g_clusters)
|
||||
#
|
||||
# evaluator.update(cluster_info)
|
||||
#
|
||||
# scores ={
|
||||
# "coref_f": evaluator.get_f1(),
|
||||
# "coref_p": evaluator.get_precision(),
|
||||
# "coref_r": evaluator.get_recall(),
|
||||
# }
|
||||
# return scores
|
||||
|
|
Loading…
Reference in New Issue
Block a user