mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
tok2vec.update instead of predict (#6113)
This commit is contained in:
parent
e0e793be4d
commit
86a08f819d
|
@ -128,7 +128,7 @@ def debug_model(
|
||||||
goldY = None
|
goldY = None
|
||||||
for e in range(3):
|
for e in range(3):
|
||||||
if tok2vec:
|
if tok2vec:
|
||||||
tok2vec.predict(X)
|
tok2vec.update([Example.from_dict(x, {}) for x in X])
|
||||||
Y, get_dX = model.begin_update(X)
|
Y, get_dX = model.begin_update(X)
|
||||||
if goldY is None:
|
if goldY is None:
|
||||||
goldY = _simulate_gold(Y)
|
goldY = _simulate_gold(Y)
|
||||||
|
|
|
@ -184,7 +184,7 @@ def test_tok2vec_listener_callback():
|
||||||
label_sample = [tagger.model.ops.asarray(gold_array, dtype="float32")]
|
label_sample = [tagger.model.ops.asarray(gold_array, dtype="float32")]
|
||||||
tagger.model.initialize(X=docs, Y=label_sample)
|
tagger.model.initialize(X=docs, Y=label_sample)
|
||||||
docs = [nlp.make_doc("Another entirely random sentence")]
|
docs = [nlp.make_doc("Another entirely random sentence")]
|
||||||
tok2vec.predict(docs)
|
tok2vec.update([Example.from_dict(x, {}) for x in docs])
|
||||||
Y, get_dX = tagger.model.begin_update(docs)
|
Y, get_dX = tagger.model.begin_update(docs)
|
||||||
# assure that the backprop call works (and doesn't hit a 'None' callback)
|
# assure that the backprop call works (and doesn't hit a 'None' callback)
|
||||||
assert get_dX(Y) is not None
|
assert get_dX(Y) is not None
|
||||||
|
|
Loading…
Reference in New Issue
Block a user