Code review suggestions, cleanup

This commit is contained in:
Paul O'Leary McCann 2022-05-25 19:18:26 +09:00
parent e721c7bed8
commit 2a8efda689
3 changed files with 11 additions and 11 deletions

View File

@ -9,14 +9,14 @@ def get_cluster_info(predicted_clusters, gold_clusters):
return (gold_clusters, predicted_clusters, g2p, p2g) return (gold_clusters, predicted_clusters, g2p, p2g)
def get_markable_assignments(inp_clusters, out_clusters): def get_markable_assignments(in_clusters, out_clusters):
markable_cluster_ids = {} markable_cluster_ids = {}
out_dic = {} out_dic = {}
for cluster_id, cluster in enumerate(out_clusters): for cluster_id, cluster in enumerate(out_clusters):
for m in cluster: for m in cluster:
out_dic[m] = cluster_id out_dic[m] = cluster_id
for cluster in inp_clusters: for cluster in in_clusters:
for im in cluster: for im in cluster:
for om in out_dic: for om in out_dic:
if im == om: if im == om:

View File

@ -29,8 +29,8 @@ def build_wl_coref_model(
dim = 768 dim = 768
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
coref_scorer = PyTorchWrapper( coref_clusterer = PyTorchWrapper(
CorefScorer( CorefClusterer(
dim, dim,
distance_embedding_size, distance_embedding_size,
hidden_size, hidden_size,
@ -39,14 +39,14 @@ def build_wl_coref_model(
antecedent_limit, antecedent_limit,
antecedent_batch_size, antecedent_batch_size,
), ),
convert_inputs=convert_coref_scorer_inputs, convert_inputs=convert_coref_clusterer_inputs,
convert_outputs=convert_coref_scorer_outputs, convert_outputs=convert_coref_clusterer_outputs,
) )
coref_model = tok2vec >> coref_scorer coref_model = tok2vec >> coref_clusterer
return coref_model return coref_model
def convert_coref_scorer_inputs( def convert_coref_clusterer_inputs(
model: Model, model: Model,
X: List[Floats2d], X: List[Floats2d],
is_train: bool is_train: bool
@ -65,7 +65,7 @@ def convert_coref_scorer_inputs(
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool): def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool):
_, outputs = inputs_outputs _, outputs = inputs_outputs
scores, indices = outputs scores, indices = outputs
@ -81,7 +81,7 @@ def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool):
return (scores_xp, indices_xp), convert_for_torch_backward return (scores_xp, indices_xp), convert_for_torch_backward
class CorefScorer(torch.nn.Module): class CorefClusterer(torch.nn.Module):
""" """
Combines all coref modules together to find coreferent token pairs. Combines all coref modules together to find coreferent token pairs.
Submodules (in the order of their usage in the pipeline): Submodules (in the order of their usage in the pipeline):

View File

@ -48,7 +48,7 @@ def build_span_predictor(
def convert_span_predictor_inputs( def convert_span_predictor_inputs(
model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool model: Model, X: Tuple[Ints1d, Tuple[Floats2d, Ints1d]], is_train: bool
): ):
tok2vec, (sent_ids, head_ids) = X tok2vec, (sent_ids, head_ids) = X
# Normally we should use the input is_train, but for these two it's not relevant # Normally we should use the input is_train, but for these two it's not relevant