Update neural net tests

This commit is contained in:
Matthew Honnibal 2017-05-19 18:11:29 -05:00
parent 08766240c3
commit 836fe1d880

View File

@ -55,26 +55,17 @@ def test_build_model(parser):
def test_predict_doc(parser, tok2vec, model, doc): def test_predict_doc(parser, tok2vec, model, doc):
state = {} doc.tensor = tok2vec([doc])
state['tokvecs'] = tok2vec([doc])
parser.model = model parser.model = model
parser(doc, state=state) parser(doc)
def test_update_doc(parser, tok2vec, model, doc, gold): def test_update_doc(parser, tok2vec, model, doc, gold):
parser.model = model parser.model = model
tokvecs, bp_tokvecs = tok2vec.begin_update([doc]) tokvecs, bp_tokvecs = tok2vec.begin_update([doc])
state = {'tokvecs': tokvecs, 'bp_tokvecs': bp_tokvecs} d_tokvecs = parser.update((doc, tokvecs), gold)
state = parser.update(doc, gold, state=state) assert d_tokvecs.shape == tokvecs.shape
loss1 = state['parser_loss']
assert loss1 > 0
state = parser.update(doc, gold, state=state)
loss2 = state['parser_loss']
assert loss2 == loss1
def optimize(weights, gradient, key=None): def optimize(weights, gradient, key=None):
weights -= 0.001 * gradient weights -= 0.001 * gradient
state = parser.update(doc, gold, sgd=optimize, state=state) bp_tokvecs(d_tokvecs, sgd=optimize)
loss3 = state['parser_loss'] assert d_tokvecs.sum() == 0.
state = parser.update(doc, gold, sgd=optimize, state=state)
lossr = state['parser_loss']
assert loss3 < loss2