From e6917d8dc4e2c7f9c2fc4db1ee6cb018066017b3 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Mon, 14 Mar 2022 19:27:55 +0900 Subject: [PATCH] Add util functions for wl-coref --- spacy/ml/models/coref_util_wl.py | 163 +++++++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 spacy/ml/models/coref_util_wl.py diff --git a/spacy/ml/models/coref_util_wl.py b/spacy/ml/models/coref_util_wl.py new file mode 100644 index 000000000..55326c217 --- /dev/null +++ b/spacy/ml/models/coref_util_wl.py @@ -0,0 +1,163 @@ +""" Contains functions not directly linked to coreference resolution """ + +from typing import List, Set, Dict, Tuple +from thinc.types import Ints1d +from dataclasses import dataclass +from ...tokens import Doc + +import torch + +EPSILON = 1e-7 + +class GraphNode: + def __init__(self, node_id: int): + self.id = node_id + self.links: Set[GraphNode] = set() + self.visited = False + + def link(self, another: "GraphNode"): + self.links.add(another) + another.links.add(self) + + def __repr__(self) -> str: + return str(self.id) + + +def add_dummy(tensor: torch.Tensor, eps: bool = False): + """ Prepends zeros (or a very small value if eps is True) + to the first (not zeroth) dimension of tensor. + """ + kwargs = dict(device=tensor.device, dtype=tensor.dtype) + shape: List[int] = list(tensor.shape) + shape[1] = 1 + if not eps: + dummy = torch.zeros(shape, **kwargs) # type: ignore + else: + dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore + output = torch.cat((dummy, tensor), dim=1) + return output + +def make_head_only_clusters(examples): + """Replace coref clusters with head-only clusters. + + This destructively modifies the docs. + """ + + #TODO what if all clusters are eliminated? + for eg in examples: + final = [] # save out clusters here + for key, sg in eg.reference.spans.items(): + if not key.startswith("coref_clusters_"): + continue + + heads = [span.root.i for span in sg] + heads = list(set(heads)) + head_spans = [eg.reference[hh:hh+1] for hh in heads] + if len(heads) > 1: + final.append(head_spans) + + # now delete the existing clusters + keys = list(eg.reference.spans.keys()) + for key in keys: + if not key.startswith("coref_clusters_"): + continue + + del eg.reference.spans[key] + + # now add the new spangroups + for ii, spans in enumerate(final): + #TODO support alternate keys + eg.reference.spans[f"coref_clusters_{ii}"] = spans + +# TODO replace with spaCy config +@dataclass +class CorefConfig: # pylint: disable=too-many-instance-attributes, too-few-public-methods + """ Contains values needed to set up the coreference model. """ + section: str + + data_dir: str + + train_data: str + dev_data: str + test_data: str + + device: str + + bert_model: str + bert_window_size: int + + embedding_size: int + sp_embedding_size: int + a_scoring_batch_size: int + hidden_size: int + n_hidden_layers: int + + max_span_len: int + + rough_k: int + + bert_finetune: bool + bert_mini_finetune: bool + dropout_rate: float + learning_rate: float + bert_learning_rate: float + train_epochs: int + bce_loss_weight: float + + tokenizer_kwargs: Dict[str, dict] + conll_log_dir: str + + +def get_sent_ids(doc): + sid = 0 + sids = [] + for sent in doc.sents: + for tok in sent: + sids.append(sid) + sid += 1 + return sids + +def get_cluster_ids(doc): + """Get the cluster ids of head tokens.""" + + out = [0] * len(doc) + head_spangroups = [doc.spans[sk] for sk in doc.spans if sk.startswith("coref_word_clusters")] + for ii, group in enumerate(head_spangroups, start=1): + for span in group: + out[span[0].i] = ii + + return out + +def get_head2span(doc): + out = [] + for sk in doc.spans: + if not sk.startswith("coref_clusters"): + continue + + if len(doc.spans[sk]) == 1: + print("===== UNARY MENTION ====") + + for span in doc.spans[sk]: + out.append( (span.root.i, span.start, span.end) ) + return out + + +def doc2tensors( + xp, + doc: Doc +) -> Tuple[Ints1d, Ints1d, Ints1d, Ints1d, Ints1d]: + sent_ids = get_sent_ids(doc) + cluster_ids = get_cluster_ids(doc) + head2span = get_head2span(doc) + + + if not head2span: + heads, starts, ends = [], [], [] + else: + heads, starts, ends = zip(*head2span) + sent_ids = xp.asarray(sent_ids) + cluster_ids = xp.asarray(cluster_ids) + heads = xp.asarray(heads) + starts = xp.asarray(starts) + ends = xp.asarray(ends) - 1 + return sent_ids, cluster_ids, heads, starts, ends