mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
Code review suggestions, cleanup
This commit is contained in:
parent
e721c7bed8
commit
2a8efda689
|
@ -9,14 +9,14 @@ def get_cluster_info(predicted_clusters, gold_clusters):
|
||||||
return (gold_clusters, predicted_clusters, g2p, p2g)
|
return (gold_clusters, predicted_clusters, g2p, p2g)
|
||||||
|
|
||||||
|
|
||||||
def get_markable_assignments(inp_clusters, out_clusters):
|
def get_markable_assignments(in_clusters, out_clusters):
|
||||||
markable_cluster_ids = {}
|
markable_cluster_ids = {}
|
||||||
out_dic = {}
|
out_dic = {}
|
||||||
for cluster_id, cluster in enumerate(out_clusters):
|
for cluster_id, cluster in enumerate(out_clusters):
|
||||||
for m in cluster:
|
for m in cluster:
|
||||||
out_dic[m] = cluster_id
|
out_dic[m] = cluster_id
|
||||||
|
|
||||||
for cluster in inp_clusters:
|
for cluster in in_clusters:
|
||||||
for im in cluster:
|
for im in cluster:
|
||||||
for om in out_dic:
|
for om in out_dic:
|
||||||
if im == om:
|
if im == om:
|
||||||
|
|
|
@ -29,8 +29,8 @@ def build_wl_coref_model(
|
||||||
dim = 768
|
dim = 768
|
||||||
|
|
||||||
with Model.define_operators({">>": chain}):
|
with Model.define_operators({">>": chain}):
|
||||||
coref_scorer = PyTorchWrapper(
|
coref_clusterer = PyTorchWrapper(
|
||||||
CorefScorer(
|
CorefClusterer(
|
||||||
dim,
|
dim,
|
||||||
distance_embedding_size,
|
distance_embedding_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
@ -39,14 +39,14 @@ def build_wl_coref_model(
|
||||||
antecedent_limit,
|
antecedent_limit,
|
||||||
antecedent_batch_size,
|
antecedent_batch_size,
|
||||||
),
|
),
|
||||||
convert_inputs=convert_coref_scorer_inputs,
|
convert_inputs=convert_coref_clusterer_inputs,
|
||||||
convert_outputs=convert_coref_scorer_outputs,
|
convert_outputs=convert_coref_clusterer_outputs,
|
||||||
)
|
)
|
||||||
coref_model = tok2vec >> coref_scorer
|
coref_model = tok2vec >> coref_clusterer
|
||||||
return coref_model
|
return coref_model
|
||||||
|
|
||||||
|
|
||||||
def convert_coref_scorer_inputs(
|
def convert_coref_clusterer_inputs(
|
||||||
model: Model,
|
model: Model,
|
||||||
X: List[Floats2d],
|
X: List[Floats2d],
|
||||||
is_train: bool
|
is_train: bool
|
||||||
|
@ -65,7 +65,7 @@ def convert_coref_scorer_inputs(
|
||||||
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
|
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
|
||||||
|
|
||||||
|
|
||||||
def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool):
|
def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool):
|
||||||
_, outputs = inputs_outputs
|
_, outputs = inputs_outputs
|
||||||
scores, indices = outputs
|
scores, indices = outputs
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool):
|
||||||
return (scores_xp, indices_xp), convert_for_torch_backward
|
return (scores_xp, indices_xp), convert_for_torch_backward
|
||||||
|
|
||||||
|
|
||||||
class CorefScorer(torch.nn.Module):
|
class CorefClusterer(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Combines all coref modules together to find coreferent token pairs.
|
Combines all coref modules together to find coreferent token pairs.
|
||||||
Submodules (in the order of their usage in the pipeline):
|
Submodules (in the order of their usage in the pipeline):
|
||||||
|
|
|
@ -48,7 +48,7 @@ def build_span_predictor(
|
||||||
|
|
||||||
|
|
||||||
def convert_span_predictor_inputs(
|
def convert_span_predictor_inputs(
|
||||||
model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool
|
model: Model, X: Tuple[Ints1d, Tuple[Floats2d, Ints1d]], is_train: bool
|
||||||
):
|
):
|
||||||
tok2vec, (sent_ids, head_ids) = X
|
tok2vec, (sent_ids, head_ids) = X
|
||||||
# Normally we should use the input is_train, but for these two it's not relevant
|
# Normally we should use the input is_train, but for these two it's not relevant
|
||||||
|
|
Loading…
Reference in New Issue
Block a user