Make span2head component

This commit is contained in:
Paul O'Leary McCann 2022-03-15 19:19:15 +09:00
parent e6917d8dc4
commit 0522a43116
2 changed files with 26 additions and 35 deletions

View File

@ -4,6 +4,7 @@ from typing import List, Set, Dict, Tuple
from thinc.types import Ints1d
from dataclasses import dataclass
from ...tokens import Doc
from ...language import Language
import torch
@ -37,37 +38,7 @@ def add_dummy(tensor: torch.Tensor, eps: bool = False):
output = torch.cat((dummy, tensor), dim=1)
return output
def make_head_only_clusters(examples):
"""Replace coref clusters with head-only clusters.
This destructively modifies the docs.
"""
#TODO what if all clusters are eliminated?
for eg in examples:
final = [] # save out clusters here
for key, sg in eg.reference.spans.items():
if not key.startswith("coref_clusters_"):
continue
heads = [span.root.i for span in sg]
heads = list(set(heads))
head_spans = [eg.reference[hh:hh+1] for hh in heads]
if len(heads) > 1:
final.append(head_spans)
# now delete the existing clusters
keys = list(eg.reference.spans.keys())
for key in keys:
if not key.startswith("coref_clusters_"):
continue
del eg.reference.spans[key]
# now add the new spangroups
for ii, spans in enumerate(final):
#TODO support alternate keys
eg.reference.spans[f"coref_clusters_{ii}"] = spans
# TODO replace with spaCy config
@dataclass

View File

@ -25,8 +25,6 @@ from ..ml.models.coref_util import (
doc2clusters,
)
from ..ml.models.coref_util_wl import make_head_only_clusters
from ..coref_scorer import Evaluator, get_cluster_info, b_cubed, muc, ceafe
# TODO remove this - kept for reference for now
@ -93,6 +91,31 @@ DEFAULT_MODEL = Config().from_str(default_config)["model"]
DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
@Language.component("span2head")
def make_head_only_clusters(doc, old_key="coref_clusters", new_key="coref_head_clusters"):
"""Create coref head clusters from span clusters.
The old clusters are left alone, and the new clusters are added under a different key.
"""
final = []
for key, sg in doc.spans.items():
if not key.startswith("{old_key}_"):
continue
heads = [span.root.i for span in sg]
heads = sorted(list(set(heads)))
head_spans = [doc[hh:hh+1] for hh in heads]
#print("===== headifying =====")
#print(sg)
#print(head_spans)
# singletons are skipped
if len(heads) > 1:
final.append(head_spans)
# now add the new spangroups
for ii, spans in enumerate(final):
doc.spans[f"{new_key}_{ii}"] = spans
return doc
@Language.factory(
"coref",
@ -237,8 +260,6 @@ class CoreferenceResolver(TrainablePipe):
return losses
set_dropout_rate(self.model, drop)
make_head_only_clusters(examples)
inputs = [example.predicted for example in examples]
preds, backprop = self.model.begin_update(inputs)
score_matrix, mention_idx = preds
@ -399,7 +420,6 @@ class CoreferenceResolver(TrainablePipe):
def score(self, examples, **kwargs):
"""Score a batch of examples."""
make_head_only_clusters(examples)
# NOTE traditionally coref uses the average of b_cubed, muc, and ceaf.
# we need to handle the average ourselves.
scores = []