mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-16 11:12:25 +03:00
Refactor Coval Scoring code (#10875)
* Move coref scoring code to scorer.py Includes some renames to make names less generic. * Refactor coval code to remove ternary expressions * Black formatting * Add header * Make scorers into registered scorers * Small test fixes * Skip coref tests when torch not present Coref can't be loaded without Torch, so nothing works. * Fix remaining type issues Some of this just involves ignoring types in thorny areas. Two main issues: 1. Some things have weird types due to indirection/ argskwargs 2. xp2torch return type seems to have changed at some point * Update spacy/scorer.py Co-authored-by: kadarakos <kadar.akos@gmail.com> * Small changes from review * Be specific about the ValueError * Type fix Co-authored-by: kadarakos <kadar.akos@gmail.com>
This commit is contained in:
parent
196886bbca
commit
16894e665d
|
@ -127,3 +127,36 @@ distributed under the License is distributed on an "AS IS" BASIS,
|
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
|
||||
coval
|
||||
-----
|
||||
|
||||
* Files: scorer.py
|
||||
|
||||
The implementations of ClusterEvaluator, lea, get_cluster_info, and
|
||||
get_markable_assignments are adapted from coval, which is distributed
|
||||
under the following license:
|
||||
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright 2018 Nafise Sadat Moosavi (ns.moosavi at gmail dot com)
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
|
|
@ -1,124 +0,0 @@
|
|||
# copied from coval
|
||||
# https://github.com/ns-moosavi/coval
|
||||
|
||||
|
||||
def get_cluster_info(predicted_clusters, gold_clusters):
|
||||
p2g = get_markable_assignments(predicted_clusters, gold_clusters)
|
||||
g2p = get_markable_assignments(gold_clusters, predicted_clusters)
|
||||
# this is the data format used as input by the evaluator
|
||||
return (gold_clusters, predicted_clusters, g2p, p2g)
|
||||
|
||||
|
||||
def get_markable_assignments(in_clusters, out_clusters):
|
||||
markable_cluster_ids = {}
|
||||
out_dic = {}
|
||||
for cluster_id, cluster in enumerate(out_clusters):
|
||||
for m in cluster:
|
||||
out_dic[m] = cluster_id
|
||||
|
||||
for cluster in in_clusters:
|
||||
for im in cluster:
|
||||
for om in out_dic:
|
||||
if im == om:
|
||||
markable_cluster_ids[im] = out_dic[om]
|
||||
break
|
||||
|
||||
return markable_cluster_ids
|
||||
|
||||
|
||||
def f1(p_num, p_den, r_num, r_den, beta=1):
|
||||
p = 0 if p_den == 0 else p_num / float(p_den)
|
||||
r = 0 if r_den == 0 else r_num / float(r_den)
|
||||
return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r)
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self, metric, beta=1, keep_aggregated_values=False):
|
||||
self.p_num = 0
|
||||
self.p_den = 0
|
||||
self.r_num = 0
|
||||
self.r_den = 0
|
||||
self.metric = metric
|
||||
self.beta = beta
|
||||
self.keep_aggregated_values = keep_aggregated_values
|
||||
|
||||
if keep_aggregated_values:
|
||||
self.aggregated_p_num = []
|
||||
self.aggregated_p_den = []
|
||||
self.aggregated_r_num = []
|
||||
self.aggregated_r_den = []
|
||||
|
||||
def update(self, coref_info):
|
||||
(
|
||||
key_clusters,
|
||||
sys_clusters,
|
||||
key_mention_sys_cluster,
|
||||
sys_mention_key_cluster,
|
||||
) = coref_info
|
||||
|
||||
pn, pd = self.metric(sys_clusters, key_clusters, sys_mention_key_cluster)
|
||||
rn, rd = self.metric(key_clusters, sys_clusters, key_mention_sys_cluster)
|
||||
self.p_num += pn
|
||||
self.p_den += pd
|
||||
self.r_num += rn
|
||||
self.r_den += rd
|
||||
|
||||
if self.keep_aggregated_values:
|
||||
self.aggregated_p_num.append(pn)
|
||||
self.aggregated_p_den.append(pd)
|
||||
self.aggregated_r_num.append(rn)
|
||||
self.aggregated_r_den.append(rd)
|
||||
|
||||
def get_f1(self):
|
||||
return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta)
|
||||
|
||||
def get_recall(self):
|
||||
return 0 if self.r_num == 0 else self.r_num / float(self.r_den)
|
||||
|
||||
def get_precision(self):
|
||||
return 0 if self.p_num == 0 else self.p_num / float(self.p_den)
|
||||
|
||||
def get_prf(self):
|
||||
return self.get_precision(), self.get_recall(), self.get_f1()
|
||||
|
||||
def get_counts(self):
|
||||
return self.p_num, self.p_den, self.r_num, self.r_den
|
||||
|
||||
def get_aggregated_values(self):
|
||||
return (
|
||||
self.aggregated_p_num,
|
||||
self.aggregated_p_den,
|
||||
self.aggregated_r_num,
|
||||
self.aggregated_r_den,
|
||||
)
|
||||
|
||||
|
||||
def lea(input_clusters, output_clusters, mention_to_gold):
|
||||
num, den = 0, 0
|
||||
|
||||
for c in input_clusters:
|
||||
if len(c) == 1:
|
||||
all_links = 1
|
||||
if (
|
||||
c[0] in mention_to_gold
|
||||
and len(output_clusters[mention_to_gold[c[0]]]) == 1
|
||||
):
|
||||
common_links = 1
|
||||
else:
|
||||
common_links = 0
|
||||
else:
|
||||
common_links = 0
|
||||
all_links = len(c) * (len(c) - 1) / 2.0
|
||||
for i, m in enumerate(c):
|
||||
if m in mention_to_gold:
|
||||
for m2 in c[i + 1 :]:
|
||||
if (
|
||||
m2 in mention_to_gold
|
||||
and mention_to_gold[m] == mention_to_gold[m2]
|
||||
):
|
||||
common_links += 1
|
||||
|
||||
num += len(c) * common_links / float(all_links)
|
||||
den += len(c)
|
||||
|
||||
return num, den
|
|
@ -49,10 +49,10 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo
|
|||
X = X[0]
|
||||
word_features = xp2torch(X, requires_grad=is_train)
|
||||
|
||||
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||
# TODO fix or remove type annotations
|
||||
def backprop(args: ArgsKwargs): #-> List[Floats2d]:
|
||||
# convert to xp and wrap in list
|
||||
gradients = torch2xp(args.args[0])
|
||||
# assert isinstance(gradients, Floats2d)
|
||||
return [gradients]
|
||||
|
||||
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
|
||||
|
|
|
@ -32,23 +32,6 @@ def get_sentence_ids(doc):
|
|||
return out
|
||||
|
||||
|
||||
def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
|
||||
"""Given a doc, give the mention clusters.
|
||||
|
||||
This is useful for scoring.
|
||||
"""
|
||||
out = []
|
||||
for name, val in doc.spans.items():
|
||||
if not name.startswith(prefix):
|
||||
continue
|
||||
|
||||
cluster = []
|
||||
for mention in val:
|
||||
cluster.append((mention.start, mention.end))
|
||||
out.append(cluster)
|
||||
return out
|
||||
|
||||
|
||||
# from model.py, refactored to be non-member
|
||||
def get_predicted_antecedents(xp, antecedent_idx, antecedent_scores):
|
||||
"""Get the ID of the antecedent for each span. -1 if no antecedent."""
|
||||
|
|
|
@ -43,23 +43,24 @@ def build_span_predictor(
|
|||
|
||||
|
||||
def convert_span_predictor_inputs(
|
||||
model: Model, X: Tuple[Ints1d, Tuple[Floats2d, Ints1d]], is_train: bool
|
||||
model: Model, X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], is_train: bool
|
||||
):
|
||||
tok2vec, (sent_ids, head_ids) = X
|
||||
# Normally we should use the input is_train, but for these two it's not relevant
|
||||
|
||||
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||
# TODO fix the type here, or remove it
|
||||
def backprop(args: ArgsKwargs): #-> Tuple[List[Floats2d], None]:
|
||||
gradients = torch2xp(args.args[1])
|
||||
# The sent_ids and head_ids are None because no gradients
|
||||
return [[gradients], None]
|
||||
|
||||
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
||||
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
|
||||
sent_ids_tensor = xp2torch(sent_ids[0], requires_grad=False)
|
||||
if not head_ids[0].size:
|
||||
head_ids = torch.empty(size=(0,))
|
||||
head_ids_tensor = torch.empty(size=(0,))
|
||||
else:
|
||||
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
||||
head_ids_tensor = xp2torch(head_ids[0], requires_grad=False)
|
||||
|
||||
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
||||
argskwargs = ArgsKwargs(args=(sent_ids_tensor, word_features, head_ids_tensor), kwargs={})
|
||||
return argskwargs, backprop
|
||||
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from ..training import Example, validate_examples, validate_get_examples
|
|||
from ..errors import Errors
|
||||
from ..tokens import Doc
|
||||
from ..vocab import Vocab
|
||||
from ..util import registry
|
||||
|
||||
from ..ml.models.coref_util import (
|
||||
create_gold_scores,
|
||||
|
@ -21,10 +22,9 @@ from ..ml.models.coref_util import (
|
|||
get_clusters_from_doc,
|
||||
get_predicted_clusters,
|
||||
DEFAULT_CLUSTER_PREFIX,
|
||||
doc2clusters,
|
||||
)
|
||||
|
||||
from ..coref_scorer import Evaluator, get_cluster_info, lea
|
||||
from ..scorer import Scorer
|
||||
|
||||
|
||||
default_config = """
|
||||
|
@ -57,7 +57,14 @@ depth = 2
|
|||
"""
|
||||
DEFAULT_COREF_MODEL = Config().from_str(default_config)["model"]
|
||||
|
||||
DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
|
||||
|
||||
def coref_scorer(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
return Scorer.score_coref_clusters(examples, **kwargs)
|
||||
|
||||
|
||||
@registry.scorers("spacy.coref_scorer.v1")
|
||||
def make_coref_scorer():
|
||||
return coref_scorer
|
||||
|
||||
|
||||
@Language.factory(
|
||||
|
@ -67,6 +74,7 @@ DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
|
|||
default_config={
|
||||
"model": DEFAULT_COREF_MODEL,
|
||||
"span_cluster_prefix": DEFAULT_CLUSTER_PREFIX,
|
||||
"scorer": {"@scorers": "spacy.coref_scorer.v1"},
|
||||
},
|
||||
default_score_weights={"coref_f": 1.0, "coref_p": None, "coref_r": None},
|
||||
)
|
||||
|
@ -74,12 +82,13 @@ def make_coref(
|
|||
nlp: Language,
|
||||
name: str,
|
||||
model,
|
||||
span_cluster_prefix: str = "coref",
|
||||
scorer: Optional[Callable],
|
||||
span_cluster_prefix: str,
|
||||
) -> "CoreferenceResolver":
|
||||
"""Create a CoreferenceResolver component."""
|
||||
|
||||
return CoreferenceResolver(
|
||||
nlp.vocab, model, name, span_cluster_prefix=span_cluster_prefix
|
||||
nlp.vocab, model, name, span_cluster_prefix=span_cluster_prefix, scorer=scorer
|
||||
)
|
||||
|
||||
|
||||
|
@ -96,7 +105,8 @@ class CoreferenceResolver(TrainablePipe):
|
|||
name: str = "coref",
|
||||
*,
|
||||
span_mentions: str = "coref_mentions",
|
||||
span_cluster_prefix: str,
|
||||
span_cluster_prefix: str = DEFAULT_CLUSTER_PREFIX,
|
||||
scorer: Optional[Callable] = coref_scorer,
|
||||
) -> None:
|
||||
"""Initialize a coreference resolution component.
|
||||
|
||||
|
@ -118,7 +128,8 @@ class CoreferenceResolver(TrainablePipe):
|
|||
self.span_cluster_prefix = span_cluster_prefix
|
||||
self._rehearsal_model = None
|
||||
|
||||
self.cfg: Dict[str, Any] = {}
|
||||
self.cfg: Dict[str, Any] = {"span_cluster_prefix": span_cluster_prefix}
|
||||
self.scorer = scorer
|
||||
|
||||
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
|
@ -276,7 +287,6 @@ class CoreferenceResolver(TrainablePipe):
|
|||
log_marg = ops.softmax(score_matrix + ops.xp.log(top_gscores), axis=1)
|
||||
log_norm = ops.softmax(score_matrix, axis=1)
|
||||
grad = log_norm - log_marg
|
||||
# gradients.append((grad, cidx))
|
||||
loss = float((grad**2).sum())
|
||||
|
||||
return loss, grad
|
||||
|
@ -306,26 +316,3 @@ class CoreferenceResolver(TrainablePipe):
|
|||
|
||||
assert len(X) > 0, Errors.E923.format(name=self.name)
|
||||
self.model.initialize(X=X, Y=Y)
|
||||
|
||||
def score(self, examples, **kwargs):
|
||||
"""Score a batch of examples using LEA.
|
||||
For details on how LEA works and why to use it see the paper:
|
||||
Which Coreference Evaluation Metric Do You Trust? A Proposal for a Link-based Entity Aware Metric
|
||||
Moosavi and Strube, 2016
|
||||
https://api.semanticscholar.org/CorpusID:17606580
|
||||
"""
|
||||
|
||||
evaluator = Evaluator(lea)
|
||||
|
||||
for ex in examples:
|
||||
p_clusters = doc2clusters(ex.predicted, self.span_cluster_prefix)
|
||||
g_clusters = doc2clusters(ex.reference, self.span_cluster_prefix)
|
||||
cluster_info = get_cluster_info(p_clusters, g_clusters)
|
||||
evaluator.update(cluster_info)
|
||||
|
||||
score = {
|
||||
"coref_f": evaluator.get_f1(),
|
||||
"coref_p": evaluator.get_precision(),
|
||||
"coref_r": evaluator.get_recall(),
|
||||
}
|
||||
return score
|
||||
|
|
|
@ -5,20 +5,19 @@ from thinc.types import Floats2d, Floats3d, Ints2d
|
|||
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
||||
from thinc.api import set_dropout_rate, to_categorical
|
||||
from itertools import islice
|
||||
from statistics import mean
|
||||
|
||||
from .trainable_pipe import TrainablePipe
|
||||
from ..language import Language
|
||||
from ..training import Example, validate_examples, validate_get_examples
|
||||
from ..errors import Errors
|
||||
from ..scorer import Scorer
|
||||
from ..scorer import Scorer, doc2clusters
|
||||
from ..tokens import Doc
|
||||
from ..vocab import Vocab
|
||||
from ..util import registry
|
||||
|
||||
from ..ml.models.coref_util import (
|
||||
MentionClusters,
|
||||
DEFAULT_CLUSTER_PREFIX,
|
||||
doc2clusters,
|
||||
)
|
||||
|
||||
default_span_predictor_config = """
|
||||
|
@ -52,6 +51,15 @@ depth = 2
|
|||
DEFAULT_SPAN_PREDICTOR_MODEL = Config().from_str(default_span_predictor_config)["model"]
|
||||
|
||||
|
||||
def span_predictor_scorer(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
return Scorer.score_span_predictions(examples, **kwargs)
|
||||
|
||||
|
||||
@registry.scorers("spacy.span_predictor_scorer.v1")
|
||||
def make_span_predictor_scorer():
|
||||
return span_predictor_scorer
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"span_predictor",
|
||||
assigns=["doc.spans"],
|
||||
|
@ -60,6 +68,7 @@ DEFAULT_SPAN_PREDICTOR_MODEL = Config().from_str(default_span_predictor_config)[
|
|||
"model": DEFAULT_SPAN_PREDICTOR_MODEL,
|
||||
"input_prefix": "coref_head_clusters",
|
||||
"output_prefix": "coref_clusters",
|
||||
"scorer": {"@scorers": "spacy.span_predictor_scorer.v1"},
|
||||
},
|
||||
default_score_weights={"span_accuracy": 1.0},
|
||||
)
|
||||
|
@ -69,10 +78,16 @@ def make_span_predictor(
|
|||
model,
|
||||
input_prefix: str = "coref_head_clusters",
|
||||
output_prefix: str = "coref_clusters",
|
||||
scorer: Optional[Callable] = span_predictor_scorer,
|
||||
) -> "SpanPredictor":
|
||||
"""Create a SpanPredictor component."""
|
||||
return SpanPredictor(
|
||||
nlp.vocab, model, name, input_prefix=input_prefix, output_prefix=output_prefix
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
input_prefix=input_prefix,
|
||||
output_prefix=output_prefix,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
|
@ -90,6 +105,7 @@ class SpanPredictor(TrainablePipe):
|
|||
*,
|
||||
input_prefix: str = "coref_head_clusters",
|
||||
output_prefix: str = "coref_clusters",
|
||||
scorer: Optional[Callable] = span_predictor_scorer,
|
||||
) -> None:
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
|
@ -97,7 +113,10 @@ class SpanPredictor(TrainablePipe):
|
|||
self.input_prefix = input_prefix
|
||||
self.output_prefix = output_prefix
|
||||
|
||||
self.cfg: Dict[str, Any] = {}
|
||||
self.scorer = scorer
|
||||
self.cfg: Dict[str, Any] = {
|
||||
"output_prefix": output_prefix,
|
||||
}
|
||||
|
||||
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
||||
# for now pretend there's just one doc
|
||||
|
@ -255,35 +274,3 @@ class SpanPredictor(TrainablePipe):
|
|||
|
||||
assert len(X) > 0, Errors.E923.format(name=self.name)
|
||||
self.model.initialize(X=X, Y=Y)
|
||||
|
||||
def score(self, examples, **kwargs):
|
||||
"""
|
||||
Evaluate on reconstructing the correct spans around
|
||||
gold heads.
|
||||
"""
|
||||
scores = []
|
||||
xp = self.model.ops.xp
|
||||
for eg in examples:
|
||||
starts = []
|
||||
ends = []
|
||||
pred_starts = []
|
||||
pred_ends = []
|
||||
ref = eg.reference
|
||||
pred = eg.predicted
|
||||
for key, gold_sg in ref.spans.items():
|
||||
if key.startswith(self.output_prefix):
|
||||
pred_sg = pred.spans[key]
|
||||
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
|
||||
starts.append(gold_mention.start)
|
||||
ends.append(gold_mention.end)
|
||||
pred_starts.append(pred_mention.start)
|
||||
pred_ends.append(pred_mention.end)
|
||||
|
||||
starts = xp.asarray(starts)
|
||||
ends = xp.asarray(ends)
|
||||
pred_starts = xp.asarray(pred_starts)
|
||||
pred_ends = xp.asarray(pred_ends)
|
||||
correct = (starts == pred_starts) * (ends == pred_ends)
|
||||
accuracy = correct.mean()
|
||||
scores.append(float(accuracy))
|
||||
return {"span_accuracy": mean(scores)}
|
||||
|
|
220
spacy/scorer.py
220
spacy/scorer.py
|
@ -2,6 +2,7 @@ from typing import Optional, Iterable, Dict, Set, List, Any, Callable, Tuple
|
|||
from typing import TYPE_CHECKING
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from statistics import mean
|
||||
|
||||
from .training import Example
|
||||
from .tokens import Token, Doc, Span
|
||||
|
@ -9,6 +10,7 @@ from .errors import Errors
|
|||
from .util import get_lang_class, SimpleFrozenList
|
||||
from .morphology import Morphology
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# This lets us add type hints for mypy etc. without causing circular imports
|
||||
from .language import Language # noqa: F401
|
||||
|
@ -873,6 +875,66 @@ class Scorer:
|
|||
f"{attr}_las_per_type": None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def score_coref_clusters(examples: Iterable[Example], **cfg):
|
||||
"""Score a batch of examples using LEA.
|
||||
|
||||
For details on how LEA works and why to use it see the paper:
|
||||
Which Coreference Evaluation Metric Do You Trust? A Proposal for a Link-based Entity Aware Metric
|
||||
Moosavi and Strube, 2016
|
||||
https://api.semanticscholar.org/CorpusID:17606580
|
||||
"""
|
||||
|
||||
span_cluster_prefix = cfg["span_cluster_prefix"]
|
||||
|
||||
evaluator = ClusterEvaluator(lea)
|
||||
|
||||
for ex in examples:
|
||||
p_clusters = doc2clusters(ex.predicted, span_cluster_prefix)
|
||||
g_clusters = doc2clusters(ex.reference, span_cluster_prefix)
|
||||
cluster_info = get_cluster_info(p_clusters, g_clusters)
|
||||
evaluator.update(cluster_info)
|
||||
|
||||
score = {
|
||||
"coref_f": evaluator.get_f1(),
|
||||
"coref_p": evaluator.get_precision(),
|
||||
"coref_r": evaluator.get_recall(),
|
||||
}
|
||||
return score
|
||||
|
||||
@staticmethod
|
||||
def score_span_predictions(examples: Iterable[Example], **cfg):
|
||||
"""Evaluate reconstruction of the correct spans from gold heads.
|
||||
"""
|
||||
scores = []
|
||||
output_prefix = cfg["output_prefix"]
|
||||
for eg in examples:
|
||||
starts = []
|
||||
ends = []
|
||||
pred_starts = []
|
||||
pred_ends = []
|
||||
ref = eg.reference
|
||||
pred = eg.predicted
|
||||
for key, gold_sg in ref.spans.items():
|
||||
if key.startswith(output_prefix):
|
||||
pred_sg = pred.spans[key]
|
||||
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
|
||||
starts.append(gold_mention.start)
|
||||
ends.append(gold_mention.end)
|
||||
pred_starts.append(pred_mention.start)
|
||||
pred_ends.append(pred_mention.end)
|
||||
|
||||
|
||||
# see how many are perfect
|
||||
cs = [a == b for a, b in zip(starts, pred_starts)]
|
||||
ce = [a == b for a, b in zip(ends, pred_ends)]
|
||||
correct = [int(a and b) for a, b in zip(cs, ce)]
|
||||
accuracy = sum(correct) / len(correct)
|
||||
|
||||
scores.append(float(accuracy))
|
||||
out_key = f"span_{output_prefix}_accuracy"
|
||||
return {out_key: mean(scores)}
|
||||
|
||||
|
||||
def get_ner_prf(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
"""Compute micro-PRF and per-entity PRF scores for a sequence of examples."""
|
||||
|
@ -1143,3 +1205,161 @@ def _auc(x, y):
|
|||
# regular numpy.ndarray instances.
|
||||
area = area.dtype.type(area)
|
||||
return area
|
||||
|
||||
|
||||
# The following implementations of get_cluster_info(), get_markable_assignments,
|
||||
# and ClusterEvaluator are adapted from coval, which is distributed under the
|
||||
# MIT License.
|
||||
# Copyright 2018 Nafise Sadat Moosavi
|
||||
# See licenses/3rd_party_licenses.txt
|
||||
def get_cluster_info(predicted_clusters, gold_clusters):
|
||||
p2g = get_markable_assignments(predicted_clusters, gold_clusters)
|
||||
g2p = get_markable_assignments(gold_clusters, predicted_clusters)
|
||||
# this is the data format used as input by the evaluator
|
||||
return (gold_clusters, predicted_clusters, g2p, p2g)
|
||||
|
||||
|
||||
def get_markable_assignments(in_clusters, out_clusters):
|
||||
markable_cluster_ids = {}
|
||||
out_dic = {}
|
||||
for cluster_id, cluster in enumerate(out_clusters):
|
||||
for m in cluster:
|
||||
out_dic[m] = cluster_id
|
||||
|
||||
for cluster in in_clusters:
|
||||
for im in cluster:
|
||||
for om in out_dic:
|
||||
if im == om:
|
||||
markable_cluster_ids[im] = out_dic[om]
|
||||
break
|
||||
|
||||
return markable_cluster_ids
|
||||
|
||||
|
||||
class ClusterEvaluator:
|
||||
def __init__(self, metric, beta=1, keep_aggregated_values=False):
|
||||
self.p_num = 0
|
||||
self.p_den = 0
|
||||
self.r_num = 0
|
||||
self.r_den = 0
|
||||
self.metric = metric
|
||||
self.beta = beta
|
||||
self.keep_aggregated_values = keep_aggregated_values
|
||||
|
||||
if keep_aggregated_values:
|
||||
self.aggregated_p_num = []
|
||||
self.aggregated_p_den = []
|
||||
self.aggregated_r_num = []
|
||||
self.aggregated_r_den = []
|
||||
|
||||
def update(self, coref_info):
|
||||
(
|
||||
key_clusters,
|
||||
sys_clusters,
|
||||
key_mention_sys_cluster,
|
||||
sys_mention_key_cluster,
|
||||
) = coref_info
|
||||
|
||||
pn, pd = self.metric(sys_clusters, key_clusters, sys_mention_key_cluster)
|
||||
rn, rd = self.metric(key_clusters, sys_clusters, key_mention_sys_cluster)
|
||||
self.p_num += pn
|
||||
self.p_den += pd
|
||||
self.r_num += rn
|
||||
self.r_den += rd
|
||||
|
||||
if self.keep_aggregated_values:
|
||||
self.aggregated_p_num.append(pn)
|
||||
self.aggregated_p_den.append(pd)
|
||||
self.aggregated_r_num.append(rn)
|
||||
self.aggregated_r_den.append(rd)
|
||||
|
||||
def f1(self, p_num, p_den, r_num, r_den, beta=1):
|
||||
p = 0
|
||||
if p_den != 0:
|
||||
p = p_num / float(p_den)
|
||||
r = 0
|
||||
if r_den != 0:
|
||||
r = r_num / float(r_den)
|
||||
|
||||
if p + r == 0:
|
||||
return 0
|
||||
|
||||
return (1 + beta * beta) * p * r / (beta * beta * p + r)
|
||||
|
||||
def get_f1(self):
|
||||
return self.f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta)
|
||||
|
||||
def get_recall(self):
|
||||
if self.r_num == 0:
|
||||
return 0
|
||||
|
||||
return self.r_num / float(self.r_den)
|
||||
|
||||
def get_precision(self):
|
||||
if self.p_num == 0:
|
||||
return 0
|
||||
|
||||
return self.p_num / float(self.p_den)
|
||||
|
||||
def get_prf(self):
|
||||
return self.get_precision(), self.get_recall(), self.get_f1()
|
||||
|
||||
def get_counts(self):
|
||||
return self.p_num, self.p_den, self.r_num, self.r_den
|
||||
|
||||
def get_aggregated_values(self):
|
||||
return (
|
||||
self.aggregated_p_num,
|
||||
self.aggregated_p_den,
|
||||
self.aggregated_r_num,
|
||||
self.aggregated_r_den,
|
||||
)
|
||||
|
||||
|
||||
def lea(input_clusters, output_clusters, mention_to_gold):
|
||||
num, den = 0, 0
|
||||
|
||||
for c in input_clusters:
|
||||
if len(c) == 1:
|
||||
all_links = 1
|
||||
if (
|
||||
c[0] in mention_to_gold
|
||||
and len(output_clusters[mention_to_gold[c[0]]]) == 1
|
||||
):
|
||||
common_links = 1
|
||||
else:
|
||||
common_links = 0
|
||||
else:
|
||||
common_links = 0
|
||||
all_links = len(c) * (len(c) - 1) / 2.0
|
||||
for i, m in enumerate(c):
|
||||
if m in mention_to_gold:
|
||||
for m2 in c[i + 1 :]:
|
||||
if (
|
||||
m2 in mention_to_gold
|
||||
and mention_to_gold[m] == mention_to_gold[m2]
|
||||
):
|
||||
common_links += 1
|
||||
|
||||
num += len(c) * common_links / float(all_links)
|
||||
den += len(c)
|
||||
|
||||
return num, den
|
||||
|
||||
|
||||
# This is coref related, but not from coval.
|
||||
def doc2clusters(doc: Doc, prefix: str) -> List[List[Tuple[int, int]]]:
|
||||
"""Given a doc, give the mention clusters.
|
||||
|
||||
This is used for scoring.
|
||||
"""
|
||||
out = []
|
||||
for name, val in doc.spans.items():
|
||||
if not name.startswith(prefix):
|
||||
continue
|
||||
|
||||
cluster = []
|
||||
for mention in val:
|
||||
cluster.append((mention.start, mention.end))
|
||||
out.append(cluster)
|
||||
return out
|
||||
|
|
|
@ -5,24 +5,26 @@ from spacy import util
|
|||
from spacy.training import Example
|
||||
from spacy.lang.en import English
|
||||
from spacy.tests.util import make_tempdir
|
||||
from spacy.pipeline.coref import DEFAULT_CLUSTERS_PREFIX
|
||||
from spacy.ml.models.coref_util import (
|
||||
DEFAULT_CLUSTER_PREFIX,
|
||||
select_non_crossing_spans,
|
||||
get_sentence_ids,
|
||||
)
|
||||
|
||||
from thinc.util import has_torch
|
||||
|
||||
# fmt: off
|
||||
TRAIN_DATA = [
|
||||
(
|
||||
"Yes, I noticed that many friends around me received it. It seems that almost everyone received this SMS.",
|
||||
{
|
||||
"spans": {
|
||||
f"{DEFAULT_CLUSTERS_PREFIX}_1": [
|
||||
f"{DEFAULT_CLUSTER_PREFIX}_1": [
|
||||
(5, 6, "MENTION"), # I
|
||||
(40, 42, "MENTION"), # me
|
||||
|
||||
],
|
||||
f"{DEFAULT_CLUSTERS_PREFIX}_2": [
|
||||
f"{DEFAULT_CLUSTER_PREFIX}_2": [
|
||||
(52, 54, "MENTION"), # it
|
||||
(95, 103, "MENTION"), # this SMS
|
||||
]
|
||||
|
@ -45,18 +47,20 @@ def snlp():
|
|||
return en
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_add_pipe(nlp):
|
||||
nlp.add_pipe("coref")
|
||||
assert nlp.pipe_names == ["coref"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_not_initialized(nlp):
|
||||
nlp.add_pipe("coref")
|
||||
text = "She gave me her pen."
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match="E109"):
|
||||
nlp(text)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_initialized(nlp):
|
||||
nlp.add_pipe("coref")
|
||||
nlp.initialize()
|
||||
|
@ -68,15 +72,16 @@ def test_initialized(nlp):
|
|||
assert len(v) <= 15
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_initialized_short(nlp):
|
||||
nlp.add_pipe("coref")
|
||||
nlp.initialize()
|
||||
assert nlp.pipe_names == ["coref"]
|
||||
text = "Hi there"
|
||||
doc = nlp(text)
|
||||
print(doc.spans)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_coref_serialization(nlp):
|
||||
# Test that the coref component can be serialized
|
||||
nlp.add_pipe("coref", last=True)
|
||||
|
@ -101,6 +106,7 @@ def test_coref_serialization(nlp):
|
|||
# assert spans_result == spans_result2
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_overfitting_IO(nlp):
|
||||
# Simple test to try and quickly overfit the senter - ensuring the ML models work correctly
|
||||
train_examples = []
|
||||
|
@ -147,6 +153,7 @@ def test_overfitting_IO(nlp):
|
|||
# assert_equal(batch_deps_1, no_batch_deps)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_crossing_spans():
|
||||
starts = [6, 10, 0, 1, 0, 1, 0, 1, 2, 2, 2]
|
||||
ends = [12, 12, 2, 3, 3, 4, 4, 4, 3, 4, 5]
|
||||
|
@ -158,6 +165,7 @@ def test_crossing_spans():
|
|||
guess = sorted(guess)
|
||||
assert gold == guess
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_sentence_map(snlp):
|
||||
doc = snlp("I like text. This is text.")
|
||||
sm = get_sentence_ids(doc)
|
||||
|
|
|
@ -7,8 +7,9 @@ from numpy.testing import assert_array_equal, assert_array_almost_equal
|
|||
import numpy
|
||||
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
|
||||
from spacy.ml.models import build_bow_text_classifier, build_simple_cnn_text_classifier
|
||||
from spacy.ml.models import build_spancat_model
|
||||
if has_torch:
|
||||
from spacy.ml.models import build_spancat_model, build_wl_coref_model
|
||||
from spacy.ml.models import build_wl_coref_model, build_span_predictor
|
||||
from spacy.ml.staticvectors import StaticVectors
|
||||
from spacy.ml.extract_spans import extract_spans, _get_span_indices
|
||||
from spacy.lang.en import English
|
||||
|
|
Loading…
Reference in New Issue
Block a user