Clean up util code

Moved everything into coref_util.py, deleted wl-specific file.
This commit is contained in:
Paul O'Leary McCann 2022-03-15 19:59:44 +09:00
parent 55039a66ad
commit abdc7d87af
3 changed files with 32 additions and 136 deletions

View File

@ -458,7 +458,7 @@ import torch
from thinc.util import xp2torch, torch2xp
# TODO rename this to coref_util
from .coref_util_wl import add_dummy
from .coref_util import add_dummy
# TODO rename to plain coref
@registry.architectures("spacy.WLCoref.v1")

View File

@ -1,13 +1,43 @@
from thinc.types import Ints2d
from spacy.tokens import Doc
from typing import List, Tuple, Callable, Any
from typing import List, Tuple, Callable, Any, Set, Dict
from ...util import registry
import torch
# type alias to make writing this less tedious
MentionClusters = List[List[Tuple[int, int]]]
DEFAULT_CLUSTER_PREFIX = "coref_clusters"
EPSILON = 1e-7
class GraphNode:
def __init__(self, node_id: int):
self.id = node_id
self.links: Set[GraphNode] = set()
self.visited = False
def link(self, another: "GraphNode"):
self.links.add(another)
another.links.add(self)
def __repr__(self) -> str:
return str(self.id)
def add_dummy(tensor: torch.Tensor, eps: bool = False):
""" Prepends zeros (or a very small value if eps is True)
to the first (not zeroth) dimension of tensor.
"""
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
shape: List[int] = list(tensor.shape)
shape[1] = 1
if not eps:
dummy = torch.zeros(shape, **kwargs) # type: ignore
else:
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
output = torch.cat((dummy, tensor), dim=1)
return output
def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
"""Given a doc, give the mention clusters.

View File

@ -1,134 +0,0 @@
""" Contains functions not directly linked to coreference resolution """
from typing import List, Set, Dict, Tuple
from thinc.types import Ints1d
from dataclasses import dataclass
from ...tokens import Doc
from ...language import Language
import torch
EPSILON = 1e-7
class GraphNode:
def __init__(self, node_id: int):
self.id = node_id
self.links: Set[GraphNode] = set()
self.visited = False
def link(self, another: "GraphNode"):
self.links.add(another)
another.links.add(self)
def __repr__(self) -> str:
return str(self.id)
def add_dummy(tensor: torch.Tensor, eps: bool = False):
""" Prepends zeros (or a very small value if eps is True)
to the first (not zeroth) dimension of tensor.
"""
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
shape: List[int] = list(tensor.shape)
shape[1] = 1
if not eps:
dummy = torch.zeros(shape, **kwargs) # type: ignore
else:
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
output = torch.cat((dummy, tensor), dim=1)
return output
# TODO replace with spaCy config
@dataclass
class CorefConfig: # pylint: disable=too-many-instance-attributes, too-few-public-methods
""" Contains values needed to set up the coreference model. """
section: str
data_dir: str
train_data: str
dev_data: str
test_data: str
device: str
bert_model: str
bert_window_size: int
embedding_size: int
sp_embedding_size: int
a_scoring_batch_size: int
hidden_size: int
n_hidden_layers: int
max_span_len: int
rough_k: int
bert_finetune: bool
bert_mini_finetune: bool
dropout_rate: float
learning_rate: float
bert_learning_rate: float
train_epochs: int
bce_loss_weight: float
tokenizer_kwargs: Dict[str, dict]
conll_log_dir: str
def get_sent_ids(doc):
sid = 0
sids = []
for sent in doc.sents:
for tok in sent:
sids.append(sid)
sid += 1
return sids
def get_cluster_ids(doc):
"""Get the cluster ids of head tokens."""
out = [0] * len(doc)
head_spangroups = [doc.spans[sk] for sk in doc.spans if sk.startswith("coref_word_clusters")]
for ii, group in enumerate(head_spangroups, start=1):
for span in group:
out[span[0].i] = ii
return out
def get_head2span(doc):
out = []
for sk in doc.spans:
if not sk.startswith("coref_clusters"):
continue
if len(doc.spans[sk]) == 1:
print("===== UNARY MENTION ====")
for span in doc.spans[sk]:
out.append( (span.root.i, span.start, span.end) )
return out
def doc2tensors(
xp,
doc: Doc
) -> Tuple[Ints1d, Ints1d, Ints1d, Ints1d, Ints1d]:
sent_ids = get_sent_ids(doc)
cluster_ids = get_cluster_ids(doc)
head2span = get_head2span(doc)
if not head2span:
heads, starts, ends = [], [], []
else:
heads, starts, ends = zip(*head2span)
sent_ids = xp.asarray(sent_ids)
cluster_ids = xp.asarray(cluster_ids)
heads = xp.asarray(heads)
starts = xp.asarray(starts)
ends = xp.asarray(ends) - 1
return sent_ids, cluster_ids, heads, starts, ends