spaCy/spacy/ml/models/coref_util.py
2022-07-04 19:28:35 +09:00

220 lines
7.0 KiB
Python

from typing import List, Tuple, Set, Dict, cast
from thinc.types import Ints2d
from spacy.tokens import Doc
# type alias to make writing this less tedious
MentionClusters = List[List[Tuple[int, int]]]
DEFAULT_CLUSTER_PREFIX = "coref_clusters"
class GraphNode:
def __init__(self, node_id: int):
self.id = node_id
self.links: Set[GraphNode] = set()
self.visited = False
def link(self, another: "GraphNode"):
self.links.add(another)
another.links.add(self)
def __repr__(self) -> str:
return str(self.id)
def get_sentence_ids(doc):
out = []
sent_id = -1
for tok in doc:
if tok.is_sent_start:
sent_id += 1
out.append(sent_id)
return out
# from model.py, refactored to be non-member
def get_predicted_antecedents(xp, antecedent_idx, antecedent_scores):
"""Get the ID of the antecedent for each span. -1 if no antecedent."""
predicted_antecedents = []
for i, idx in enumerate(xp.argmax(antecedent_scores, axis=1) - 1):
if idx < 0:
predicted_antecedents.append(-1)
else:
predicted_antecedents.append(antecedent_idx[i][idx])
return predicted_antecedents
# from model.py, refactored to be non-member
def get_predicted_clusters(
xp, span_starts, span_ends, antecedent_idx, antecedent_scores
):
"""Convert predictions to usable cluster data.
return values:
clusters: a list of spans (i, j) that are a cluster
Note that not all spans will be in the final output; spans with no
antecedent or referrent are omitted from clusters and mention2cluster.
"""
# Get predicted antecedents
predicted_antecedents = get_predicted_antecedents(
xp, antecedent_idx, antecedent_scores
)
# Get predicted clusters
mention_to_cluster_id = {}
predicted_clusters = []
for i, predicted_idx in enumerate(predicted_antecedents):
if predicted_idx < 0:
continue
assert i > predicted_idx, f"span idx: {i}; antecedent idx: {predicted_idx}"
# Check antecedent's cluster
antecedent = (int(span_starts[predicted_idx]), int(span_ends[predicted_idx]))
antecedent_cluster_id = mention_to_cluster_id.get(antecedent, -1)
if antecedent_cluster_id == -1:
antecedent_cluster_id = len(predicted_clusters)
predicted_clusters.append([antecedent])
mention_to_cluster_id[antecedent] = antecedent_cluster_id
# Add mention to cluster
mention = (int(span_starts[i]), int(span_ends[i]))
predicted_clusters[antecedent_cluster_id].append(mention)
mention_to_cluster_id[mention] = antecedent_cluster_id
predicted_clusters = [tuple(c) for c in predicted_clusters]
return predicted_clusters
def select_non_crossing_spans(
idxs: List[int], starts: List[int], ends: List[int], limit: int
) -> List[int]:
"""Given a list of spans sorted in descending order, return the indexes of
spans to keep, discarding spans that cross.
Nested spans are allowed.
"""
# ported from Model._extract_top_spans
selected: List[int] = []
start_to_max_end: Dict[int, int] = {}
end_to_min_start: Dict[int, int] = {}
for idx in idxs:
if len(selected) >= limit or idx > len(starts):
break
start, end = starts[idx], ends[idx]
cross = False
for ti in range(start, end):
max_end = start_to_max_end.get(ti, -1)
if ti > start and max_end > end:
cross = True
break
min_start = end_to_min_start.get(ti, -1)
if ti < end and 0 <= min_start < start:
cross = True
break
if not cross:
# this index will be kept
# record it so we can exclude anything that crosses it
selected.append(idx)
max_end = start_to_max_end.get(start, -1)
if end > max_end:
start_to_max_end[start] = end
min_start = end_to_min_start.get(end, -1)
if min_start == -1 or start < min_start:
end_to_min_start[end] = start
# sort idxs by order in doc
selected = sorted(selected, key=lambda idx: (starts[idx], ends[idx]))
# This was causing many repetitive entities in the output - removed for now
# while len(selected) < limit:
# selected.append(selected[0]) # this seems a bit weird?
return selected
def create_head_span_idxs(ops, doclen: int):
"""Helper function to create single-token span indices."""
aa = ops.xp.arange(0, doclen)
bb = ops.xp.arange(0, doclen) + 1
return ops.asarray2i([aa, bb]).T
def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
"""Given a Doc, convert the cluster spans to simple int tuple lists. The
ints are char spans, to be tokenization independent.
"""
out = []
for key, val in doc.spans.items():
cluster = []
for span in val:
head_i = span.root.i
head = doc[head_i]
char_span = (head.idx, head.idx + len(head))
cluster.append(char_span)
# don't want duplicates
cluster = list(set(cluster))
out.append(cluster)
return out
def create_gold_scores(
ments: Ints2d, clusters: List[List[Tuple[int, int]]]
) -> List[List[bool]]:
"""Given mentions considered for antecedents and gold clusters,
construct a gold score matrix. This does not include the placeholder.
In the gold matrix, the value of a true antecedent is True, and otherwise
it is False. These will be converted to 1/0 values later.
"""
# make a mapping of mentions to cluster id
# id is not important but equality will be
ment2cid: Dict[Tuple[int, int], int] = {}
for cid, cluster in enumerate(clusters):
for ment in cluster:
ment2cid[ment] = cid
ll = len(ments)
out = []
# The .tolist() call is necessary with cupy but not numpy
mentuples = [cast(Tuple[int, int], tuple(mm.tolist())) for mm in ments]
for ii, ment in enumerate(mentuples):
if ment not in ment2cid:
# this is not in a cluster so it has no antecedent
out.append([False] * ll)
continue
# this might change if no real antecedent is a candidate
row = []
cid = ment2cid[ment]
for jj, ante in enumerate(mentuples):
# antecedents must come first
if jj >= ii:
row.append(False)
continue
row.append(cid == ment2cid.get(ante, -1))
out.append(row)
# caller needs to convert to array, and add placeholder
return out
def _spans_to_offsets(doc):
"""Convert doc.spans to nested list of ints for comparison.
The ints are character indices, and the spans groups are sorted by key first.
This is useful for checking consistency of predictions.
"""
out = []
keys = sorted([key for key in doc.spans])
for key in keys:
cluster = doc.spans[key]
out.append([(ss.start_char, ss.end_char) for ss in cluster])
return out