mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
Black formatting
This commit is contained in:
parent
2a8efda689
commit
838f50192b
|
@ -46,11 +46,7 @@ def build_wl_coref_model(
|
||||||
return coref_model
|
return coref_model
|
||||||
|
|
||||||
|
|
||||||
def convert_coref_clusterer_inputs(
|
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
|
||||||
model: Model,
|
|
||||||
X: List[Floats2d],
|
|
||||||
is_train: bool
|
|
||||||
):
|
|
||||||
# The input here is List[Floats2d], one for each doc
|
# The input here is List[Floats2d], one for each doc
|
||||||
# just use the first
|
# just use the first
|
||||||
# TODO real batching
|
# TODO real batching
|
||||||
|
@ -62,7 +58,7 @@ def convert_coref_clusterer_inputs(
|
||||||
gradients = torch2xp(args.args[0])
|
gradients = torch2xp(args.args[0])
|
||||||
return [gradients]
|
return [gradients]
|
||||||
|
|
||||||
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
|
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
|
||||||
|
|
||||||
|
|
||||||
def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool):
|
def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool):
|
||||||
|
@ -115,12 +111,7 @@ class CorefClusterer(torch.nn.Module):
|
||||||
# Modules
|
# Modules
|
||||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
|
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
|
||||||
pair_emb = dim * 3 + self.pw.shape
|
pair_emb = dim * 3 + self.pw.shape
|
||||||
self.a_scorer = AnaphoricityScorer(
|
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
|
||||||
pair_emb,
|
|
||||||
hidden_size,
|
|
||||||
n_layers,
|
|
||||||
dropout
|
|
||||||
)
|
|
||||||
self.lstm = torch.nn.LSTM(
|
self.lstm = torch.nn.LSTM(
|
||||||
input_size=dim,
|
input_size=dim,
|
||||||
hidden_size=dim,
|
hidden_size=dim,
|
||||||
|
@ -129,13 +120,9 @@ class CorefClusterer(torch.nn.Module):
|
||||||
self.rough_scorer = RoughScorer(dim, dropout, roughk)
|
self.rough_scorer = RoughScorer(dim, dropout, roughk)
|
||||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
|
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
|
||||||
pair_emb = dim * 3 + self.pw.shape
|
pair_emb = dim * 3 + self.pw.shape
|
||||||
self.a_scorer = AnaphoricityScorer(
|
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
|
||||||
pair_emb, hidden_size, n_layers, dropout
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
self, word_features: torch.Tensor
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
"""
|
||||||
1. LSTM encodes the incoming word_features.
|
1. LSTM encodes the incoming word_features.
|
||||||
2. The RoughScorer scores and prunes the candidates.
|
2. The RoughScorer scores and prunes the candidates.
|
||||||
|
@ -350,13 +337,9 @@ class DistancePairwiseEncoder(torch.nn.Module):
|
||||||
self.dropout = torch.nn.Dropout(dropout)
|
self.dropout = torch.nn.Dropout(dropout)
|
||||||
self.shape = emb_size
|
self.shape = emb_size
|
||||||
|
|
||||||
def forward(
|
def forward(self, top_indices: torch.Tensor) -> torch.Tensor:
|
||||||
self,
|
|
||||||
top_indices: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
word_ids = torch.arange(0, top_indices.size(0))
|
word_ids = torch.arange(0, top_indices.size(0))
|
||||||
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
|
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1)
|
||||||
).clamp_min_(min=1)
|
|
||||||
log_distance = distance.to(torch.float).log2().floor_()
|
log_distance = distance.to(torch.float).log2().floor_()
|
||||||
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
|
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
|
||||||
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
|
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
|
||||||
|
|
|
@ -7,6 +7,7 @@ MentionClusters = List[List[Tuple[int, int]]]
|
||||||
|
|
||||||
DEFAULT_CLUSTER_PREFIX = "coref_clusters"
|
DEFAULT_CLUSTER_PREFIX = "coref_clusters"
|
||||||
|
|
||||||
|
|
||||||
class GraphNode:
|
class GraphNode:
|
||||||
def __init__(self, node_id: int):
|
def __init__(self, node_id: int):
|
||||||
self.id = node_id
|
self.id = node_id
|
||||||
|
@ -30,6 +31,7 @@ def get_sentence_ids(doc):
|
||||||
out.append(sent_id)
|
out.append(sent_id)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
|
def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
|
||||||
"""Given a doc, give the mention clusters.
|
"""Given a doc, give the mention clusters.
|
||||||
|
|
||||||
|
@ -100,7 +102,6 @@ def get_predicted_clusters(
|
||||||
return predicted_clusters
|
return predicted_clusters
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def select_non_crossing_spans(
|
def select_non_crossing_spans(
|
||||||
idxs: List[int], starts: List[int], ends: List[int], limit: int
|
idxs: List[int], starts: List[int], ends: List[int], limit: int
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
|
@ -150,12 +151,14 @@ def select_non_crossing_spans(
|
||||||
# selected.append(selected[0]) # this seems a bit weird?
|
# selected.append(selected[0]) # this seems a bit weird?
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
|
|
||||||
def create_head_span_idxs(ops, doclen: int):
|
def create_head_span_idxs(ops, doclen: int):
|
||||||
"""Helper function to create single-token span indices."""
|
"""Helper function to create single-token span indices."""
|
||||||
aa = ops.xp.arange(0, doclen)
|
aa = ops.xp.arange(0, doclen)
|
||||||
bb = ops.xp.arange(0, doclen) + 1
|
bb = ops.xp.arange(0, doclen) + 1
|
||||||
return ops.asarray2i([aa, bb]).T
|
return ops.asarray2i([aa, bb]).T
|
||||||
|
|
||||||
|
|
||||||
def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
|
def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
|
||||||
"""Given a Doc, convert the cluster spans to simple int tuple lists."""
|
"""Given a Doc, convert the cluster spans to simple int tuple lists."""
|
||||||
out = []
|
out = []
|
||||||
|
@ -163,10 +166,10 @@ def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
|
||||||
cluster = []
|
cluster = []
|
||||||
for span in val:
|
for span in val:
|
||||||
# TODO check that there isn't an off-by-one error here
|
# TODO check that there isn't an off-by-one error here
|
||||||
#cluster.append((span.start, span.end))
|
# cluster.append((span.start, span.end))
|
||||||
# TODO This conversion should be happening earlier in processing
|
# TODO This conversion should be happening earlier in processing
|
||||||
head_i = span.root.i
|
head_i = span.root.i
|
||||||
cluster.append( (head_i, head_i + 1) )
|
cluster.append((head_i, head_i + 1))
|
||||||
|
|
||||||
# don't want duplicates
|
# don't want duplicates
|
||||||
cluster = list(set(cluster))
|
cluster = list(set(cluster))
|
||||||
|
|
|
@ -18,7 +18,7 @@ def build_span_predictor(
|
||||||
conv_channels: int = 4,
|
conv_channels: int = 4,
|
||||||
window_size: int = 1,
|
window_size: int = 1,
|
||||||
max_distance: int = 128,
|
max_distance: int = 128,
|
||||||
prefix: str = "coref_head_clusters"
|
prefix: str = "coref_head_clusters",
|
||||||
):
|
):
|
||||||
# TODO add model return types
|
# TODO add model return types
|
||||||
# TODO fix this
|
# TODO fix this
|
||||||
|
@ -36,7 +36,7 @@ def build_span_predictor(
|
||||||
distance_embedding_size,
|
distance_embedding_size,
|
||||||
conv_channels,
|
conv_channels,
|
||||||
window_size,
|
window_size,
|
||||||
max_distance
|
max_distance,
|
||||||
),
|
),
|
||||||
convert_inputs=convert_span_predictor_inputs,
|
convert_inputs=convert_span_predictor_inputs,
|
||||||
)
|
)
|
||||||
|
@ -133,20 +133,17 @@ def head_data_forward(model, docs, is_train):
|
||||||
# TODO this should maybe have a different name from the component
|
# TODO this should maybe have a different name from the component
|
||||||
class SpanPredictor(torch.nn.Module):
|
class SpanPredictor(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_size: int,
|
input_size: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
dist_emb_size: int,
|
dist_emb_size: int,
|
||||||
conv_channels: int,
|
conv_channels: int,
|
||||||
window_size: int,
|
window_size: int,
|
||||||
max_distance: int
|
max_distance: int,
|
||||||
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if max_distance % 2 != 0:
|
if max_distance % 2 != 0:
|
||||||
raise ValueError(
|
raise ValueError("max_distance has to be an even number")
|
||||||
"max_distance has to be an even number"
|
|
||||||
)
|
|
||||||
# input size = single token size
|
# input size = single token size
|
||||||
# 64 = probably distance emb size
|
# 64 = probably distance emb size
|
||||||
# TODO check that dist_emb_size use is correct
|
# TODO check that dist_emb_size use is correct
|
||||||
|
@ -164,7 +161,7 @@ class SpanPredictor(torch.nn.Module):
|
||||||
kernel_size = window_size * 2 + 1
|
kernel_size = window_size * 2 + 1
|
||||||
self.conv = torch.nn.Sequential(
|
self.conv = torch.nn.Sequential(
|
||||||
torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1),
|
torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1),
|
||||||
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1)
|
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1),
|
||||||
)
|
)
|
||||||
# TODO make embeddings size a parameter
|
# TODO make embeddings size a parameter
|
||||||
self.max_distance = max_distance
|
self.max_distance = max_distance
|
||||||
|
|
Loading…
Reference in New Issue
Block a user