Try different oracle cuts

This commit is contained in:
Matthew Honnibal 2020-09-04 03:52:29 +02:00
parent 1c07820681
commit cd9194c823

View File

@ -281,7 +281,7 @@ cdef class Parser(Pipe):
# Chop sequences into lengths of this many words, to make the
# batch uniform length.
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
states, golds, _ = self._init_gold_batch(
states, golds, max_moves = self._init_gold_batch(
examples,
max_length=max_moves
)
@ -304,10 +304,9 @@ cdef class Parser(Pipe):
# Follow the predicted action
self.transition_states(states, scores)
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
n_moves += 1
if max_moves >= 1 and n_moves >= max_moves:
break
n_moves += 1
backprop_tok2vec(golds)
if sgd not in (None, False):
self.model.finish_update(sgd)
@ -513,21 +512,19 @@ cdef class Parser(Pipe):
StateClass state
Transition action
all_states = self.moves.init_batch([eg.predicted for eg in examples])
states = []
golds = []
to_cut = []
for state, eg in zip(all_states, examples):
if self.moves.has_gold(eg) and not state.is_final():
gold = self.moves.init_gold(state, eg)
if len(eg.x) < max_length:
states.append(state)
golds.append(gold)
else:
oracle_actions = self.moves.get_oracle_sequence_from_state(
state.copy(), gold)
to_cut.append((eg, state, gold, oracle_actions))
states = []
golds = []
if not to_cut:
return states, golds, 0
lengths = [len(x[-1]) for x in to_cut]
max_length = min(min(lengths), max_length)
cdef int clas
for eg, state, gold, oracle_actions in to_cut:
for i in range(0, len(oracle_actions), max_length):