Try different oracle cuts

This commit is contained in:
Matthew Honnibal 2020-09-04 03:52:29 +02:00
parent 1c07820681
commit cd9194c823

View File

@ -281,7 +281,7 @@ cdef class Parser(Pipe):
# Chop sequences into lengths of this many words, to make the # Chop sequences into lengths of this many words, to make the
# batch uniform length. # batch uniform length.
max_moves = int(random.uniform(max_moves // 2, max_moves * 2)) max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
states, golds, _ = self._init_gold_batch( states, golds, max_moves = self._init_gold_batch(
examples, examples,
max_length=max_moves max_length=max_moves
) )
@ -304,10 +304,9 @@ cdef class Parser(Pipe):
# Follow the predicted action # Follow the predicted action
self.transition_states(states, scores) self.transition_states(states, scores)
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()] states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
n_moves += 1
if max_moves >= 1 and n_moves >= max_moves: if max_moves >= 1 and n_moves >= max_moves:
break break
n_moves += 1
backprop_tok2vec(golds) backprop_tok2vec(golds)
if sgd not in (None, False): if sgd not in (None, False):
self.model.finish_update(sgd) self.model.finish_update(sgd)
@ -513,21 +512,19 @@ cdef class Parser(Pipe):
StateClass state StateClass state
Transition action Transition action
all_states = self.moves.init_batch([eg.predicted for eg in examples]) all_states = self.moves.init_batch([eg.predicted for eg in examples])
states = []
golds = []
to_cut = [] to_cut = []
for state, eg in zip(all_states, examples): for state, eg in zip(all_states, examples):
if self.moves.has_gold(eg) and not state.is_final(): if self.moves.has_gold(eg) and not state.is_final():
gold = self.moves.init_gold(state, eg) gold = self.moves.init_gold(state, eg)
if len(eg.x) < max_length: oracle_actions = self.moves.get_oracle_sequence_from_state(
states.append(state) state.copy(), gold)
golds.append(gold) to_cut.append((eg, state, gold, oracle_actions))
else: states = []
oracle_actions = self.moves.get_oracle_sequence_from_state( golds = []
state.copy(), gold)
to_cut.append((eg, state, gold, oracle_actions))
if not to_cut: if not to_cut:
return states, golds, 0 return states, golds, 0
lengths = [len(x[-1]) for x in to_cut]
max_length = min(min(lengths), max_length)
cdef int clas cdef int clas
for eg, state, gold, oracle_actions in to_cut: for eg, state, gold, oracle_actions in to_cut:
for i in range(0, len(oracle_actions), max_length): for i in range(0, len(oracle_actions), max_length):