mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
Make span2head component
This commit is contained in:
parent
e6917d8dc4
commit
0522a43116
|
@ -4,6 +4,7 @@ from typing import List, Set, Dict, Tuple
|
||||||
from thinc.types import Ints1d
|
from thinc.types import Ints1d
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from ...tokens import Doc
|
from ...tokens import Doc
|
||||||
|
from ...language import Language
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -37,37 +38,7 @@ def add_dummy(tensor: torch.Tensor, eps: bool = False):
|
||||||
output = torch.cat((dummy, tensor), dim=1)
|
output = torch.cat((dummy, tensor), dim=1)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def make_head_only_clusters(examples):
|
|
||||||
"""Replace coref clusters with head-only clusters.
|
|
||||||
|
|
||||||
This destructively modifies the docs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
#TODO what if all clusters are eliminated?
|
|
||||||
for eg in examples:
|
|
||||||
final = [] # save out clusters here
|
|
||||||
for key, sg in eg.reference.spans.items():
|
|
||||||
if not key.startswith("coref_clusters_"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
heads = [span.root.i for span in sg]
|
|
||||||
heads = list(set(heads))
|
|
||||||
head_spans = [eg.reference[hh:hh+1] for hh in heads]
|
|
||||||
if len(heads) > 1:
|
|
||||||
final.append(head_spans)
|
|
||||||
|
|
||||||
# now delete the existing clusters
|
|
||||||
keys = list(eg.reference.spans.keys())
|
|
||||||
for key in keys:
|
|
||||||
if not key.startswith("coref_clusters_"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
del eg.reference.spans[key]
|
|
||||||
|
|
||||||
# now add the new spangroups
|
|
||||||
for ii, spans in enumerate(final):
|
|
||||||
#TODO support alternate keys
|
|
||||||
eg.reference.spans[f"coref_clusters_{ii}"] = spans
|
|
||||||
|
|
||||||
# TODO replace with spaCy config
|
# TODO replace with spaCy config
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -25,8 +25,6 @@ from ..ml.models.coref_util import (
|
||||||
doc2clusters,
|
doc2clusters,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..ml.models.coref_util_wl import make_head_only_clusters
|
|
||||||
|
|
||||||
from ..coref_scorer import Evaluator, get_cluster_info, b_cubed, muc, ceafe
|
from ..coref_scorer import Evaluator, get_cluster_info, b_cubed, muc, ceafe
|
||||||
|
|
||||||
# TODO remove this - kept for reference for now
|
# TODO remove this - kept for reference for now
|
||||||
|
@ -93,6 +91,31 @@ DEFAULT_MODEL = Config().from_str(default_config)["model"]
|
||||||
|
|
||||||
DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
|
DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
|
||||||
|
|
||||||
|
@Language.component("span2head")
|
||||||
|
def make_head_only_clusters(doc, old_key="coref_clusters", new_key="coref_head_clusters"):
|
||||||
|
"""Create coref head clusters from span clusters.
|
||||||
|
|
||||||
|
The old clusters are left alone, and the new clusters are added under a different key.
|
||||||
|
"""
|
||||||
|
final = []
|
||||||
|
for key, sg in doc.spans.items():
|
||||||
|
if not key.startswith("{old_key}_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
heads = [span.root.i for span in sg]
|
||||||
|
heads = sorted(list(set(heads)))
|
||||||
|
head_spans = [doc[hh:hh+1] for hh in heads]
|
||||||
|
#print("===== headifying =====")
|
||||||
|
#print(sg)
|
||||||
|
#print(head_spans)
|
||||||
|
# singletons are skipped
|
||||||
|
if len(heads) > 1:
|
||||||
|
final.append(head_spans)
|
||||||
|
|
||||||
|
# now add the new spangroups
|
||||||
|
for ii, spans in enumerate(final):
|
||||||
|
doc.spans[f"{new_key}_{ii}"] = spans
|
||||||
|
return doc
|
||||||
|
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
"coref",
|
"coref",
|
||||||
|
@ -237,8 +260,6 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
return losses
|
return losses
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
|
|
||||||
make_head_only_clusters(examples)
|
|
||||||
|
|
||||||
inputs = [example.predicted for example in examples]
|
inputs = [example.predicted for example in examples]
|
||||||
preds, backprop = self.model.begin_update(inputs)
|
preds, backprop = self.model.begin_update(inputs)
|
||||||
score_matrix, mention_idx = preds
|
score_matrix, mention_idx = preds
|
||||||
|
@ -399,7 +420,6 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
def score(self, examples, **kwargs):
|
def score(self, examples, **kwargs):
|
||||||
"""Score a batch of examples."""
|
"""Score a batch of examples."""
|
||||||
|
|
||||||
make_head_only_clusters(examples)
|
|
||||||
# NOTE traditionally coref uses the average of b_cubed, muc, and ceaf.
|
# NOTE traditionally coref uses the average of b_cubed, muc, and ceaf.
|
||||||
# we need to handle the average ourselves.
|
# we need to handle the average ourselves.
|
||||||
scores = []
|
scores = []
|
||||||
|
|
Loading…
Reference in New Issue
Block a user