From 9372b22d32308210c2684ac298453897fe83ac9c Mon Sep 17 00:00:00 2001 From: kadarakos Date: Fri, 2 Jun 2023 10:08:16 +0000 Subject: [PATCH] move preset_spans_suggester test to spancat tests --- spacy/tests/pipeline/test_span_finder.py | 18 ------------------ spacy/tests/pipeline/test_spancat.py | 17 +++++++++++++++++ 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/spacy/tests/pipeline/test_span_finder.py b/spacy/tests/pipeline/test_span_finder.py index 7050f4653..4caa3a33f 100644 --- a/spacy/tests/pipeline/test_span_finder.py +++ b/spacy/tests/pipeline/test_span_finder.py @@ -1,6 +1,5 @@ import pytest from thinc.api import Config -from thinc.types import Ragged from spacy.language import Language from spacy.lang.en import English @@ -193,23 +192,6 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count): assert all(min_length <= len(span) <= max_length for span in doc.spans[SPANS_KEY]) -def test_span_finder_suggester(): - nlp = Language() - docs = [nlp("This is an example."), nlp("This is the second example.")] - docs[0].spans[SPANS_KEY] = [docs[0][3:4]] - docs[1].spans[SPANS_KEY] = [docs[1][0:4], docs[1][3:5]] - suggester = registry.misc.get("spacy.preset_spans_suggester.v1")( - spans_key=SPANS_KEY - ) - candidates = suggester(docs) - assert type(candidates) == Ragged - assert len(candidates) == 2 - assert list(candidates.dataXd[0]) == [3, 4] - assert list(candidates.dataXd[1]) == [0, 4] - assert list(candidates.dataXd[2]) == [3, 5] - assert list(candidates.lengths) == [1, 2] - - def test_overfitting_IO(): # Simple test to try and quickly overfit the span_finder component - ensuring the ML models work correctly fix_random_seed(0) diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index 199ef2b2a..1d29c89ec 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -406,6 +406,23 @@ def test_ngram_sizes(en_tokenizer): assert_array_equal(OPS.to_numpy(ngrams_3.lengths), [0, 1, 3, 6, 9]) +def test_preset_spans_suggester(): + nlp = Language() + docs = [nlp("This is an example."), nlp("This is the second example.")] + docs[0].spans[SPAN_KEY] = [docs[0][3:4]] + docs[1].spans[SPAN_KEY] = [docs[1][0:4], docs[1][3:5]] + suggester = registry.misc.get("spacy.preset_spans_suggester.v1")( + spans_key=SPAN_KEY + ) + candidates = suggester(docs) + assert type(candidates) == Ragged + assert len(candidates) == 2 + assert list(candidates.dataXd[0]) == [3, 4] + assert list(candidates.dataXd[1]) == [0, 4] + assert list(candidates.dataXd[2]) == [3, 5] + assert list(candidates.lengths) == [1, 2] + + def test_overfitting_IO(): # Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly fix_random_seed(0)