diff --git a/spacy/ml/models/coref_util_wl.py b/spacy/ml/models/coref_util_wl.py index 55326c217..20a5f40c4 100644 --- a/spacy/ml/models/coref_util_wl.py +++ b/spacy/ml/models/coref_util_wl.py @@ -4,6 +4,7 @@ from typing import List, Set, Dict, Tuple from thinc.types import Ints1d from dataclasses import dataclass from ...tokens import Doc +from ...language import Language import torch @@ -37,37 +38,7 @@ def add_dummy(tensor: torch.Tensor, eps: bool = False): output = torch.cat((dummy, tensor), dim=1) return output -def make_head_only_clusters(examples): - """Replace coref clusters with head-only clusters. - This destructively modifies the docs. - """ - - #TODO what if all clusters are eliminated? - for eg in examples: - final = [] # save out clusters here - for key, sg in eg.reference.spans.items(): - if not key.startswith("coref_clusters_"): - continue - - heads = [span.root.i for span in sg] - heads = list(set(heads)) - head_spans = [eg.reference[hh:hh+1] for hh in heads] - if len(heads) > 1: - final.append(head_spans) - - # now delete the existing clusters - keys = list(eg.reference.spans.keys()) - for key in keys: - if not key.startswith("coref_clusters_"): - continue - - del eg.reference.spans[key] - - # now add the new spangroups - for ii, spans in enumerate(final): - #TODO support alternate keys - eg.reference.spans[f"coref_clusters_{ii}"] = spans # TODO replace with spaCy config @dataclass diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index 0c42ac94a..db93051d7 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -25,8 +25,6 @@ from ..ml.models.coref_util import ( doc2clusters, ) -from ..ml.models.coref_util_wl import make_head_only_clusters - from ..coref_scorer import Evaluator, get_cluster_info, b_cubed, muc, ceafe # TODO remove this - kept for reference for now @@ -93,6 +91,31 @@ DEFAULT_MODEL = Config().from_str(default_config)["model"] DEFAULT_CLUSTERS_PREFIX = "coref_clusters" +@Language.component("span2head") +def make_head_only_clusters(doc, old_key="coref_clusters", new_key="coref_head_clusters"): + """Create coref head clusters from span clusters. + + The old clusters are left alone, and the new clusters are added under a different key. + """ + final = [] + for key, sg in doc.spans.items(): + if not key.startswith("{old_key}_"): + continue + + heads = [span.root.i for span in sg] + heads = sorted(list(set(heads))) + head_spans = [doc[hh:hh+1] for hh in heads] + #print("===== headifying =====") + #print(sg) + #print(head_spans) + # singletons are skipped + if len(heads) > 1: + final.append(head_spans) + + # now add the new spangroups + for ii, spans in enumerate(final): + doc.spans[f"{new_key}_{ii}"] = spans + return doc @Language.factory( "coref", @@ -237,8 +260,6 @@ class CoreferenceResolver(TrainablePipe): return losses set_dropout_rate(self.model, drop) - make_head_only_clusters(examples) - inputs = [example.predicted for example in examples] preds, backprop = self.model.begin_update(inputs) score_matrix, mention_idx = preds @@ -399,7 +420,6 @@ class CoreferenceResolver(TrainablePipe): def score(self, examples, **kwargs): """Score a batch of examples.""" - make_head_only_clusters(examples) # NOTE traditionally coref uses the average of b_cubed, muc, and ceaf. # we need to handle the average ourselves. scores = []