Refactor Coval Scoring code (#10875)

* Move coref scoring code to scorer.py

Includes some renames to make names less generic.

* Refactor coval code to remove ternary expressions

* Black formatting

* Add header

* Make scorers into registered scorers

* Small test fixes

* Skip coref tests when torch not present

Coref can't be loaded without Torch, so nothing works.

* Fix remaining type issues

Some of this just involves ignoring types in thorny areas. Two main
issues:

1. Some things have weird types due to indirection/ argskwargs
2. xp2torch return type seems to have changed at some point

* Update spacy/scorer.py

Co-authored-by: kadarakos <kadar.akos@gmail.com>

* Small changes from review

* Be specific about the ValueError

* Type fix

Co-authored-by: kadarakos <kadar.akos@gmail.com>
This commit is contained in:
Paul O'Leary McCann 2022-06-22 16:05:52 +09:00 committed by GitHub
parent 196886bbca
commit 16894e665d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 321 additions and 225 deletions

View File

@ -127,3 +127,36 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
coval
-----
* Files: scorer.py
The implementations of ClusterEvaluator, lea, get_cluster_info, and
get_markable_assignments are adapted from coval, which is distributed
under the following license:
The MIT License (MIT)
Copyright 2018 Nafise Sadat Moosavi (ns.moosavi at gmail dot com)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,124 +0,0 @@
# copied from coval
# https://github.com/ns-moosavi/coval
def get_cluster_info(predicted_clusters, gold_clusters):
p2g = get_markable_assignments(predicted_clusters, gold_clusters)
g2p = get_markable_assignments(gold_clusters, predicted_clusters)
# this is the data format used as input by the evaluator
return (gold_clusters, predicted_clusters, g2p, p2g)
def get_markable_assignments(in_clusters, out_clusters):
markable_cluster_ids = {}
out_dic = {}
for cluster_id, cluster in enumerate(out_clusters):
for m in cluster:
out_dic[m] = cluster_id
for cluster in in_clusters:
for im in cluster:
for om in out_dic:
if im == om:
markable_cluster_ids[im] = out_dic[om]
break
return markable_cluster_ids
def f1(p_num, p_den, r_num, r_den, beta=1):
p = 0 if p_den == 0 else p_num / float(p_den)
r = 0 if r_den == 0 else r_num / float(r_den)
return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r)
class Evaluator:
def __init__(self, metric, beta=1, keep_aggregated_values=False):
self.p_num = 0
self.p_den = 0
self.r_num = 0
self.r_den = 0
self.metric = metric
self.beta = beta
self.keep_aggregated_values = keep_aggregated_values
if keep_aggregated_values:
self.aggregated_p_num = []
self.aggregated_p_den = []
self.aggregated_r_num = []
self.aggregated_r_den = []
def update(self, coref_info):
(
key_clusters,
sys_clusters,
key_mention_sys_cluster,
sys_mention_key_cluster,
) = coref_info
pn, pd = self.metric(sys_clusters, key_clusters, sys_mention_key_cluster)
rn, rd = self.metric(key_clusters, sys_clusters, key_mention_sys_cluster)
self.p_num += pn
self.p_den += pd
self.r_num += rn
self.r_den += rd
if self.keep_aggregated_values:
self.aggregated_p_num.append(pn)
self.aggregated_p_den.append(pd)
self.aggregated_r_num.append(rn)
self.aggregated_r_den.append(rd)
def get_f1(self):
return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta)
def get_recall(self):
return 0 if self.r_num == 0 else self.r_num / float(self.r_den)
def get_precision(self):
return 0 if self.p_num == 0 else self.p_num / float(self.p_den)
def get_prf(self):
return self.get_precision(), self.get_recall(), self.get_f1()
def get_counts(self):
return self.p_num, self.p_den, self.r_num, self.r_den
def get_aggregated_values(self):
return (
self.aggregated_p_num,
self.aggregated_p_den,
self.aggregated_r_num,
self.aggregated_r_den,
)
def lea(input_clusters, output_clusters, mention_to_gold):
num, den = 0, 0
for c in input_clusters:
if len(c) == 1:
all_links = 1
if (
c[0] in mention_to_gold
and len(output_clusters[mention_to_gold[c[0]]]) == 1
):
common_links = 1
else:
common_links = 0
else:
common_links = 0
all_links = len(c) * (len(c) - 1) / 2.0
for i, m in enumerate(c):
if m in mention_to_gold:
for m2 in c[i + 1 :]:
if (
m2 in mention_to_gold
and mention_to_gold[m] == mention_to_gold[m2]
):
common_links += 1
num += len(c) * common_links / float(all_links)
den += len(c)
return num, den

View File

@ -49,10 +49,10 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo
X = X[0]
word_features = xp2torch(X, requires_grad=is_train)
def backprop(args: ArgsKwargs) -> List[Floats2d]:
# TODO fix or remove type annotations
def backprop(args: ArgsKwargs): #-> List[Floats2d]:
# convert to xp and wrap in list
gradients = torch2xp(args.args[0])
# assert isinstance(gradients, Floats2d)
return [gradients]
return ArgsKwargs(args=(word_features,), kwargs={}), backprop

View File

@ -32,23 +32,6 @@ def get_sentence_ids(doc):
return out
def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
"""Given a doc, give the mention clusters.
This is useful for scoring.
"""
out = []
for name, val in doc.spans.items():
if not name.startswith(prefix):
continue
cluster = []
for mention in val:
cluster.append((mention.start, mention.end))
out.append(cluster)
return out
# from model.py, refactored to be non-member
def get_predicted_antecedents(xp, antecedent_idx, antecedent_scores):
"""Get the ID of the antecedent for each span. -1 if no antecedent."""

View File

@ -43,23 +43,24 @@ def build_span_predictor(
def convert_span_predictor_inputs(
model: Model, X: Tuple[Ints1d, Tuple[Floats2d, Ints1d]], is_train: bool
model: Model, X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], is_train: bool
):
tok2vec, (sent_ids, head_ids) = X
# Normally we should use the input is_train, but for these two it's not relevant
def backprop(args: ArgsKwargs) -> List[Floats2d]:
# TODO fix the type here, or remove it
def backprop(args: ArgsKwargs): #-> Tuple[List[Floats2d], None]:
gradients = torch2xp(args.args[1])
# The sent_ids and head_ids are None because no gradients
return [[gradients], None]
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
sent_ids_tensor = xp2torch(sent_ids[0], requires_grad=False)
if not head_ids[0].size:
head_ids = torch.empty(size=(0,))
head_ids_tensor = torch.empty(size=(0,))
else:
head_ids = xp2torch(head_ids[0], requires_grad=False)
head_ids_tensor = xp2torch(head_ids[0], requires_grad=False)
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
argskwargs = ArgsKwargs(args=(sent_ids_tensor, word_features, head_ids_tensor), kwargs={})
return argskwargs, backprop

View File

@ -13,6 +13,7 @@ from ..training import Example, validate_examples, validate_get_examples
from ..errors import Errors
from ..tokens import Doc
from ..vocab import Vocab
from ..util import registry
from ..ml.models.coref_util import (
create_gold_scores,
@ -21,10 +22,9 @@ from ..ml.models.coref_util import (
get_clusters_from_doc,
get_predicted_clusters,
DEFAULT_CLUSTER_PREFIX,
doc2clusters,
)
from ..coref_scorer import Evaluator, get_cluster_info, lea
from ..scorer import Scorer
default_config = """
@ -57,7 +57,14 @@ depth = 2
"""
DEFAULT_COREF_MODEL = Config().from_str(default_config)["model"]
DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
def coref_scorer(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
return Scorer.score_coref_clusters(examples, **kwargs)
@registry.scorers("spacy.coref_scorer.v1")
def make_coref_scorer():
return coref_scorer
@Language.factory(
@ -67,6 +74,7 @@ DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
default_config={
"model": DEFAULT_COREF_MODEL,
"span_cluster_prefix": DEFAULT_CLUSTER_PREFIX,
"scorer": {"@scorers": "spacy.coref_scorer.v1"},
},
default_score_weights={"coref_f": 1.0, "coref_p": None, "coref_r": None},
)
@ -74,12 +82,13 @@ def make_coref(
nlp: Language,
name: str,
model,
span_cluster_prefix: str = "coref",
scorer: Optional[Callable],
span_cluster_prefix: str,
) -> "CoreferenceResolver":
"""Create a CoreferenceResolver component."""
return CoreferenceResolver(
nlp.vocab, model, name, span_cluster_prefix=span_cluster_prefix
nlp.vocab, model, name, span_cluster_prefix=span_cluster_prefix, scorer=scorer
)
@ -96,7 +105,8 @@ class CoreferenceResolver(TrainablePipe):
name: str = "coref",
*,
span_mentions: str = "coref_mentions",
span_cluster_prefix: str,
span_cluster_prefix: str = DEFAULT_CLUSTER_PREFIX,
scorer: Optional[Callable] = coref_scorer,
) -> None:
"""Initialize a coreference resolution component.
@ -118,7 +128,8 @@ class CoreferenceResolver(TrainablePipe):
self.span_cluster_prefix = span_cluster_prefix
self._rehearsal_model = None
self.cfg: Dict[str, Any] = {}
self.cfg: Dict[str, Any] = {"span_cluster_prefix": span_cluster_prefix}
self.scorer = scorer
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
"""Apply the pipeline's model to a batch of docs, without modifying them.
@ -276,7 +287,6 @@ class CoreferenceResolver(TrainablePipe):
log_marg = ops.softmax(score_matrix + ops.xp.log(top_gscores), axis=1)
log_norm = ops.softmax(score_matrix, axis=1)
grad = log_norm - log_marg
# gradients.append((grad, cidx))
loss = float((grad**2).sum())
return loss, grad
@ -306,26 +316,3 @@ class CoreferenceResolver(TrainablePipe):
assert len(X) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=X, Y=Y)
def score(self, examples, **kwargs):
"""Score a batch of examples using LEA.
For details on how LEA works and why to use it see the paper:
Which Coreference Evaluation Metric Do You Trust? A Proposal for a Link-based Entity Aware Metric
Moosavi and Strube, 2016
https://api.semanticscholar.org/CorpusID:17606580
"""
evaluator = Evaluator(lea)
for ex in examples:
p_clusters = doc2clusters(ex.predicted, self.span_cluster_prefix)
g_clusters = doc2clusters(ex.reference, self.span_cluster_prefix)
cluster_info = get_cluster_info(p_clusters, g_clusters)
evaluator.update(cluster_info)
score = {
"coref_f": evaluator.get_f1(),
"coref_p": evaluator.get_precision(),
"coref_r": evaluator.get_recall(),
}
return score

View File

@ -5,20 +5,19 @@ from thinc.types import Floats2d, Floats3d, Ints2d
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
from thinc.api import set_dropout_rate, to_categorical
from itertools import islice
from statistics import mean
from .trainable_pipe import TrainablePipe
from ..language import Language
from ..training import Example, validate_examples, validate_get_examples
from ..errors import Errors
from ..scorer import Scorer
from ..scorer import Scorer, doc2clusters
from ..tokens import Doc
from ..vocab import Vocab
from ..util import registry
from ..ml.models.coref_util import (
MentionClusters,
DEFAULT_CLUSTER_PREFIX,
doc2clusters,
)
default_span_predictor_config = """
@ -52,6 +51,15 @@ depth = 2
DEFAULT_SPAN_PREDICTOR_MODEL = Config().from_str(default_span_predictor_config)["model"]
def span_predictor_scorer(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
return Scorer.score_span_predictions(examples, **kwargs)
@registry.scorers("spacy.span_predictor_scorer.v1")
def make_span_predictor_scorer():
return span_predictor_scorer
@Language.factory(
"span_predictor",
assigns=["doc.spans"],
@ -60,6 +68,7 @@ DEFAULT_SPAN_PREDICTOR_MODEL = Config().from_str(default_span_predictor_config)[
"model": DEFAULT_SPAN_PREDICTOR_MODEL,
"input_prefix": "coref_head_clusters",
"output_prefix": "coref_clusters",
"scorer": {"@scorers": "spacy.span_predictor_scorer.v1"},
},
default_score_weights={"span_accuracy": 1.0},
)
@ -69,10 +78,16 @@ def make_span_predictor(
model,
input_prefix: str = "coref_head_clusters",
output_prefix: str = "coref_clusters",
scorer: Optional[Callable] = span_predictor_scorer,
) -> "SpanPredictor":
"""Create a SpanPredictor component."""
return SpanPredictor(
nlp.vocab, model, name, input_prefix=input_prefix, output_prefix=output_prefix
nlp.vocab,
model,
name,
input_prefix=input_prefix,
output_prefix=output_prefix,
scorer=scorer,
)
@ -90,6 +105,7 @@ class SpanPredictor(TrainablePipe):
*,
input_prefix: str = "coref_head_clusters",
output_prefix: str = "coref_clusters",
scorer: Optional[Callable] = span_predictor_scorer,
) -> None:
self.vocab = vocab
self.model = model
@ -97,7 +113,10 @@ class SpanPredictor(TrainablePipe):
self.input_prefix = input_prefix
self.output_prefix = output_prefix
self.cfg: Dict[str, Any] = {}
self.scorer = scorer
self.cfg: Dict[str, Any] = {
"output_prefix": output_prefix,
}
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
# for now pretend there's just one doc
@ -255,35 +274,3 @@ class SpanPredictor(TrainablePipe):
assert len(X) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=X, Y=Y)
def score(self, examples, **kwargs):
"""
Evaluate on reconstructing the correct spans around
gold heads.
"""
scores = []
xp = self.model.ops.xp
for eg in examples:
starts = []
ends = []
pred_starts = []
pred_ends = []
ref = eg.reference
pred = eg.predicted
for key, gold_sg in ref.spans.items():
if key.startswith(self.output_prefix):
pred_sg = pred.spans[key]
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
starts.append(gold_mention.start)
ends.append(gold_mention.end)
pred_starts.append(pred_mention.start)
pred_ends.append(pred_mention.end)
starts = xp.asarray(starts)
ends = xp.asarray(ends)
pred_starts = xp.asarray(pred_starts)
pred_ends = xp.asarray(pred_ends)
correct = (starts == pred_starts) * (ends == pred_ends)
accuracy = correct.mean()
scores.append(float(accuracy))
return {"span_accuracy": mean(scores)}

View File

@ -2,6 +2,7 @@ from typing import Optional, Iterable, Dict, Set, List, Any, Callable, Tuple
from typing import TYPE_CHECKING
import numpy as np
from collections import defaultdict
from statistics import mean
from .training import Example
from .tokens import Token, Doc, Span
@ -9,6 +10,7 @@ from .errors import Errors
from .util import get_lang_class, SimpleFrozenList
from .morphology import Morphology
if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports
from .language import Language # noqa: F401
@ -873,6 +875,66 @@ class Scorer:
f"{attr}_las_per_type": None,
}
@staticmethod
def score_coref_clusters(examples: Iterable[Example], **cfg):
"""Score a batch of examples using LEA.
For details on how LEA works and why to use it see the paper:
Which Coreference Evaluation Metric Do You Trust? A Proposal for a Link-based Entity Aware Metric
Moosavi and Strube, 2016
https://api.semanticscholar.org/CorpusID:17606580
"""
span_cluster_prefix = cfg["span_cluster_prefix"]
evaluator = ClusterEvaluator(lea)
for ex in examples:
p_clusters = doc2clusters(ex.predicted, span_cluster_prefix)
g_clusters = doc2clusters(ex.reference, span_cluster_prefix)
cluster_info = get_cluster_info(p_clusters, g_clusters)
evaluator.update(cluster_info)
score = {
"coref_f": evaluator.get_f1(),
"coref_p": evaluator.get_precision(),
"coref_r": evaluator.get_recall(),
}
return score
@staticmethod
def score_span_predictions(examples: Iterable[Example], **cfg):
"""Evaluate reconstruction of the correct spans from gold heads.
"""
scores = []
output_prefix = cfg["output_prefix"]
for eg in examples:
starts = []
ends = []
pred_starts = []
pred_ends = []
ref = eg.reference
pred = eg.predicted
for key, gold_sg in ref.spans.items():
if key.startswith(output_prefix):
pred_sg = pred.spans[key]
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
starts.append(gold_mention.start)
ends.append(gold_mention.end)
pred_starts.append(pred_mention.start)
pred_ends.append(pred_mention.end)
# see how many are perfect
cs = [a == b for a, b in zip(starts, pred_starts)]
ce = [a == b for a, b in zip(ends, pred_ends)]
correct = [int(a and b) for a, b in zip(cs, ce)]
accuracy = sum(correct) / len(correct)
scores.append(float(accuracy))
out_key = f"span_{output_prefix}_accuracy"
return {out_key: mean(scores)}
def get_ner_prf(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
"""Compute micro-PRF and per-entity PRF scores for a sequence of examples."""
@ -1143,3 +1205,161 @@ def _auc(x, y):
# regular numpy.ndarray instances.
area = area.dtype.type(area)
return area
# The following implementations of get_cluster_info(), get_markable_assignments,
# and ClusterEvaluator are adapted from coval, which is distributed under the
# MIT License.
# Copyright 2018 Nafise Sadat Moosavi
# See licenses/3rd_party_licenses.txt
def get_cluster_info(predicted_clusters, gold_clusters):
p2g = get_markable_assignments(predicted_clusters, gold_clusters)
g2p = get_markable_assignments(gold_clusters, predicted_clusters)
# this is the data format used as input by the evaluator
return (gold_clusters, predicted_clusters, g2p, p2g)
def get_markable_assignments(in_clusters, out_clusters):
markable_cluster_ids = {}
out_dic = {}
for cluster_id, cluster in enumerate(out_clusters):
for m in cluster:
out_dic[m] = cluster_id
for cluster in in_clusters:
for im in cluster:
for om in out_dic:
if im == om:
markable_cluster_ids[im] = out_dic[om]
break
return markable_cluster_ids
class ClusterEvaluator:
def __init__(self, metric, beta=1, keep_aggregated_values=False):
self.p_num = 0
self.p_den = 0
self.r_num = 0
self.r_den = 0
self.metric = metric
self.beta = beta
self.keep_aggregated_values = keep_aggregated_values
if keep_aggregated_values:
self.aggregated_p_num = []
self.aggregated_p_den = []
self.aggregated_r_num = []
self.aggregated_r_den = []
def update(self, coref_info):
(
key_clusters,
sys_clusters,
key_mention_sys_cluster,
sys_mention_key_cluster,
) = coref_info
pn, pd = self.metric(sys_clusters, key_clusters, sys_mention_key_cluster)
rn, rd = self.metric(key_clusters, sys_clusters, key_mention_sys_cluster)
self.p_num += pn
self.p_den += pd
self.r_num += rn
self.r_den += rd
if self.keep_aggregated_values:
self.aggregated_p_num.append(pn)
self.aggregated_p_den.append(pd)
self.aggregated_r_num.append(rn)
self.aggregated_r_den.append(rd)
def f1(self, p_num, p_den, r_num, r_den, beta=1):
p = 0
if p_den != 0:
p = p_num / float(p_den)
r = 0
if r_den != 0:
r = r_num / float(r_den)
if p + r == 0:
return 0
return (1 + beta * beta) * p * r / (beta * beta * p + r)
def get_f1(self):
return self.f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta)
def get_recall(self):
if self.r_num == 0:
return 0
return self.r_num / float(self.r_den)
def get_precision(self):
if self.p_num == 0:
return 0
return self.p_num / float(self.p_den)
def get_prf(self):
return self.get_precision(), self.get_recall(), self.get_f1()
def get_counts(self):
return self.p_num, self.p_den, self.r_num, self.r_den
def get_aggregated_values(self):
return (
self.aggregated_p_num,
self.aggregated_p_den,
self.aggregated_r_num,
self.aggregated_r_den,
)
def lea(input_clusters, output_clusters, mention_to_gold):
num, den = 0, 0
for c in input_clusters:
if len(c) == 1:
all_links = 1
if (
c[0] in mention_to_gold
and len(output_clusters[mention_to_gold[c[0]]]) == 1
):
common_links = 1
else:
common_links = 0
else:
common_links = 0
all_links = len(c) * (len(c) - 1) / 2.0
for i, m in enumerate(c):
if m in mention_to_gold:
for m2 in c[i + 1 :]:
if (
m2 in mention_to_gold
and mention_to_gold[m] == mention_to_gold[m2]
):
common_links += 1
num += len(c) * common_links / float(all_links)
den += len(c)
return num, den
# This is coref related, but not from coval.
def doc2clusters(doc: Doc, prefix: str) -> List[List[Tuple[int, int]]]:
"""Given a doc, give the mention clusters.
This is used for scoring.
"""
out = []
for name, val in doc.spans.items():
if not name.startswith(prefix):
continue
cluster = []
for mention in val:
cluster.append((mention.start, mention.end))
out.append(cluster)
return out

View File

@ -5,24 +5,26 @@ from spacy import util
from spacy.training import Example
from spacy.lang.en import English
from spacy.tests.util import make_tempdir
from spacy.pipeline.coref import DEFAULT_CLUSTERS_PREFIX
from spacy.ml.models.coref_util import (
DEFAULT_CLUSTER_PREFIX,
select_non_crossing_spans,
get_sentence_ids,
)
from thinc.util import has_torch
# fmt: off
TRAIN_DATA = [
(
"Yes, I noticed that many friends around me received it. It seems that almost everyone received this SMS.",
{
"spans": {
f"{DEFAULT_CLUSTERS_PREFIX}_1": [
f"{DEFAULT_CLUSTER_PREFIX}_1": [
(5, 6, "MENTION"), # I
(40, 42, "MENTION"), # me
],
f"{DEFAULT_CLUSTERS_PREFIX}_2": [
f"{DEFAULT_CLUSTER_PREFIX}_2": [
(52, 54, "MENTION"), # it
(95, 103, "MENTION"), # this SMS
]
@ -45,18 +47,20 @@ def snlp():
return en
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_add_pipe(nlp):
nlp.add_pipe("coref")
assert nlp.pipe_names == ["coref"]
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_not_initialized(nlp):
nlp.add_pipe("coref")
text = "She gave me her pen."
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="E109"):
nlp(text)
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_initialized(nlp):
nlp.add_pipe("coref")
nlp.initialize()
@ -68,15 +72,16 @@ def test_initialized(nlp):
assert len(v) <= 15
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_initialized_short(nlp):
nlp.add_pipe("coref")
nlp.initialize()
assert nlp.pipe_names == ["coref"]
text = "Hi there"
doc = nlp(text)
print(doc.spans)
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_coref_serialization(nlp):
# Test that the coref component can be serialized
nlp.add_pipe("coref", last=True)
@ -101,6 +106,7 @@ def test_coref_serialization(nlp):
# assert spans_result == spans_result2
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_overfitting_IO(nlp):
# Simple test to try and quickly overfit the senter - ensuring the ML models work correctly
train_examples = []
@ -147,6 +153,7 @@ def test_overfitting_IO(nlp):
# assert_equal(batch_deps_1, no_batch_deps)
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_crossing_spans():
starts = [6, 10, 0, 1, 0, 1, 0, 1, 2, 2, 2]
ends = [12, 12, 2, 3, 3, 4, 4, 4, 3, 4, 5]
@ -158,6 +165,7 @@ def test_crossing_spans():
guess = sorted(guess)
assert gold == guess
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_sentence_map(snlp):
doc = snlp("I like text. This is text.")
sm = get_sentence_ids(doc)

View File

@ -7,8 +7,9 @@ from numpy.testing import assert_array_equal, assert_array_almost_equal
import numpy
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
from spacy.ml.models import build_bow_text_classifier, build_simple_cnn_text_classifier
from spacy.ml.models import build_spancat_model
if has_torch:
from spacy.ml.models import build_spancat_model, build_wl_coref_model
from spacy.ml.models import build_wl_coref_model, build_span_predictor
from spacy.ml.staticvectors import StaticVectors
from spacy.ml.extract_spans import extract_spans, _get_span_indices
from spacy.lang.en import English