diff --git a/spacy/coref_scorer.py b/spacy/coref_scorer.py index b266ec3b3..981b1cf03 100644 --- a/spacy/coref_scorer.py +++ b/spacy/coref_scorer.py @@ -9,14 +9,14 @@ def get_cluster_info(predicted_clusters, gold_clusters): return (gold_clusters, predicted_clusters, g2p, p2g) -def get_markable_assignments(inp_clusters, out_clusters): +def get_markable_assignments(in_clusters, out_clusters): markable_cluster_ids = {} out_dic = {} for cluster_id, cluster in enumerate(out_clusters): for m in cluster: out_dic[m] = cluster_id - for cluster in inp_clusters: + for cluster in in_clusters: for im in cluster: for om in out_dic: if im == om: diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index ca9011577..4be22dd96 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -29,8 +29,8 @@ def build_wl_coref_model( dim = 768 with Model.define_operators({">>": chain}): - coref_scorer = PyTorchWrapper( - CorefScorer( + coref_clusterer = PyTorchWrapper( + CorefClusterer( dim, distance_embedding_size, hidden_size, @@ -39,14 +39,14 @@ def build_wl_coref_model( antecedent_limit, antecedent_batch_size, ), - convert_inputs=convert_coref_scorer_inputs, - convert_outputs=convert_coref_scorer_outputs, + convert_inputs=convert_coref_clusterer_inputs, + convert_outputs=convert_coref_clusterer_outputs, ) - coref_model = tok2vec >> coref_scorer + coref_model = tok2vec >> coref_clusterer return coref_model -def convert_coref_scorer_inputs( +def convert_coref_clusterer_inputs( model: Model, X: List[Floats2d], is_train: bool @@ -65,7 +65,7 @@ def convert_coref_scorer_inputs( return ArgsKwargs(args=(word_features, ), kwargs={}), backprop -def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool): +def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool): _, outputs = inputs_outputs scores, indices = outputs @@ -81,7 +81,7 @@ def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool): return (scores_xp, indices_xp), convert_for_torch_backward -class CorefScorer(torch.nn.Module): +class CorefClusterer(torch.nn.Module): """ Combines all coref modules together to find coreferent token pairs. Submodules (in the order of their usage in the pipeline): diff --git a/spacy/ml/models/span_predictor.py b/spacy/ml/models/span_predictor.py index 1ded9c3c7..03101edf9 100644 --- a/spacy/ml/models/span_predictor.py +++ b/spacy/ml/models/span_predictor.py @@ -48,7 +48,7 @@ def build_span_predictor( def convert_span_predictor_inputs( - model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool + model: Model, X: Tuple[Ints1d, Tuple[Floats2d, Ints1d]], is_train: bool ): tok2vec, (sent_ids, head_ids) = X # Normally we should use the input is_train, but for these two it's not relevant