Add util functions for wl-coref

This commit is contained in:
Paul O'Leary McCann 2022-03-14 19:27:55 +09:00
parent dfec6993d6
commit e6917d8dc4

View File

@ -0,0 +1,163 @@
""" Contains functions not directly linked to coreference resolution """
from typing import List, Set, Dict, Tuple
from thinc.types import Ints1d
from dataclasses import dataclass
from ...tokens import Doc
import torch
EPSILON = 1e-7
class GraphNode:
def __init__(self, node_id: int):
self.id = node_id
self.links: Set[GraphNode] = set()
self.visited = False
def link(self, another: "GraphNode"):
self.links.add(another)
another.links.add(self)
def __repr__(self) -> str:
return str(self.id)
def add_dummy(tensor: torch.Tensor, eps: bool = False):
""" Prepends zeros (or a very small value if eps is True)
to the first (not zeroth) dimension of tensor.
"""
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
shape: List[int] = list(tensor.shape)
shape[1] = 1
if not eps:
dummy = torch.zeros(shape, **kwargs) # type: ignore
else:
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
output = torch.cat((dummy, tensor), dim=1)
return output
def make_head_only_clusters(examples):
"""Replace coref clusters with head-only clusters.
This destructively modifies the docs.
"""
#TODO what if all clusters are eliminated?
for eg in examples:
final = [] # save out clusters here
for key, sg in eg.reference.spans.items():
if not key.startswith("coref_clusters_"):
continue
heads = [span.root.i for span in sg]
heads = list(set(heads))
head_spans = [eg.reference[hh:hh+1] for hh in heads]
if len(heads) > 1:
final.append(head_spans)
# now delete the existing clusters
keys = list(eg.reference.spans.keys())
for key in keys:
if not key.startswith("coref_clusters_"):
continue
del eg.reference.spans[key]
# now add the new spangroups
for ii, spans in enumerate(final):
#TODO support alternate keys
eg.reference.spans[f"coref_clusters_{ii}"] = spans
# TODO replace with spaCy config
@dataclass
class CorefConfig: # pylint: disable=too-many-instance-attributes, too-few-public-methods
""" Contains values needed to set up the coreference model. """
section: str
data_dir: str
train_data: str
dev_data: str
test_data: str
device: str
bert_model: str
bert_window_size: int
embedding_size: int
sp_embedding_size: int
a_scoring_batch_size: int
hidden_size: int
n_hidden_layers: int
max_span_len: int
rough_k: int
bert_finetune: bool
bert_mini_finetune: bool
dropout_rate: float
learning_rate: float
bert_learning_rate: float
train_epochs: int
bce_loss_weight: float
tokenizer_kwargs: Dict[str, dict]
conll_log_dir: str
def get_sent_ids(doc):
sid = 0
sids = []
for sent in doc.sents:
for tok in sent:
sids.append(sid)
sid += 1
return sids
def get_cluster_ids(doc):
"""Get the cluster ids of head tokens."""
out = [0] * len(doc)
head_spangroups = [doc.spans[sk] for sk in doc.spans if sk.startswith("coref_word_clusters")]
for ii, group in enumerate(head_spangroups, start=1):
for span in group:
out[span[0].i] = ii
return out
def get_head2span(doc):
out = []
for sk in doc.spans:
if not sk.startswith("coref_clusters"):
continue
if len(doc.spans[sk]) == 1:
print("===== UNARY MENTION ====")
for span in doc.spans[sk]:
out.append( (span.root.i, span.start, span.end) )
return out
def doc2tensors(
xp,
doc: Doc
) -> Tuple[Ints1d, Ints1d, Ints1d, Ints1d, Ints1d]:
sent_ids = get_sent_ids(doc)
cluster_ids = get_cluster_ids(doc)
head2span = get_head2span(doc)
if not head2span:
heads, starts, ends = [], [], []
else:
heads, starts, ends = zip(*head2span)
sent_ids = xp.asarray(sent_ids)
cluster_ids = xp.asarray(cluster_ids)
heads = xp.asarray(heads)
starts = xp.asarray(starts)
ends = xp.asarray(ends) - 1
return sent_ids, cluster_ids, heads, starts, ends