diff --git a/spacy/tests/pipeline/test_span_finder.py b/spacy/tests/pipeline/test_span_finder.py index 9777dc9f6..ebe1879d4 100644 --- a/spacy/tests/pipeline/test_span_finder.py +++ b/spacy/tests/pipeline/test_span_finder.py @@ -198,29 +198,17 @@ def test_span_finder_suggester(): nlp = Language() docs = [nlp("This is an example."), nlp("This is the second example.")] docs[0].spans[SPANS_KEY] = [docs[0][3:4]] - docs[1].spans[SPANS_KEY] = [docs[1][3:5]] - span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY}) - nlp.initialize() - # Setting bias terms to 1000 to always make it predict something. - scorer_model = span_finder.model.get_ref("scorer") - bias = scorer_model.layers[0].get_param("b") - bias += 1000 - scorer_model.layers[0].set_param("b", bias) - span_finder.set_annotations(docs, span_finder.predict(docs)) - + docs[1].spans[SPANS_KEY] = [docs[1][0:4], docs[1][3:5]] suggester = registry.misc.get("spacy.span_finder_suggester.v1")( spans_key=SPANS_KEY ) - candidates = suggester(docs) - - span_length = 0 - for doc in docs: - span_length += len(doc.spans[SPANS_KEY]) - - assert span_length == len(candidates.dataXd) assert type(candidates) == Ragged - assert len(candidates.dataXd[0]) == 2 + assert len(candidates) == 2 + assert list(candidates.dataXd[0]) == [3, 4] + assert list(candidates.dataXd[1]) == [0, 4] + assert list(candidates.dataXd[2]) == [3, 5] + assert list(candidates.lengths) == [1, 2] def test_overfitting_IO():