Black formatting

This commit is contained in:
Paul O'Leary McCann 2022-05-25 19:20:03 +09:00
parent 2a8efda689
commit 838f50192b
3 changed files with 24 additions and 41 deletions

View File

@ -46,11 +46,7 @@ def build_wl_coref_model(
return coref_model return coref_model
def convert_coref_clusterer_inputs( def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
model: Model,
X: List[Floats2d],
is_train: bool
):
# The input here is List[Floats2d], one for each doc # The input here is List[Floats2d], one for each doc
# just use the first # just use the first
# TODO real batching # TODO real batching
@ -115,12 +111,7 @@ class CorefClusterer(torch.nn.Module):
# Modules # Modules
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout) self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
pair_emb = dim * 3 + self.pw.shape pair_emb = dim * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer( self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
pair_emb,
hidden_size,
n_layers,
dropout
)
self.lstm = torch.nn.LSTM( self.lstm = torch.nn.LSTM(
input_size=dim, input_size=dim,
hidden_size=dim, hidden_size=dim,
@ -129,13 +120,9 @@ class CorefClusterer(torch.nn.Module):
self.rough_scorer = RoughScorer(dim, dropout, roughk) self.rough_scorer = RoughScorer(dim, dropout, roughk)
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout) self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
pair_emb = dim * 3 + self.pw.shape pair_emb = dim * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer( self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
pair_emb, hidden_size, n_layers, dropout
)
def forward( def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self, word_features: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
1. LSTM encodes the incoming word_features. 1. LSTM encodes the incoming word_features.
2. The RoughScorer scores and prunes the candidates. 2. The RoughScorer scores and prunes the candidates.
@ -350,13 +337,9 @@ class DistancePairwiseEncoder(torch.nn.Module):
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
self.shape = emb_size self.shape = emb_size
def forward( def forward(self, top_indices: torch.Tensor) -> torch.Tensor:
self,
top_indices: torch.Tensor
) -> torch.Tensor:
word_ids = torch.arange(0, top_indices.size(0)) word_ids = torch.arange(0, top_indices.size(0))
distance = (word_ids.unsqueeze(1) - word_ids[top_indices] distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1)
).clamp_min_(min=1)
log_distance = distance.to(torch.float).log2().floor_() log_distance = distance.to(torch.float).log2().floor_()
log_distance = log_distance.clamp_max_(max=6).to(torch.long) log_distance = log_distance.clamp_max_(max=6).to(torch.long)
distance = torch.where(distance < 5, distance - 1, log_distance + 2) distance = torch.where(distance < 5, distance - 1, log_distance + 2)

View File

@ -7,6 +7,7 @@ MentionClusters = List[List[Tuple[int, int]]]
DEFAULT_CLUSTER_PREFIX = "coref_clusters" DEFAULT_CLUSTER_PREFIX = "coref_clusters"
class GraphNode: class GraphNode:
def __init__(self, node_id: int): def __init__(self, node_id: int):
self.id = node_id self.id = node_id
@ -30,6 +31,7 @@ def get_sentence_ids(doc):
out.append(sent_id) out.append(sent_id)
return out return out
def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters: def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
"""Given a doc, give the mention clusters. """Given a doc, give the mention clusters.
@ -100,7 +102,6 @@ def get_predicted_clusters(
return predicted_clusters return predicted_clusters
def select_non_crossing_spans( def select_non_crossing_spans(
idxs: List[int], starts: List[int], ends: List[int], limit: int idxs: List[int], starts: List[int], ends: List[int], limit: int
) -> List[int]: ) -> List[int]:
@ -150,12 +151,14 @@ def select_non_crossing_spans(
# selected.append(selected[0]) # this seems a bit weird? # selected.append(selected[0]) # this seems a bit weird?
return selected return selected
def create_head_span_idxs(ops, doclen: int): def create_head_span_idxs(ops, doclen: int):
"""Helper function to create single-token span indices.""" """Helper function to create single-token span indices."""
aa = ops.xp.arange(0, doclen) aa = ops.xp.arange(0, doclen)
bb = ops.xp.arange(0, doclen) + 1 bb = ops.xp.arange(0, doclen) + 1
return ops.asarray2i([aa, bb]).T return ops.asarray2i([aa, bb]).T
def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]: def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
"""Given a Doc, convert the cluster spans to simple int tuple lists.""" """Given a Doc, convert the cluster spans to simple int tuple lists."""
out = [] out = []

View File

@ -18,7 +18,7 @@ def build_span_predictor(
conv_channels: int = 4, conv_channels: int = 4,
window_size: int = 1, window_size: int = 1,
max_distance: int = 128, max_distance: int = 128,
prefix: str = "coref_head_clusters" prefix: str = "coref_head_clusters",
): ):
# TODO add model return types # TODO add model return types
# TODO fix this # TODO fix this
@ -36,7 +36,7 @@ def build_span_predictor(
distance_embedding_size, distance_embedding_size,
conv_channels, conv_channels,
window_size, window_size,
max_distance max_distance,
), ),
convert_inputs=convert_span_predictor_inputs, convert_inputs=convert_span_predictor_inputs,
) )
@ -139,14 +139,11 @@ class SpanPredictor(torch.nn.Module):
dist_emb_size: int, dist_emb_size: int,
conv_channels: int, conv_channels: int,
window_size: int, window_size: int,
max_distance: int max_distance: int,
): ):
super().__init__() super().__init__()
if max_distance % 2 != 0: if max_distance % 2 != 0:
raise ValueError( raise ValueError("max_distance has to be an even number")
"max_distance has to be an even number"
)
# input size = single token size # input size = single token size
# 64 = probably distance emb size # 64 = probably distance emb size
# TODO check that dist_emb_size use is correct # TODO check that dist_emb_size use is correct
@ -164,7 +161,7 @@ class SpanPredictor(torch.nn.Module):
kernel_size = window_size * 2 + 1 kernel_size = window_size * 2 + 1
self.conv = torch.nn.Sequential( self.conv = torch.nn.Sequential(
torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1), torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1),
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1) torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1),
) )
# TODO make embeddings size a parameter # TODO make embeddings size a parameter
self.max_distance = max_distance self.max_distance = max_distance