Black formatting

This commit is contained in:
Paul O'Leary McCann 2022-05-25 19:20:03 +09:00
parent 2a8efda689
commit 838f50192b
3 changed files with 24 additions and 41 deletions

View File

@ -46,11 +46,7 @@ def build_wl_coref_model(
return coref_model
def convert_coref_clusterer_inputs(
model: Model,
X: List[Floats2d],
is_train: bool
):
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
# The input here is List[Floats2d], one for each doc
# just use the first
# TODO real batching
@ -115,12 +111,7 @@ class CorefClusterer(torch.nn.Module):
# Modules
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
pair_emb = dim * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer(
pair_emb,
hidden_size,
n_layers,
dropout
)
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
self.lstm = torch.nn.LSTM(
input_size=dim,
hidden_size=dim,
@ -129,13 +120,9 @@ class CorefClusterer(torch.nn.Module):
self.rough_scorer = RoughScorer(dim, dropout, roughk)
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
pair_emb = dim * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer(
pair_emb, hidden_size, n_layers, dropout
)
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
def forward(
self, word_features: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
1. LSTM encodes the incoming word_features.
2. The RoughScorer scores and prunes the candidates.
@ -350,13 +337,9 @@ class DistancePairwiseEncoder(torch.nn.Module):
self.dropout = torch.nn.Dropout(dropout)
self.shape = emb_size
def forward(
self,
top_indices: torch.Tensor
) -> torch.Tensor:
def forward(self, top_indices: torch.Tensor) -> torch.Tensor:
word_ids = torch.arange(0, top_indices.size(0))
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
).clamp_min_(min=1)
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1)
log_distance = distance.to(torch.float).log2().floor_()
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
distance = torch.where(distance < 5, distance - 1, log_distance + 2)

View File

@ -7,6 +7,7 @@ MentionClusters = List[List[Tuple[int, int]]]
DEFAULT_CLUSTER_PREFIX = "coref_clusters"
class GraphNode:
def __init__(self, node_id: int):
self.id = node_id
@ -30,6 +31,7 @@ def get_sentence_ids(doc):
out.append(sent_id)
return out
def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
"""Given a doc, give the mention clusters.
@ -100,7 +102,6 @@ def get_predicted_clusters(
return predicted_clusters
def select_non_crossing_spans(
idxs: List[int], starts: List[int], ends: List[int], limit: int
) -> List[int]:
@ -150,12 +151,14 @@ def select_non_crossing_spans(
# selected.append(selected[0]) # this seems a bit weird?
return selected
def create_head_span_idxs(ops, doclen: int):
"""Helper function to create single-token span indices."""
aa = ops.xp.arange(0, doclen)
bb = ops.xp.arange(0, doclen) + 1
return ops.asarray2i([aa, bb]).T
def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
"""Given a Doc, convert the cluster spans to simple int tuple lists."""
out = []

View File

@ -18,7 +18,7 @@ def build_span_predictor(
conv_channels: int = 4,
window_size: int = 1,
max_distance: int = 128,
prefix: str = "coref_head_clusters"
prefix: str = "coref_head_clusters",
):
# TODO add model return types
# TODO fix this
@ -36,7 +36,7 @@ def build_span_predictor(
distance_embedding_size,
conv_channels,
window_size,
max_distance
max_distance,
),
convert_inputs=convert_span_predictor_inputs,
)
@ -139,14 +139,11 @@ class SpanPredictor(torch.nn.Module):
dist_emb_size: int,
conv_channels: int,
window_size: int,
max_distance: int
max_distance: int,
):
super().__init__()
if max_distance % 2 != 0:
raise ValueError(
"max_distance has to be an even number"
)
raise ValueError("max_distance has to be an even number")
# input size = single token size
# 64 = probably distance emb size
# TODO check that dist_emb_size use is correct
@ -164,7 +161,7 @@ class SpanPredictor(torch.nn.Module):
kernel_size = window_size * 2 + 1
self.conv = torch.nn.Sequential(
torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1),
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1)
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1),
)
# TODO make embeddings size a parameter
self.max_distance = max_distance