Try to fix parser training

This commit is contained in:
Matthew Honnibal 2021-01-25 14:40:05 +11:00
parent 3a6b93ae3a
commit 456c881ae3
2 changed files with 10 additions and 7 deletions

View File

@ -83,6 +83,8 @@ cdef class TransitionSystem:
def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None):
if state.is_final():
return []
if not self.has_gold(eg):
return []
cdef Pool mem = Pool()
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.n_moves > 0

View File

@ -316,8 +316,9 @@ cdef class Parser(TrainablePipe):
validate_examples(examples, "Parser.update")
for multitask in self._multitasks:
multitask.update(examples, drop=drop, sgd=sgd)
examples = [eg for eg in examples if self.moves.has_gold(eg)]
# We need to take care to act on the whole batch, because we might be
# getting vectors via a listener.
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
if len(examples) == 0:
return losses
set_dropout_rate(self.model, drop)
@ -347,7 +348,8 @@ cdef class Parser(TrainablePipe):
states, golds, _ = self.moves.init_gold_batch(examples)
if not states:
return losses
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
docs = [eg.predicted for eg in examples]
model, backprop_tok2vec = self.model.begin_update(docs)
all_states = list(states)
states_golds = list(zip(states, golds))
@ -371,7 +373,6 @@ cdef class Parser(TrainablePipe):
backprop_tok2vec(golds)
if sgd not in (None, False):
self.finish_update(sgd)
docs = [eg.predicted for eg in examples]
# If we want to set the annotations based on predictions, it's really
# hard to avoid parsing the data twice :(.
# The issue is that we cut up the gold batch into sub-states, and that
@ -601,7 +602,7 @@ cdef class Parser(TrainablePipe):
states = []
golds = []
for state, eg, history in zip(all_states, examples, oracle_histories):
if state.is_final():
if not history:
continue
gold = self.moves.init_gold(state, eg)
if len(history) < max_length:
@ -609,6 +610,8 @@ cdef class Parser(TrainablePipe):
golds.append(gold)
continue
for i in range(0, len(history), max_length):
if state.is_final():
break
start_state = state.copy()
for clas in history[i:i+max_length]:
action = self.moves.c[clas]
@ -618,6 +621,4 @@ cdef class Parser(TrainablePipe):
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
states.append(start_state)
golds.append(gold)
if state.is_final():
break
return states, golds, max_length