only test suggester and test result exhaustively

This commit is contained in:
kadarakos 2023-06-02 08:54:45 +00:00
parent 658c4aee35
commit 37c4ad5007

View File

@ -198,29 +198,17 @@ def test_span_finder_suggester():
nlp = Language()
docs = [nlp("This is an example."), nlp("This is the second example.")]
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
nlp.initialize()
# Setting bias terms to 1000 to always make it predict something.
scorer_model = span_finder.model.get_ref("scorer")
bias = scorer_model.layers[0].get_param("b")
bias += 1000
scorer_model.layers[0].set_param("b", bias)
span_finder.set_annotations(docs, span_finder.predict(docs))
docs[1].spans[SPANS_KEY] = [docs[1][0:4], docs[1][3:5]]
suggester = registry.misc.get("spacy.span_finder_suggester.v1")(
spans_key=SPANS_KEY
)
candidates = suggester(docs)
span_length = 0
for doc in docs:
span_length += len(doc.spans[SPANS_KEY])
assert span_length == len(candidates.dataXd)
assert type(candidates) == Ragged
assert len(candidates.dataXd[0]) == 2
assert len(candidates) == 2
assert list(candidates.dataXd[0]) == [3, 4]
assert list(candidates.dataXd[1]) == [0, 4]
assert list(candidates.dataXd[2]) == [3, 5]
assert list(candidates.lengths) == [1, 2]
def test_overfitting_IO():