mirror of
https://github.com/explosion/spaCy.git
synced 2025-10-24 12:41:23 +03:00
Make span2head component
This commit is contained in:
parent
e6917d8dc4
commit
0522a43116
|
@ -4,6 +4,7 @@ from typing import List, Set, Dict, Tuple
|
|||
from thinc.types import Ints1d
|
||||
from dataclasses import dataclass
|
||||
from ...tokens import Doc
|
||||
from ...language import Language
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -37,37 +38,7 @@ def add_dummy(tensor: torch.Tensor, eps: bool = False):
|
|||
output = torch.cat((dummy, tensor), dim=1)
|
||||
return output
|
||||
|
||||
def make_head_only_clusters(examples):
|
||||
"""Replace coref clusters with head-only clusters.
|
||||
|
||||
This destructively modifies the docs.
|
||||
"""
|
||||
|
||||
#TODO what if all clusters are eliminated?
|
||||
for eg in examples:
|
||||
final = [] # save out clusters here
|
||||
for key, sg in eg.reference.spans.items():
|
||||
if not key.startswith("coref_clusters_"):
|
||||
continue
|
||||
|
||||
heads = [span.root.i for span in sg]
|
||||
heads = list(set(heads))
|
||||
head_spans = [eg.reference[hh:hh+1] for hh in heads]
|
||||
if len(heads) > 1:
|
||||
final.append(head_spans)
|
||||
|
||||
# now delete the existing clusters
|
||||
keys = list(eg.reference.spans.keys())
|
||||
for key in keys:
|
||||
if not key.startswith("coref_clusters_"):
|
||||
continue
|
||||
|
||||
del eg.reference.spans[key]
|
||||
|
||||
# now add the new spangroups
|
||||
for ii, spans in enumerate(final):
|
||||
#TODO support alternate keys
|
||||
eg.reference.spans[f"coref_clusters_{ii}"] = spans
|
||||
|
||||
# TODO replace with spaCy config
|
||||
@dataclass
|
||||
|
|
|
@ -25,8 +25,6 @@ from ..ml.models.coref_util import (
|
|||
doc2clusters,
|
||||
)
|
||||
|
||||
from ..ml.models.coref_util_wl import make_head_only_clusters
|
||||
|
||||
from ..coref_scorer import Evaluator, get_cluster_info, b_cubed, muc, ceafe
|
||||
|
||||
# TODO remove this - kept for reference for now
|
||||
|
@ -93,6 +91,31 @@ DEFAULT_MODEL = Config().from_str(default_config)["model"]
|
|||
|
||||
DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
|
||||
|
||||
@Language.component("span2head")
|
||||
def make_head_only_clusters(doc, old_key="coref_clusters", new_key="coref_head_clusters"):
|
||||
"""Create coref head clusters from span clusters.
|
||||
|
||||
The old clusters are left alone, and the new clusters are added under a different key.
|
||||
"""
|
||||
final = []
|
||||
for key, sg in doc.spans.items():
|
||||
if not key.startswith("{old_key}_"):
|
||||
continue
|
||||
|
||||
heads = [span.root.i for span in sg]
|
||||
heads = sorted(list(set(heads)))
|
||||
head_spans = [doc[hh:hh+1] for hh in heads]
|
||||
#print("===== headifying =====")
|
||||
#print(sg)
|
||||
#print(head_spans)
|
||||
# singletons are skipped
|
||||
if len(heads) > 1:
|
||||
final.append(head_spans)
|
||||
|
||||
# now add the new spangroups
|
||||
for ii, spans in enumerate(final):
|
||||
doc.spans[f"{new_key}_{ii}"] = spans
|
||||
return doc
|
||||
|
||||
@Language.factory(
|
||||
"coref",
|
||||
|
@ -237,8 +260,6 @@ class CoreferenceResolver(TrainablePipe):
|
|||
return losses
|
||||
set_dropout_rate(self.model, drop)
|
||||
|
||||
make_head_only_clusters(examples)
|
||||
|
||||
inputs = [example.predicted for example in examples]
|
||||
preds, backprop = self.model.begin_update(inputs)
|
||||
score_matrix, mention_idx = preds
|
||||
|
@ -399,7 +420,6 @@ class CoreferenceResolver(TrainablePipe):
|
|||
def score(self, examples, **kwargs):
|
||||
"""Score a batch of examples."""
|
||||
|
||||
make_head_only_clusters(examples)
|
||||
# NOTE traditionally coref uses the average of b_cubed, muc, and ceaf.
|
||||
# we need to handle the average ourselves.
|
||||
scores = []
|
||||
|
|
Loading…
Reference in New Issue
Block a user