spaCy/spacy/ml/models/coref_util_wl.py
2022-03-15 19:19:15 +09:00

135 lines
3.2 KiB
Python

""" Contains functions not directly linked to coreference resolution """
from typing import List, Set, Dict, Tuple
from thinc.types import Ints1d
from dataclasses import dataclass
from ...tokens import Doc
from ...language import Language
import torch
EPSILON = 1e-7
class GraphNode:
def __init__(self, node_id: int):
self.id = node_id
self.links: Set[GraphNode] = set()
self.visited = False
def link(self, another: "GraphNode"):
self.links.add(another)
another.links.add(self)
def __repr__(self) -> str:
return str(self.id)
def add_dummy(tensor: torch.Tensor, eps: bool = False):
""" Prepends zeros (or a very small value if eps is True)
to the first (not zeroth) dimension of tensor.
"""
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
shape: List[int] = list(tensor.shape)
shape[1] = 1
if not eps:
dummy = torch.zeros(shape, **kwargs) # type: ignore
else:
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
output = torch.cat((dummy, tensor), dim=1)
return output
# TODO replace with spaCy config
@dataclass
class CorefConfig: # pylint: disable=too-many-instance-attributes, too-few-public-methods
""" Contains values needed to set up the coreference model. """
section: str
data_dir: str
train_data: str
dev_data: str
test_data: str
device: str
bert_model: str
bert_window_size: int
embedding_size: int
sp_embedding_size: int
a_scoring_batch_size: int
hidden_size: int
n_hidden_layers: int
max_span_len: int
rough_k: int
bert_finetune: bool
bert_mini_finetune: bool
dropout_rate: float
learning_rate: float
bert_learning_rate: float
train_epochs: int
bce_loss_weight: float
tokenizer_kwargs: Dict[str, dict]
conll_log_dir: str
def get_sent_ids(doc):
sid = 0
sids = []
for sent in doc.sents:
for tok in sent:
sids.append(sid)
sid += 1
return sids
def get_cluster_ids(doc):
"""Get the cluster ids of head tokens."""
out = [0] * len(doc)
head_spangroups = [doc.spans[sk] for sk in doc.spans if sk.startswith("coref_word_clusters")]
for ii, group in enumerate(head_spangroups, start=1):
for span in group:
out[span[0].i] = ii
return out
def get_head2span(doc):
out = []
for sk in doc.spans:
if not sk.startswith("coref_clusters"):
continue
if len(doc.spans[sk]) == 1:
print("===== UNARY MENTION ====")
for span in doc.spans[sk]:
out.append( (span.root.i, span.start, span.end) )
return out
def doc2tensors(
xp,
doc: Doc
) -> Tuple[Ints1d, Ints1d, Ints1d, Ints1d, Ints1d]:
sent_ids = get_sent_ids(doc)
cluster_ids = get_cluster_ids(doc)
head2span = get_head2span(doc)
if not head2span:
heads, starts, ends = [], [], []
else:
heads, starts, ends = zip(*head2span)
sent_ids = xp.asarray(sent_ids)
cluster_ids = xp.asarray(cluster_ids)
heads = xp.asarray(heads)
starts = xp.asarray(starts)
ends = xp.asarray(ends) - 1
return sent_ids, cluster_ids, heads, starts, ends