Improve cutting logic in parser

This commit is contained in:
Matthw Honnibal 2020-07-08 11:27:54 +02:00
parent 42e1109def
commit ca989f4cc4

View File

@ -292,10 +292,8 @@ cdef class Parser:
if not states: if not states:
return losses return losses
all_states = list(states) all_states = list(states)
states_golds = zip(states, golds) states_golds = list(zip(states, golds))
for _ in range(max_steps): while states_golds:
if not states_golds:
break
states, golds = zip(*states_golds) states, golds = zip(*states_golds)
scores, backprop = model.begin_update(states) scores, backprop = model.begin_update(states)
d_scores = self.get_batch_loss(states, golds, scores, losses) d_scores = self.get_batch_loss(states, golds, scores, losses)
@ -519,21 +517,25 @@ cdef class Parser:
StateClass state StateClass state
Transition action Transition action
all_states = self.moves.init_batch([eg.predicted for eg in examples]) all_states = self.moves.init_batch([eg.predicted for eg in examples])
states = []
golds = []
kept = [] kept = []
max_length_seen = 0 max_length_seen = 0
for state, eg in zip(all_states, examples): for state, eg in zip(all_states, examples):
if self.moves.has_gold(eg) and not state.is_final(): if self.moves.has_gold(eg) and not state.is_final():
gold = self.moves.init_gold(state, eg) gold = self.moves.init_gold(state, eg)
if len(eg.x) < max_length:
states.append(state)
golds.append(gold)
else:
oracle_actions = self.moves.get_oracle_sequence_from_state( oracle_actions = self.moves.get_oracle_sequence_from_state(
state.copy(), gold) state.copy(), gold)
kept.append((eg, state, gold, oracle_actions)) kept.append((eg, state, gold, oracle_actions))
min_length = min(min_length, len(oracle_actions)) min_length = min(min_length, len(oracle_actions))
max_length_seen = max(max_length, len(oracle_actions)) max_length_seen = max(max_length, len(oracle_actions))
if not kept: if not kept:
return [], [], 0 return states, golds, 0
max_length = max(min_length, min(max_length, max_length_seen)) max_length = max(min_length, min(max_length, max_length_seen))
states = []
golds = []
cdef int clas cdef int clas
max_moves = 0 max_moves = 0
for eg, state, gold, oracle_actions in kept: for eg, state, gold, oracle_actions in kept: