Improve cutting logic in parser

This commit is contained in:
Matthw Honnibal 2020-07-08 11:27:54 +02:00
parent 42e1109def
commit ca989f4cc4

View File

@ -292,10 +292,8 @@ cdef class Parser:
if not states:
return losses
all_states = list(states)
states_golds = zip(states, golds)
for _ in range(max_steps):
if not states_golds:
break
states_golds = list(zip(states, golds))
while states_golds:
states, golds = zip(*states_golds)
scores, backprop = model.begin_update(states)
d_scores = self.get_batch_loss(states, golds, scores, losses)
@ -519,21 +517,25 @@ cdef class Parser:
StateClass state
Transition action
all_states = self.moves.init_batch([eg.predicted for eg in examples])
states = []
golds = []
kept = []
max_length_seen = 0
for state, eg in zip(all_states, examples):
if self.moves.has_gold(eg) and not state.is_final():
gold = self.moves.init_gold(state, eg)
oracle_actions = self.moves.get_oracle_sequence_from_state(
state.copy(), gold)
kept.append((eg, state, gold, oracle_actions))
min_length = min(min_length, len(oracle_actions))
max_length_seen = max(max_length, len(oracle_actions))
if len(eg.x) < max_length:
states.append(state)
golds.append(gold)
else:
oracle_actions = self.moves.get_oracle_sequence_from_state(
state.copy(), gold)
kept.append((eg, state, gold, oracle_actions))
min_length = min(min_length, len(oracle_actions))
max_length_seen = max(max_length, len(oracle_actions))
if not kept:
return [], [], 0
return states, golds, 0
max_length = max(min_length, min(max_length, max_length_seen))
states = []
golds = []
cdef int clas
max_moves = 0
for eg, state, gold, oracle_actions in kept: