diff --git a/spacy/tests/pipeline/test_span_predictor.py b/spacy/tests/pipeline/test_span_predictor.py index 9281df354..4434b6651 100644 --- a/spacy/tests/pipeline/test_span_predictor.py +++ b/spacy/tests/pipeline/test_span_predictor.py @@ -98,19 +98,29 @@ def test_overfitting_IO(nlp): for text, annot in TRAIN_DATA: train_examples.append(Example.from_dict(nlp.make_doc(text), annot)) + train_examples = [] + for text, annot in TRAIN_DATA: + eg = Example.from_dict(nlp.make_doc(text), annot) + ref = eg.reference + # Finally, copy over the head spans to the pred + pred = eg.predicted + for key, spans in ref.spans.items(): + if key.startswith("coref_head_clusters"): + pred.spans[key] = [pred[span.start:span.end] for span in spans] + + train_examples.append(eg) nlp.add_pipe("span_predictor", config=CONFIG) optimizer = nlp.initialize() test_text = TRAIN_DATA[0][0] doc = nlp(test_text) - # Needs ~12 epochs to converge - for i in range(15): + for i in range(1500): losses = {} nlp.update(train_examples, sgd=optimizer, losses=losses) doc = nlp(test_text) - # test the trained model - doc = nlp(test_text) + # test the trained model, using the pred since it has heads + doc = nlp(train_examples[0].predicted) # Also test the results are still the same after IO with make_tempdir() as tmp_dir: