mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
Black formatting
This commit is contained in:
parent
2a8efda689
commit
838f50192b
|
@ -46,11 +46,7 @@ def build_wl_coref_model(
|
|||
return coref_model
|
||||
|
||||
|
||||
def convert_coref_clusterer_inputs(
|
||||
model: Model,
|
||||
X: List[Floats2d],
|
||||
is_train: bool
|
||||
):
|
||||
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
|
||||
# The input here is List[Floats2d], one for each doc
|
||||
# just use the first
|
||||
# TODO real batching
|
||||
|
@ -115,12 +111,7 @@ class CorefClusterer(torch.nn.Module):
|
|||
# Modules
|
||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
|
||||
pair_emb = dim * 3 + self.pw.shape
|
||||
self.a_scorer = AnaphoricityScorer(
|
||||
pair_emb,
|
||||
hidden_size,
|
||||
n_layers,
|
||||
dropout
|
||||
)
|
||||
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
|
||||
self.lstm = torch.nn.LSTM(
|
||||
input_size=dim,
|
||||
hidden_size=dim,
|
||||
|
@ -129,13 +120,9 @@ class CorefClusterer(torch.nn.Module):
|
|||
self.rough_scorer = RoughScorer(dim, dropout, roughk)
|
||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
|
||||
pair_emb = dim * 3 + self.pw.shape
|
||||
self.a_scorer = AnaphoricityScorer(
|
||||
pair_emb, hidden_size, n_layers, dropout
|
||||
)
|
||||
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
|
||||
|
||||
def forward(
|
||||
self, word_features: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
1. LSTM encodes the incoming word_features.
|
||||
2. The RoughScorer scores and prunes the candidates.
|
||||
|
@ -350,13 +337,9 @@ class DistancePairwiseEncoder(torch.nn.Module):
|
|||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.shape = emb_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
top_indices: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
def forward(self, top_indices: torch.Tensor) -> torch.Tensor:
|
||||
word_ids = torch.arange(0, top_indices.size(0))
|
||||
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
|
||||
).clamp_min_(min=1)
|
||||
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1)
|
||||
log_distance = distance.to(torch.float).log2().floor_()
|
||||
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
|
||||
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
|
||||
|
|
|
@ -7,6 +7,7 @@ MentionClusters = List[List[Tuple[int, int]]]
|
|||
|
||||
DEFAULT_CLUSTER_PREFIX = "coref_clusters"
|
||||
|
||||
|
||||
class GraphNode:
|
||||
def __init__(self, node_id: int):
|
||||
self.id = node_id
|
||||
|
@ -30,6 +31,7 @@ def get_sentence_ids(doc):
|
|||
out.append(sent_id)
|
||||
return out
|
||||
|
||||
|
||||
def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
|
||||
"""Given a doc, give the mention clusters.
|
||||
|
||||
|
@ -100,7 +102,6 @@ def get_predicted_clusters(
|
|||
return predicted_clusters
|
||||
|
||||
|
||||
|
||||
def select_non_crossing_spans(
|
||||
idxs: List[int], starts: List[int], ends: List[int], limit: int
|
||||
) -> List[int]:
|
||||
|
@ -150,12 +151,14 @@ def select_non_crossing_spans(
|
|||
# selected.append(selected[0]) # this seems a bit weird?
|
||||
return selected
|
||||
|
||||
|
||||
def create_head_span_idxs(ops, doclen: int):
|
||||
"""Helper function to create single-token span indices."""
|
||||
aa = ops.xp.arange(0, doclen)
|
||||
bb = ops.xp.arange(0, doclen) + 1
|
||||
return ops.asarray2i([aa, bb]).T
|
||||
|
||||
|
||||
def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
|
||||
"""Given a Doc, convert the cluster spans to simple int tuple lists."""
|
||||
out = []
|
||||
|
|
|
@ -18,7 +18,7 @@ def build_span_predictor(
|
|||
conv_channels: int = 4,
|
||||
window_size: int = 1,
|
||||
max_distance: int = 128,
|
||||
prefix: str = "coref_head_clusters"
|
||||
prefix: str = "coref_head_clusters",
|
||||
):
|
||||
# TODO add model return types
|
||||
# TODO fix this
|
||||
|
@ -36,7 +36,7 @@ def build_span_predictor(
|
|||
distance_embedding_size,
|
||||
conv_channels,
|
||||
window_size,
|
||||
max_distance
|
||||
max_distance,
|
||||
),
|
||||
convert_inputs=convert_span_predictor_inputs,
|
||||
)
|
||||
|
@ -139,14 +139,11 @@ class SpanPredictor(torch.nn.Module):
|
|||
dist_emb_size: int,
|
||||
conv_channels: int,
|
||||
window_size: int,
|
||||
max_distance: int
|
||||
|
||||
max_distance: int,
|
||||
):
|
||||
super().__init__()
|
||||
if max_distance % 2 != 0:
|
||||
raise ValueError(
|
||||
"max_distance has to be an even number"
|
||||
)
|
||||
raise ValueError("max_distance has to be an even number")
|
||||
# input size = single token size
|
||||
# 64 = probably distance emb size
|
||||
# TODO check that dist_emb_size use is correct
|
||||
|
@ -164,7 +161,7 @@ class SpanPredictor(torch.nn.Module):
|
|||
kernel_size = window_size * 2 + 1
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1),
|
||||
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1)
|
||||
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1),
|
||||
)
|
||||
# TODO make embeddings size a parameter
|
||||
self.max_distance = max_distance
|
||||
|
|
Loading…
Reference in New Issue
Block a user