spaCy/spacy/ml/models/coref_util.py
Paul O'Leary McCann 7c42a8c90a Migrate coref code
This includes the coref code that was being tested separately, modified
to work in spaCy. It hasn't been tested yet and presumably still needs
fixes.

In particular, the evaluation code is currently omitted. It's unclear at
the moment whether we want to use a complex scorer similar to the
official one, or a simpler scorer using more modern evaluation methods.
2021-05-15 21:36:10 +09:00

253 lines
7.9 KiB
Python

from thinc.types import Ints2d
from spacy.tokens import Doc
from typing import List, Tuple
# type alias to make writing this less tedious
MentionClusters = List[List[Tuple[int, int]]]
DEFAULT_CLUSTER_PREFIX = "coref_clusters"
def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
"""Given a doc, give the mention clusters.
This is useful for scoring.
"""
out = []
for name, val in doc.spans.items():
if not name.startswith(prefix):
continue
cluster = []
for mention in val:
cluster.append((mention.start, mention.end))
out.append(cluster)
return out
def topk(xp, arr, k, axis=None):
"""Given and array and a k value, give the top values and idxs for each row."""
part = xp.argpartition(arr, -k, axis=1)
idxs = xp.flip(part)[:, :k]
vals = xp.take_along_axis(arr, idxs, axis=1)
sidxs = xp.argsort(vals, axis=1)
# map these idxs back to the original
oidxs = xp.take_along_axis(idxs, sidxs, axis=1)
svals = xp.take_along_axis(vals, sidxs, axis=1)
return svals, oidxs
def logsumexp(xp, arr, axis=None):
"""Emulate torch.logsumexp by returning the log of summed exponentials
along each row in the given dimension.
Reduces a 2d array to 1d."""
# from slide 5 here:
# https://www.slideshare.net/ryokuta/cupy
hi = arr.max(axis=axis)
hi = xp.expand_dims(hi, 1)
return hi.squeeze() + xp.log(xp.exp(arr - hi).sum(axis=axis))
# from model.py, refactored to be non-member
def get_predicted_antecedents(xp, antecedent_idx, antecedent_scores):
"""Get the ID of the antecedent for each span. -1 if no antecedent."""
predicted_antecedents = []
for i, idx in enumerate(xp.argmax(antecedent_scores, axis=1) - 1):
if idx < 0:
predicted_antecedents.append(-1)
else:
predicted_antecedents.append(antecedent_idx[i][idx])
return predicted_antecedents
# from model.py, refactored to be non-member
def get_predicted_clusters(
xp, span_starts, span_ends, antecedent_idx, antecedent_scores
):
"""Convert predictions to usable cluster data.
return values:
clusters: a list of spans (i, j) that are a cluster
Note that not all spans will be in the final output; spans with no
antecedent or referrent are omitted from clusters and mention2cluster.
"""
# Get predicted antecedents
predicted_antecedents = get_predicted_antecedents(
xp, antecedent_idx, antecedent_scores
)
# Get predicted clusters
mention_to_cluster_id = {}
predicted_clusters = []
for i, predicted_idx in enumerate(predicted_antecedents):
if predicted_idx < 0:
continue
assert i > predicted_idx, f"span idx: {i}; antecedent idx: {predicted_idx}"
# Check antecedent's cluster
antecedent = (int(span_starts[predicted_idx]), int(span_ends[predicted_idx]))
antecedent_cluster_id = mention_to_cluster_id.get(antecedent, -1)
if antecedent_cluster_id == -1:
antecedent_cluster_id = len(predicted_clusters)
predicted_clusters.append([antecedent])
mention_to_cluster_id[antecedent] = antecedent_cluster_id
# Add mention to cluster
mention = (int(span_starts[i]), int(span_ends[i]))
predicted_clusters[antecedent_cluster_id].append(mention)
mention_to_cluster_id[mention] = antecedent_cluster_id
predicted_clusters = [tuple(c) for c in predicted_clusters]
return predicted_clusters
def get_sentence_map(doc: Doc):
"""For the given span, return a list of sentence indexes."""
si = 0
out = []
for sent in doc.sents:
for tok in sent:
out.append(si)
si += 1
return out
def get_candidate_mentions(
doc: Doc, max_span_width: int = 20
) -> Tuple[List[int], List[int]]:
"""Given a Doc, return candidate mentions.
This isn't a trainable layer, it just returns raw candidates.
"""
# XXX Note that in coref-hoi the indexes are designed so you actually want [i:j+1], but here
# we're using [i:j], which is more natural.
sentence_map = get_sentence_map(doc)
begins = []
ends = []
for tok in doc:
si = sentence_map[tok.i] # sentence index
for ii in range(1, max_span_width):
ei = tok.i + ii # end index
if ei < len(doc) and sentence_map[ei] == si:
begins.append(tok.i)
ends.append(ei)
return (begins, ends)
def select_non_crossing_spans(
idxs: List[int], starts: List[int], ends: List[int], limit: int
) -> List[int]:
"""Given a list of spans sorted in descending order, return the indexes of
spans to keep, discarding spans that cross.
Nested spans are allowed.
"""
# ported from Model._extract_top_spans
selected = []
start_to_max_end = {}
end_to_min_start = {}
for idx in idxs:
if len(selected) >= limit or idx > len(starts):
break
start, end = starts[idx], ends[idx]
cross = False
for ti in range(start, end + 1):
max_end = start_to_max_end.get(ti, -1)
if ti > start and max_end > end:
cross = True
break
min_start = end_to_min_start.get(ti, -1)
if ti < end and 0 <= min_start < start:
cross = True
break
if not cross:
# this index will be kept
# record it so we can exclude anything that crosses it
selected.append(idx)
max_end = start_to_max_end.get(start, -1)
if end > max_end:
start_to_max_end[start] = end
min_start = end_to_min_start.get(end, -1)
if start == -1 or start < min_start:
end_to_min_start[end] = start
# sort idxs by order in doc
selected = sorted(selected, key=lambda idx: (starts[idx], ends[idx]))
while len(selected) < limit:
selected.append(selected[0]) # this seems a bit weird?
return selected
def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
"""Given a Doc, convert the cluster spans to simple int tuple lists."""
out = []
for key, val in doc.spans.items():
cluster = []
for span in val:
# TODO check that there isn't an off-by-one error here
cluster.append((span.start, span.end))
out.append(cluster)
return out
def make_clean_doc(nlp, doc):
"""Return a doc with raw data but not span annotations."""
# Surely there is a better way to do this?
sents = [tok.is_sent_start for tok in doc]
words = [tok.text for tok in doc]
out = Doc(nlp.vocab, words=words, sent_starts=sents)
return out
def create_gold_scores(
ments: Ints2d, clusters: List[List[Tuple[int, int]]]
) -> List[List[bool]]:
"""Given mentions considered for antecedents and gold clusters,
construct a gold score matrix. This does not include the placeholder."""
# make a mapping of mentions to cluster id
# id is not important but equality will be
ment2cid = {}
for cid, cluster in enumerate(clusters):
for ment in cluster:
ment2cid[ment] = cid
ll = len(ments)
out = []
# The .tolist() call is necessary with cupy but not numpy
mentuples = [tuple(mm.tolist()) for mm in ments]
for ii, ment in enumerate(mentuples):
if ment not in ment2cid:
# this is not in a cluster so it has no antecedent
out.append([False] * ll)
continue
# this might change if no real antecedent is a candidate
row = []
cid = ment2cid[ment]
for jj, ante in enumerate(mentuples):
# antecedents must come first
if jj >= ii:
row.append(False)
continue
row.append(cid == ment2cid.get(ante, -1))
out.append(row)
# caller needs to convert to array, and add placeholder
return out