diff --git a/licenses/3rd_party_licenses.txt b/licenses/3rd_party_licenses.txt index d58da9c4a..c605c40b9 100644 --- a/licenses/3rd_party_licenses.txt +++ b/licenses/3rd_party_licenses.txt @@ -127,3 +127,36 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + + +coval +----- + +* Files: scorer.py + +The implementations of ClusterEvaluator, lea, get_cluster_info, and +get_markable_assignments are adapted from coval, which is distributed +under the following license: + +The MIT License (MIT) + +Copyright 2018 Nafise Sadat Moosavi (ns.moosavi at gmail dot com) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/spacy/coref_scorer.py b/spacy/coref_scorer.py deleted file mode 100644 index 981b1cf03..000000000 --- a/spacy/coref_scorer.py +++ /dev/null @@ -1,124 +0,0 @@ -# copied from coval -# https://github.com/ns-moosavi/coval - - -def get_cluster_info(predicted_clusters, gold_clusters): - p2g = get_markable_assignments(predicted_clusters, gold_clusters) - g2p = get_markable_assignments(gold_clusters, predicted_clusters) - # this is the data format used as input by the evaluator - return (gold_clusters, predicted_clusters, g2p, p2g) - - -def get_markable_assignments(in_clusters, out_clusters): - markable_cluster_ids = {} - out_dic = {} - for cluster_id, cluster in enumerate(out_clusters): - for m in cluster: - out_dic[m] = cluster_id - - for cluster in in_clusters: - for im in cluster: - for om in out_dic: - if im == om: - markable_cluster_ids[im] = out_dic[om] - break - - return markable_cluster_ids - - -def f1(p_num, p_den, r_num, r_den, beta=1): - p = 0 if p_den == 0 else p_num / float(p_den) - r = 0 if r_den == 0 else r_num / float(r_den) - return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r) - - -class Evaluator: - def __init__(self, metric, beta=1, keep_aggregated_values=False): - self.p_num = 0 - self.p_den = 0 - self.r_num = 0 - self.r_den = 0 - self.metric = metric - self.beta = beta - self.keep_aggregated_values = keep_aggregated_values - - if keep_aggregated_values: - self.aggregated_p_num = [] - self.aggregated_p_den = [] - self.aggregated_r_num = [] - self.aggregated_r_den = [] - - def update(self, coref_info): - ( - key_clusters, - sys_clusters, - key_mention_sys_cluster, - sys_mention_key_cluster, - ) = coref_info - - pn, pd = self.metric(sys_clusters, key_clusters, sys_mention_key_cluster) - rn, rd = self.metric(key_clusters, sys_clusters, key_mention_sys_cluster) - self.p_num += pn - self.p_den += pd - self.r_num += rn - self.r_den += rd - - if self.keep_aggregated_values: - self.aggregated_p_num.append(pn) - self.aggregated_p_den.append(pd) - self.aggregated_r_num.append(rn) - self.aggregated_r_den.append(rd) - - def get_f1(self): - return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta) - - def get_recall(self): - return 0 if self.r_num == 0 else self.r_num / float(self.r_den) - - def get_precision(self): - return 0 if self.p_num == 0 else self.p_num / float(self.p_den) - - def get_prf(self): - return self.get_precision(), self.get_recall(), self.get_f1() - - def get_counts(self): - return self.p_num, self.p_den, self.r_num, self.r_den - - def get_aggregated_values(self): - return ( - self.aggregated_p_num, - self.aggregated_p_den, - self.aggregated_r_num, - self.aggregated_r_den, - ) - - -def lea(input_clusters, output_clusters, mention_to_gold): - num, den = 0, 0 - - for c in input_clusters: - if len(c) == 1: - all_links = 1 - if ( - c[0] in mention_to_gold - and len(output_clusters[mention_to_gold[c[0]]]) == 1 - ): - common_links = 1 - else: - common_links = 0 - else: - common_links = 0 - all_links = len(c) * (len(c) - 1) / 2.0 - for i, m in enumerate(c): - if m in mention_to_gold: - for m2 in c[i + 1 :]: - if ( - m2 in mention_to_gold - and mention_to_gold[m] == mention_to_gold[m2] - ): - common_links += 1 - - num += len(c) * common_links / float(all_links) - den += len(c) - - return num, den diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index c7fb2ba24..a8c880a39 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -49,10 +49,10 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo X = X[0] word_features = xp2torch(X, requires_grad=is_train) - def backprop(args: ArgsKwargs) -> List[Floats2d]: + # TODO fix or remove type annotations + def backprop(args: ArgsKwargs): #-> List[Floats2d]: # convert to xp and wrap in list gradients = torch2xp(args.args[0]) - # assert isinstance(gradients, Floats2d) return [gradients] return ArgsKwargs(args=(word_features,), kwargs={}), backprop diff --git a/spacy/ml/models/coref_util.py b/spacy/ml/models/coref_util.py index dc9366a61..a004a69d7 100644 --- a/spacy/ml/models/coref_util.py +++ b/spacy/ml/models/coref_util.py @@ -32,23 +32,6 @@ def get_sentence_ids(doc): return out -def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters: - """Given a doc, give the mention clusters. - - This is useful for scoring. - """ - out = [] - for name, val in doc.spans.items(): - if not name.startswith(prefix): - continue - - cluster = [] - for mention in val: - cluster.append((mention.start, mention.end)) - out.append(cluster) - return out - - # from model.py, refactored to be non-member def get_predicted_antecedents(xp, antecedent_idx, antecedent_scores): """Get the ID of the antecedent for each span. -1 if no antecedent.""" diff --git a/spacy/ml/models/span_predictor.py b/spacy/ml/models/span_predictor.py index 7962e4157..378b79e9b 100644 --- a/spacy/ml/models/span_predictor.py +++ b/spacy/ml/models/span_predictor.py @@ -43,23 +43,24 @@ def build_span_predictor( def convert_span_predictor_inputs( - model: Model, X: Tuple[Ints1d, Tuple[Floats2d, Ints1d]], is_train: bool + model: Model, X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], is_train: bool ): tok2vec, (sent_ids, head_ids) = X # Normally we should use the input is_train, but for these two it's not relevant - - def backprop(args: ArgsKwargs) -> List[Floats2d]: + # TODO fix the type here, or remove it + def backprop(args: ArgsKwargs): #-> Tuple[List[Floats2d], None]: gradients = torch2xp(args.args[1]) + # The sent_ids and head_ids are None because no gradients return [[gradients], None] word_features = xp2torch(tok2vec[0], requires_grad=is_train) - sent_ids = xp2torch(sent_ids[0], requires_grad=False) + sent_ids_tensor = xp2torch(sent_ids[0], requires_grad=False) if not head_ids[0].size: - head_ids = torch.empty(size=(0,)) + head_ids_tensor = torch.empty(size=(0,)) else: - head_ids = xp2torch(head_ids[0], requires_grad=False) + head_ids_tensor = xp2torch(head_ids[0], requires_grad=False) - argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={}) + argskwargs = ArgsKwargs(args=(sent_ids_tensor, word_features, head_ids_tensor), kwargs={}) return argskwargs, backprop diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index 96dc80f53..cd07f80e8 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -13,6 +13,7 @@ from ..training import Example, validate_examples, validate_get_examples from ..errors import Errors from ..tokens import Doc from ..vocab import Vocab +from ..util import registry from ..ml.models.coref_util import ( create_gold_scores, @@ -21,10 +22,9 @@ from ..ml.models.coref_util import ( get_clusters_from_doc, get_predicted_clusters, DEFAULT_CLUSTER_PREFIX, - doc2clusters, ) -from ..coref_scorer import Evaluator, get_cluster_info, lea +from ..scorer import Scorer default_config = """ @@ -57,7 +57,14 @@ depth = 2 """ DEFAULT_COREF_MODEL = Config().from_str(default_config)["model"] -DEFAULT_CLUSTERS_PREFIX = "coref_clusters" + +def coref_scorer(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: + return Scorer.score_coref_clusters(examples, **kwargs) + + +@registry.scorers("spacy.coref_scorer.v1") +def make_coref_scorer(): + return coref_scorer @Language.factory( @@ -67,6 +74,7 @@ DEFAULT_CLUSTERS_PREFIX = "coref_clusters" default_config={ "model": DEFAULT_COREF_MODEL, "span_cluster_prefix": DEFAULT_CLUSTER_PREFIX, + "scorer": {"@scorers": "spacy.coref_scorer.v1"}, }, default_score_weights={"coref_f": 1.0, "coref_p": None, "coref_r": None}, ) @@ -74,12 +82,13 @@ def make_coref( nlp: Language, name: str, model, - span_cluster_prefix: str = "coref", + scorer: Optional[Callable], + span_cluster_prefix: str, ) -> "CoreferenceResolver": """Create a CoreferenceResolver component.""" return CoreferenceResolver( - nlp.vocab, model, name, span_cluster_prefix=span_cluster_prefix + nlp.vocab, model, name, span_cluster_prefix=span_cluster_prefix, scorer=scorer ) @@ -96,7 +105,8 @@ class CoreferenceResolver(TrainablePipe): name: str = "coref", *, span_mentions: str = "coref_mentions", - span_cluster_prefix: str, + span_cluster_prefix: str = DEFAULT_CLUSTER_PREFIX, + scorer: Optional[Callable] = coref_scorer, ) -> None: """Initialize a coreference resolution component. @@ -118,7 +128,8 @@ class CoreferenceResolver(TrainablePipe): self.span_cluster_prefix = span_cluster_prefix self._rehearsal_model = None - self.cfg: Dict[str, Any] = {} + self.cfg: Dict[str, Any] = {"span_cluster_prefix": span_cluster_prefix} + self.scorer = scorer def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]: """Apply the pipeline's model to a batch of docs, without modifying them. @@ -276,7 +287,6 @@ class CoreferenceResolver(TrainablePipe): log_marg = ops.softmax(score_matrix + ops.xp.log(top_gscores), axis=1) log_norm = ops.softmax(score_matrix, axis=1) grad = log_norm - log_marg - # gradients.append((grad, cidx)) loss = float((grad**2).sum()) return loss, grad @@ -306,26 +316,3 @@ class CoreferenceResolver(TrainablePipe): assert len(X) > 0, Errors.E923.format(name=self.name) self.model.initialize(X=X, Y=Y) - - def score(self, examples, **kwargs): - """Score a batch of examples using LEA. - For details on how LEA works and why to use it see the paper: - Which Coreference Evaluation Metric Do You Trust? A Proposal for a Link-based Entity Aware Metric - Moosavi and Strube, 2016 - https://api.semanticscholar.org/CorpusID:17606580 - """ - - evaluator = Evaluator(lea) - - for ex in examples: - p_clusters = doc2clusters(ex.predicted, self.span_cluster_prefix) - g_clusters = doc2clusters(ex.reference, self.span_cluster_prefix) - cluster_info = get_cluster_info(p_clusters, g_clusters) - evaluator.update(cluster_info) - - score = { - "coref_f": evaluator.get_f1(), - "coref_p": evaluator.get_precision(), - "coref_r": evaluator.get_recall(), - } - return score diff --git a/spacy/pipeline/span_predictor.py b/spacy/pipeline/span_predictor.py index 23539dce9..d7e96a4b2 100644 --- a/spacy/pipeline/span_predictor.py +++ b/spacy/pipeline/span_predictor.py @@ -5,20 +5,19 @@ from thinc.types import Floats2d, Floats3d, Ints2d from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy from thinc.api import set_dropout_rate, to_categorical from itertools import islice -from statistics import mean from .trainable_pipe import TrainablePipe from ..language import Language from ..training import Example, validate_examples, validate_get_examples from ..errors import Errors -from ..scorer import Scorer +from ..scorer import Scorer, doc2clusters from ..tokens import Doc from ..vocab import Vocab +from ..util import registry from ..ml.models.coref_util import ( MentionClusters, DEFAULT_CLUSTER_PREFIX, - doc2clusters, ) default_span_predictor_config = """ @@ -52,6 +51,15 @@ depth = 2 DEFAULT_SPAN_PREDICTOR_MODEL = Config().from_str(default_span_predictor_config)["model"] +def span_predictor_scorer(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: + return Scorer.score_span_predictions(examples, **kwargs) + + +@registry.scorers("spacy.span_predictor_scorer.v1") +def make_span_predictor_scorer(): + return span_predictor_scorer + + @Language.factory( "span_predictor", assigns=["doc.spans"], @@ -60,6 +68,7 @@ DEFAULT_SPAN_PREDICTOR_MODEL = Config().from_str(default_span_predictor_config)[ "model": DEFAULT_SPAN_PREDICTOR_MODEL, "input_prefix": "coref_head_clusters", "output_prefix": "coref_clusters", + "scorer": {"@scorers": "spacy.span_predictor_scorer.v1"}, }, default_score_weights={"span_accuracy": 1.0}, ) @@ -69,10 +78,16 @@ def make_span_predictor( model, input_prefix: str = "coref_head_clusters", output_prefix: str = "coref_clusters", + scorer: Optional[Callable] = span_predictor_scorer, ) -> "SpanPredictor": """Create a SpanPredictor component.""" return SpanPredictor( - nlp.vocab, model, name, input_prefix=input_prefix, output_prefix=output_prefix + nlp.vocab, + model, + name, + input_prefix=input_prefix, + output_prefix=output_prefix, + scorer=scorer, ) @@ -90,6 +105,7 @@ class SpanPredictor(TrainablePipe): *, input_prefix: str = "coref_head_clusters", output_prefix: str = "coref_clusters", + scorer: Optional[Callable] = span_predictor_scorer, ) -> None: self.vocab = vocab self.model = model @@ -97,7 +113,10 @@ class SpanPredictor(TrainablePipe): self.input_prefix = input_prefix self.output_prefix = output_prefix - self.cfg: Dict[str, Any] = {} + self.scorer = scorer + self.cfg: Dict[str, Any] = { + "output_prefix": output_prefix, + } def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]: # for now pretend there's just one doc @@ -255,35 +274,3 @@ class SpanPredictor(TrainablePipe): assert len(X) > 0, Errors.E923.format(name=self.name) self.model.initialize(X=X, Y=Y) - - def score(self, examples, **kwargs): - """ - Evaluate on reconstructing the correct spans around - gold heads. - """ - scores = [] - xp = self.model.ops.xp - for eg in examples: - starts = [] - ends = [] - pred_starts = [] - pred_ends = [] - ref = eg.reference - pred = eg.predicted - for key, gold_sg in ref.spans.items(): - if key.startswith(self.output_prefix): - pred_sg = pred.spans[key] - for gold_mention, pred_mention in zip(gold_sg, pred_sg): - starts.append(gold_mention.start) - ends.append(gold_mention.end) - pred_starts.append(pred_mention.start) - pred_ends.append(pred_mention.end) - - starts = xp.asarray(starts) - ends = xp.asarray(ends) - pred_starts = xp.asarray(pred_starts) - pred_ends = xp.asarray(pred_ends) - correct = (starts == pred_starts) * (ends == pred_ends) - accuracy = correct.mean() - scores.append(float(accuracy)) - return {"span_accuracy": mean(scores)} diff --git a/spacy/scorer.py b/spacy/scorer.py index 8ee6294ad..14b4b2a79 100644 --- a/spacy/scorer.py +++ b/spacy/scorer.py @@ -2,6 +2,7 @@ from typing import Optional, Iterable, Dict, Set, List, Any, Callable, Tuple from typing import TYPE_CHECKING import numpy as np from collections import defaultdict +from statistics import mean from .training import Example from .tokens import Token, Doc, Span @@ -9,6 +10,7 @@ from .errors import Errors from .util import get_lang_class, SimpleFrozenList from .morphology import Morphology + if TYPE_CHECKING: # This lets us add type hints for mypy etc. without causing circular imports from .language import Language # noqa: F401 @@ -873,6 +875,66 @@ class Scorer: f"{attr}_las_per_type": None, } + @staticmethod + def score_coref_clusters(examples: Iterable[Example], **cfg): + """Score a batch of examples using LEA. + + For details on how LEA works and why to use it see the paper: + Which Coreference Evaluation Metric Do You Trust? A Proposal for a Link-based Entity Aware Metric + Moosavi and Strube, 2016 + https://api.semanticscholar.org/CorpusID:17606580 + """ + + span_cluster_prefix = cfg["span_cluster_prefix"] + + evaluator = ClusterEvaluator(lea) + + for ex in examples: + p_clusters = doc2clusters(ex.predicted, span_cluster_prefix) + g_clusters = doc2clusters(ex.reference, span_cluster_prefix) + cluster_info = get_cluster_info(p_clusters, g_clusters) + evaluator.update(cluster_info) + + score = { + "coref_f": evaluator.get_f1(), + "coref_p": evaluator.get_precision(), + "coref_r": evaluator.get_recall(), + } + return score + + @staticmethod + def score_span_predictions(examples: Iterable[Example], **cfg): + """Evaluate reconstruction of the correct spans from gold heads. + """ + scores = [] + output_prefix = cfg["output_prefix"] + for eg in examples: + starts = [] + ends = [] + pred_starts = [] + pred_ends = [] + ref = eg.reference + pred = eg.predicted + for key, gold_sg in ref.spans.items(): + if key.startswith(output_prefix): + pred_sg = pred.spans[key] + for gold_mention, pred_mention in zip(gold_sg, pred_sg): + starts.append(gold_mention.start) + ends.append(gold_mention.end) + pred_starts.append(pred_mention.start) + pred_ends.append(pred_mention.end) + + + # see how many are perfect + cs = [a == b for a, b in zip(starts, pred_starts)] + ce = [a == b for a, b in zip(ends, pred_ends)] + correct = [int(a and b) for a, b in zip(cs, ce)] + accuracy = sum(correct) / len(correct) + + scores.append(float(accuracy)) + out_key = f"span_{output_prefix}_accuracy" + return {out_key: mean(scores)} + def get_ner_prf(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: """Compute micro-PRF and per-entity PRF scores for a sequence of examples.""" @@ -1143,3 +1205,161 @@ def _auc(x, y): # regular numpy.ndarray instances. area = area.dtype.type(area) return area + + +# The following implementations of get_cluster_info(), get_markable_assignments, +# and ClusterEvaluator are adapted from coval, which is distributed under the +# MIT License. +# Copyright 2018 Nafise Sadat Moosavi +# See licenses/3rd_party_licenses.txt +def get_cluster_info(predicted_clusters, gold_clusters): + p2g = get_markable_assignments(predicted_clusters, gold_clusters) + g2p = get_markable_assignments(gold_clusters, predicted_clusters) + # this is the data format used as input by the evaluator + return (gold_clusters, predicted_clusters, g2p, p2g) + + +def get_markable_assignments(in_clusters, out_clusters): + markable_cluster_ids = {} + out_dic = {} + for cluster_id, cluster in enumerate(out_clusters): + for m in cluster: + out_dic[m] = cluster_id + + for cluster in in_clusters: + for im in cluster: + for om in out_dic: + if im == om: + markable_cluster_ids[im] = out_dic[om] + break + + return markable_cluster_ids + + +class ClusterEvaluator: + def __init__(self, metric, beta=1, keep_aggregated_values=False): + self.p_num = 0 + self.p_den = 0 + self.r_num = 0 + self.r_den = 0 + self.metric = metric + self.beta = beta + self.keep_aggregated_values = keep_aggregated_values + + if keep_aggregated_values: + self.aggregated_p_num = [] + self.aggregated_p_den = [] + self.aggregated_r_num = [] + self.aggregated_r_den = [] + + def update(self, coref_info): + ( + key_clusters, + sys_clusters, + key_mention_sys_cluster, + sys_mention_key_cluster, + ) = coref_info + + pn, pd = self.metric(sys_clusters, key_clusters, sys_mention_key_cluster) + rn, rd = self.metric(key_clusters, sys_clusters, key_mention_sys_cluster) + self.p_num += pn + self.p_den += pd + self.r_num += rn + self.r_den += rd + + if self.keep_aggregated_values: + self.aggregated_p_num.append(pn) + self.aggregated_p_den.append(pd) + self.aggregated_r_num.append(rn) + self.aggregated_r_den.append(rd) + + def f1(self, p_num, p_den, r_num, r_den, beta=1): + p = 0 + if p_den != 0: + p = p_num / float(p_den) + r = 0 + if r_den != 0: + r = r_num / float(r_den) + + if p + r == 0: + return 0 + + return (1 + beta * beta) * p * r / (beta * beta * p + r) + + def get_f1(self): + return self.f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta) + + def get_recall(self): + if self.r_num == 0: + return 0 + + return self.r_num / float(self.r_den) + + def get_precision(self): + if self.p_num == 0: + return 0 + + return self.p_num / float(self.p_den) + + def get_prf(self): + return self.get_precision(), self.get_recall(), self.get_f1() + + def get_counts(self): + return self.p_num, self.p_den, self.r_num, self.r_den + + def get_aggregated_values(self): + return ( + self.aggregated_p_num, + self.aggregated_p_den, + self.aggregated_r_num, + self.aggregated_r_den, + ) + + +def lea(input_clusters, output_clusters, mention_to_gold): + num, den = 0, 0 + + for c in input_clusters: + if len(c) == 1: + all_links = 1 + if ( + c[0] in mention_to_gold + and len(output_clusters[mention_to_gold[c[0]]]) == 1 + ): + common_links = 1 + else: + common_links = 0 + else: + common_links = 0 + all_links = len(c) * (len(c) - 1) / 2.0 + for i, m in enumerate(c): + if m in mention_to_gold: + for m2 in c[i + 1 :]: + if ( + m2 in mention_to_gold + and mention_to_gold[m] == mention_to_gold[m2] + ): + common_links += 1 + + num += len(c) * common_links / float(all_links) + den += len(c) + + return num, den + + +# This is coref related, but not from coval. +def doc2clusters(doc: Doc, prefix: str) -> List[List[Tuple[int, int]]]: + """Given a doc, give the mention clusters. + + This is used for scoring. + """ + out = [] + for name, val in doc.spans.items(): + if not name.startswith(prefix): + continue + + cluster = [] + for mention in val: + cluster.append((mention.start, mention.end)) + out.append(cluster) + return out diff --git a/spacy/tests/pipeline/test_coref.py b/spacy/tests/pipeline/test_coref.py index 25de6e356..53f0b2011 100644 --- a/spacy/tests/pipeline/test_coref.py +++ b/spacy/tests/pipeline/test_coref.py @@ -5,24 +5,26 @@ from spacy import util from spacy.training import Example from spacy.lang.en import English from spacy.tests.util import make_tempdir -from spacy.pipeline.coref import DEFAULT_CLUSTERS_PREFIX from spacy.ml.models.coref_util import ( + DEFAULT_CLUSTER_PREFIX, select_non_crossing_spans, get_sentence_ids, ) +from thinc.util import has_torch + # fmt: off TRAIN_DATA = [ ( "Yes, I noticed that many friends around me received it. It seems that almost everyone received this SMS.", { "spans": { - f"{DEFAULT_CLUSTERS_PREFIX}_1": [ + f"{DEFAULT_CLUSTER_PREFIX}_1": [ (5, 6, "MENTION"), # I (40, 42, "MENTION"), # me ], - f"{DEFAULT_CLUSTERS_PREFIX}_2": [ + f"{DEFAULT_CLUSTER_PREFIX}_2": [ (52, 54, "MENTION"), # it (95, 103, "MENTION"), # this SMS ] @@ -45,18 +47,20 @@ def snlp(): return en +@pytest.mark.skipif(not has_torch, reason="Torch not available") def test_add_pipe(nlp): nlp.add_pipe("coref") assert nlp.pipe_names == ["coref"] +@pytest.mark.skipif(not has_torch, reason="Torch not available") def test_not_initialized(nlp): nlp.add_pipe("coref") text = "She gave me her pen." - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="E109"): nlp(text) - +@pytest.mark.skipif(not has_torch, reason="Torch not available") def test_initialized(nlp): nlp.add_pipe("coref") nlp.initialize() @@ -68,15 +72,16 @@ def test_initialized(nlp): assert len(v) <= 15 +@pytest.mark.skipif(not has_torch, reason="Torch not available") def test_initialized_short(nlp): nlp.add_pipe("coref") nlp.initialize() assert nlp.pipe_names == ["coref"] text = "Hi there" doc = nlp(text) - print(doc.spans) +@pytest.mark.skipif(not has_torch, reason="Torch not available") def test_coref_serialization(nlp): # Test that the coref component can be serialized nlp.add_pipe("coref", last=True) @@ -101,6 +106,7 @@ def test_coref_serialization(nlp): # assert spans_result == spans_result2 +@pytest.mark.skipif(not has_torch, reason="Torch not available") def test_overfitting_IO(nlp): # Simple test to try and quickly overfit the senter - ensuring the ML models work correctly train_examples = [] @@ -147,6 +153,7 @@ def test_overfitting_IO(nlp): # assert_equal(batch_deps_1, no_batch_deps) +@pytest.mark.skipif(not has_torch, reason="Torch not available") def test_crossing_spans(): starts = [6, 10, 0, 1, 0, 1, 0, 1, 2, 2, 2] ends = [12, 12, 2, 3, 3, 4, 4, 4, 3, 4, 5] @@ -158,6 +165,7 @@ def test_crossing_spans(): guess = sorted(guess) assert gold == guess +@pytest.mark.skipif(not has_torch, reason="Torch not available") def test_sentence_map(snlp): doc = snlp("I like text. This is text.") sm = get_sentence_ids(doc) diff --git a/spacy/tests/test_models.py b/spacy/tests/test_models.py index 794f9ca87..b3ce46e34 100644 --- a/spacy/tests/test_models.py +++ b/spacy/tests/test_models.py @@ -7,8 +7,9 @@ from numpy.testing import assert_array_equal, assert_array_almost_equal import numpy from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder from spacy.ml.models import build_bow_text_classifier, build_simple_cnn_text_classifier +from spacy.ml.models import build_spancat_model if has_torch: - from spacy.ml.models import build_spancat_model, build_wl_coref_model + from spacy.ml.models import build_wl_coref_model, build_span_predictor from spacy.ml.staticvectors import StaticVectors from spacy.ml.extract_spans import extract_spans, _get_span_indices from spacy.lang.en import English