mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 04:32:32 +03:00
Add util functions for wl-coref
This commit is contained in:
parent
dfec6993d6
commit
e6917d8dc4
163
spacy/ml/models/coref_util_wl.py
Normal file
163
spacy/ml/models/coref_util_wl.py
Normal file
|
@ -0,0 +1,163 @@
|
|||
""" Contains functions not directly linked to coreference resolution """
|
||||
|
||||
from typing import List, Set, Dict, Tuple
|
||||
from thinc.types import Ints1d
|
||||
from dataclasses import dataclass
|
||||
from ...tokens import Doc
|
||||
|
||||
import torch
|
||||
|
||||
EPSILON = 1e-7
|
||||
|
||||
class GraphNode:
|
||||
def __init__(self, node_id: int):
|
||||
self.id = node_id
|
||||
self.links: Set[GraphNode] = set()
|
||||
self.visited = False
|
||||
|
||||
def link(self, another: "GraphNode"):
|
||||
self.links.add(another)
|
||||
another.links.add(self)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self.id)
|
||||
|
||||
|
||||
def add_dummy(tensor: torch.Tensor, eps: bool = False):
|
||||
""" Prepends zeros (or a very small value if eps is True)
|
||||
to the first (not zeroth) dimension of tensor.
|
||||
"""
|
||||
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
|
||||
shape: List[int] = list(tensor.shape)
|
||||
shape[1] = 1
|
||||
if not eps:
|
||||
dummy = torch.zeros(shape, **kwargs) # type: ignore
|
||||
else:
|
||||
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
|
||||
output = torch.cat((dummy, tensor), dim=1)
|
||||
return output
|
||||
|
||||
def make_head_only_clusters(examples):
|
||||
"""Replace coref clusters with head-only clusters.
|
||||
|
||||
This destructively modifies the docs.
|
||||
"""
|
||||
|
||||
#TODO what if all clusters are eliminated?
|
||||
for eg in examples:
|
||||
final = [] # save out clusters here
|
||||
for key, sg in eg.reference.spans.items():
|
||||
if not key.startswith("coref_clusters_"):
|
||||
continue
|
||||
|
||||
heads = [span.root.i for span in sg]
|
||||
heads = list(set(heads))
|
||||
head_spans = [eg.reference[hh:hh+1] for hh in heads]
|
||||
if len(heads) > 1:
|
||||
final.append(head_spans)
|
||||
|
||||
# now delete the existing clusters
|
||||
keys = list(eg.reference.spans.keys())
|
||||
for key in keys:
|
||||
if not key.startswith("coref_clusters_"):
|
||||
continue
|
||||
|
||||
del eg.reference.spans[key]
|
||||
|
||||
# now add the new spangroups
|
||||
for ii, spans in enumerate(final):
|
||||
#TODO support alternate keys
|
||||
eg.reference.spans[f"coref_clusters_{ii}"] = spans
|
||||
|
||||
# TODO replace with spaCy config
|
||||
@dataclass
|
||||
class CorefConfig: # pylint: disable=too-many-instance-attributes, too-few-public-methods
|
||||
""" Contains values needed to set up the coreference model. """
|
||||
section: str
|
||||
|
||||
data_dir: str
|
||||
|
||||
train_data: str
|
||||
dev_data: str
|
||||
test_data: str
|
||||
|
||||
device: str
|
||||
|
||||
bert_model: str
|
||||
bert_window_size: int
|
||||
|
||||
embedding_size: int
|
||||
sp_embedding_size: int
|
||||
a_scoring_batch_size: int
|
||||
hidden_size: int
|
||||
n_hidden_layers: int
|
||||
|
||||
max_span_len: int
|
||||
|
||||
rough_k: int
|
||||
|
||||
bert_finetune: bool
|
||||
bert_mini_finetune: bool
|
||||
dropout_rate: float
|
||||
learning_rate: float
|
||||
bert_learning_rate: float
|
||||
train_epochs: int
|
||||
bce_loss_weight: float
|
||||
|
||||
tokenizer_kwargs: Dict[str, dict]
|
||||
conll_log_dir: str
|
||||
|
||||
|
||||
def get_sent_ids(doc):
|
||||
sid = 0
|
||||
sids = []
|
||||
for sent in doc.sents:
|
||||
for tok in sent:
|
||||
sids.append(sid)
|
||||
sid += 1
|
||||
return sids
|
||||
|
||||
def get_cluster_ids(doc):
|
||||
"""Get the cluster ids of head tokens."""
|
||||
|
||||
out = [0] * len(doc)
|
||||
head_spangroups = [doc.spans[sk] for sk in doc.spans if sk.startswith("coref_word_clusters")]
|
||||
for ii, group in enumerate(head_spangroups, start=1):
|
||||
for span in group:
|
||||
out[span[0].i] = ii
|
||||
|
||||
return out
|
||||
|
||||
def get_head2span(doc):
|
||||
out = []
|
||||
for sk in doc.spans:
|
||||
if not sk.startswith("coref_clusters"):
|
||||
continue
|
||||
|
||||
if len(doc.spans[sk]) == 1:
|
||||
print("===== UNARY MENTION ====")
|
||||
|
||||
for span in doc.spans[sk]:
|
||||
out.append( (span.root.i, span.start, span.end) )
|
||||
return out
|
||||
|
||||
|
||||
def doc2tensors(
|
||||
xp,
|
||||
doc: Doc
|
||||
) -> Tuple[Ints1d, Ints1d, Ints1d, Ints1d, Ints1d]:
|
||||
sent_ids = get_sent_ids(doc)
|
||||
cluster_ids = get_cluster_ids(doc)
|
||||
head2span = get_head2span(doc)
|
||||
|
||||
|
||||
if not head2span:
|
||||
heads, starts, ends = [], [], []
|
||||
else:
|
||||
heads, starts, ends = zip(*head2span)
|
||||
sent_ids = xp.asarray(sent_ids)
|
||||
cluster_ids = xp.asarray(cluster_ids)
|
||||
heads = xp.asarray(heads)
|
||||
starts = xp.asarray(starts)
|
||||
ends = xp.asarray(ends) - 1
|
||||
return sent_ids, cluster_ids, heads, starts, ends
|
Loading…
Reference in New Issue
Block a user