mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
Update overfitting test
This commit is contained in:
parent
a46bc03abb
commit
fd574a89c4
|
@ -98,19 +98,29 @@ def test_overfitting_IO(nlp):
|
||||||
for text, annot in TRAIN_DATA:
|
for text, annot in TRAIN_DATA:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
|
||||||
|
|
||||||
|
train_examples = []
|
||||||
|
for text, annot in TRAIN_DATA:
|
||||||
|
eg = Example.from_dict(nlp.make_doc(text), annot)
|
||||||
|
ref = eg.reference
|
||||||
|
# Finally, copy over the head spans to the pred
|
||||||
|
pred = eg.predicted
|
||||||
|
for key, spans in ref.spans.items():
|
||||||
|
if key.startswith("coref_head_clusters"):
|
||||||
|
pred.spans[key] = [pred[span.start:span.end] for span in spans]
|
||||||
|
|
||||||
|
train_examples.append(eg)
|
||||||
nlp.add_pipe("span_predictor", config=CONFIG)
|
nlp.add_pipe("span_predictor", config=CONFIG)
|
||||||
optimizer = nlp.initialize()
|
optimizer = nlp.initialize()
|
||||||
test_text = TRAIN_DATA[0][0]
|
test_text = TRAIN_DATA[0][0]
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
|
|
||||||
# Needs ~12 epochs to converge
|
for i in range(1500):
|
||||||
for i in range(15):
|
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
|
|
||||||
# test the trained model
|
# test the trained model, using the pred since it has heads
|
||||||
doc = nlp(test_text)
|
doc = nlp(train_examples[0].predicted)
|
||||||
|
|
||||||
# Also test the results are still the same after IO
|
# Also test the results are still the same after IO
|
||||||
with make_tempdir() as tmp_dir:
|
with make_tempdir() as tmp_dir:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user