Make span2head component

This commit is contained in:
Paul O'Leary McCann 2022-03-15 19:19:15 +09:00
parent e6917d8dc4
commit 0522a43116
2 changed files with 26 additions and 35 deletions

View File

@ -4,6 +4,7 @@ from typing import List, Set, Dict, Tuple
from thinc.types import Ints1d from thinc.types import Ints1d
from dataclasses import dataclass from dataclasses import dataclass
from ...tokens import Doc from ...tokens import Doc
from ...language import Language
import torch import torch
@ -37,37 +38,7 @@ def add_dummy(tensor: torch.Tensor, eps: bool = False):
output = torch.cat((dummy, tensor), dim=1) output = torch.cat((dummy, tensor), dim=1)
return output return output
def make_head_only_clusters(examples):
"""Replace coref clusters with head-only clusters.
This destructively modifies the docs.
"""
#TODO what if all clusters are eliminated?
for eg in examples:
final = [] # save out clusters here
for key, sg in eg.reference.spans.items():
if not key.startswith("coref_clusters_"):
continue
heads = [span.root.i for span in sg]
heads = list(set(heads))
head_spans = [eg.reference[hh:hh+1] for hh in heads]
if len(heads) > 1:
final.append(head_spans)
# now delete the existing clusters
keys = list(eg.reference.spans.keys())
for key in keys:
if not key.startswith("coref_clusters_"):
continue
del eg.reference.spans[key]
# now add the new spangroups
for ii, spans in enumerate(final):
#TODO support alternate keys
eg.reference.spans[f"coref_clusters_{ii}"] = spans
# TODO replace with spaCy config # TODO replace with spaCy config
@dataclass @dataclass

View File

@ -25,8 +25,6 @@ from ..ml.models.coref_util import (
doc2clusters, doc2clusters,
) )
from ..ml.models.coref_util_wl import make_head_only_clusters
from ..coref_scorer import Evaluator, get_cluster_info, b_cubed, muc, ceafe from ..coref_scorer import Evaluator, get_cluster_info, b_cubed, muc, ceafe
# TODO remove this - kept for reference for now # TODO remove this - kept for reference for now
@ -93,6 +91,31 @@ DEFAULT_MODEL = Config().from_str(default_config)["model"]
DEFAULT_CLUSTERS_PREFIX = "coref_clusters" DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
@Language.component("span2head")
def make_head_only_clusters(doc, old_key="coref_clusters", new_key="coref_head_clusters"):
"""Create coref head clusters from span clusters.
The old clusters are left alone, and the new clusters are added under a different key.
"""
final = []
for key, sg in doc.spans.items():
if not key.startswith("{old_key}_"):
continue
heads = [span.root.i for span in sg]
heads = sorted(list(set(heads)))
head_spans = [doc[hh:hh+1] for hh in heads]
#print("===== headifying =====")
#print(sg)
#print(head_spans)
# singletons are skipped
if len(heads) > 1:
final.append(head_spans)
# now add the new spangroups
for ii, spans in enumerate(final):
doc.spans[f"{new_key}_{ii}"] = spans
return doc
@Language.factory( @Language.factory(
"coref", "coref",
@ -237,8 +260,6 @@ class CoreferenceResolver(TrainablePipe):
return losses return losses
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
make_head_only_clusters(examples)
inputs = [example.predicted for example in examples] inputs = [example.predicted for example in examples]
preds, backprop = self.model.begin_update(inputs) preds, backprop = self.model.begin_update(inputs)
score_matrix, mention_idx = preds score_matrix, mention_idx = preds
@ -399,7 +420,6 @@ class CoreferenceResolver(TrainablePipe):
def score(self, examples, **kwargs): def score(self, examples, **kwargs):
"""Score a batch of examples.""" """Score a batch of examples."""
make_head_only_clusters(examples)
# NOTE traditionally coref uses the average of b_cubed, muc, and ceaf. # NOTE traditionally coref uses the average of b_cubed, muc, and ceaf.
# we need to handle the average ourselves. # we need to handle the average ourselves.
scores = [] scores = []