mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
Return examples from init_gold_batch
This commit is contained in:
parent
4925c0be34
commit
67c82dbea9
|
@ -272,7 +272,7 @@ cdef class Parser:
|
||||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||||
model, backprop_tok2vec = self.model.begin_update(
|
model, backprop_tok2vec = self.model.begin_update(
|
||||||
[eg.predicted for eg in examples])
|
[eg.predicted for eg in examples])
|
||||||
states, golds, max_steps = self.moves.init_gold_batch(examples)
|
states, golds, examples, max_steps = self.moves.init_gold_batch(examples)
|
||||||
all_states = list(states)
|
all_states = list(states)
|
||||||
states_golds = zip(states, golds)
|
states_golds = zip(states, golds)
|
||||||
for _ in range(max_steps):
|
for _ in range(max_steps):
|
||||||
|
@ -285,6 +285,7 @@ cdef class Parser:
|
||||||
# Follow the predicted action
|
# Follow the predicted action
|
||||||
self.transition_states(states, scores)
|
self.transition_states(states, scores)
|
||||||
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
|
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
|
||||||
|
|
||||||
backprop_tok2vec(golds)
|
backprop_tok2vec(golds)
|
||||||
if sgd not in (None, False):
|
if sgd not in (None, False):
|
||||||
self.model.finish_update(sgd)
|
self.model.finish_update(sgd)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user