Small update_with_oracle_cut_size fixes

Fix an off-by-one in `TransitionModel.forward`, where we always did one
move more than the maximum number of moves.

This explosed another issue: when creating cut states, we skipped states
where the (maximum number of) moves from that state only applied
transitions that did not modify the buffer.

Replace uses of `random.uniform` by `random.randrange`.
This commit is contained in:
Daniël de Kok 2023-02-21 15:51:29 +01:00
parent e27c60a702
commit 10f5e9413d
2 changed files with 5 additions and 6 deletions

View File

@ -338,9 +338,9 @@ def _forward_fallback(
all_ids.append(ids)
all_statevecs.append(statevecs)
all_which.append(which)
n_moves += 1
if n_moves >= max_moves >= 1:
break
n_moves += 1
def backprop_parser(d_states_d_scores):
ids = ops.xp.vstack(all_ids)

View File

@ -258,7 +258,7 @@ class Parser(TrainablePipe):
# batch uniform length. Since we do not have a gold standard
# sequence, we use the teacher's predictions as the gold
# standard.
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
max_moves = random.randrange(max(max_moves // 2, 1), max_moves * 2)
states = self._init_batch_from_teacher(teacher_pipe, student_docs, max_moves)
else:
states = self.moves.init_batch(student_docs)
@ -425,7 +425,7 @@ class Parser(TrainablePipe):
if max_moves >= 1:
# Chop sequences into lengths of this many words, to make the
# batch uniform length.
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
max_moves = random.randrange(max(max_moves // 2, 1), max_moves * 2)
init_states, gold_states, _ = self._init_gold_batch(
examples,
max_length=max_moves
@ -729,7 +729,6 @@ class Parser(TrainablePipe):
action.do(state.c, action.label)
if state.is_final():
break
if moves.has_gold(eg, start_state.B(0), state.B(0)):
states.append(start_state)
golds.append(gold)
if state.is_final():