Update overfitting test

This commit is contained in:
Paul O'Leary McCann 2022-07-03 19:34:15 +09:00
parent a46bc03abb
commit fd574a89c4

View File

@ -98,19 +98,29 @@ def test_overfitting_IO(nlp):
for text, annot in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
train_examples = []
for text, annot in TRAIN_DATA:
eg = Example.from_dict(nlp.make_doc(text), annot)
ref = eg.reference
# Finally, copy over the head spans to the pred
pred = eg.predicted
for key, spans in ref.spans.items():
if key.startswith("coref_head_clusters"):
pred.spans[key] = [pred[span.start:span.end] for span in spans]
train_examples.append(eg)
nlp.add_pipe("span_predictor", config=CONFIG)
optimizer = nlp.initialize()
test_text = TRAIN_DATA[0][0]
doc = nlp(test_text)
# Needs ~12 epochs to converge
for i in range(15):
for i in range(1500):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
doc = nlp(test_text)
# test the trained model
doc = nlp(test_text)
# test the trained model, using the pred since it has heads
doc = nlp(train_examples[0].predicted)
# Also test the results are still the same after IO
with make_tempdir() as tmp_dir: