Fix parser gold cutting and gradient normalization

This commit is contained in:
Matthw Honnibal 2020-07-01 01:02:58 +02:00
parent 8c5a88e777
commit a1b6add4c8

View File

@ -265,11 +265,15 @@ cdef class Parser:
free(is_valid) free(is_valid)
def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None): def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None):
cdef StateClass state
if losses is None: if losses is None:
losses = {} losses = {}
losses.setdefault(self.name, 0.) losses.setdefault(self.name, 0.)
for multitask in self._multitasks: for multitask in self._multitasks:
multitask.update(examples, drop=drop, sgd=sgd) multitask.update(examples, drop=drop, sgd=sgd)
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
if n_examples == 0:
return losses
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
# Prepare the stepwise model, and get the callback for finishing the batch # Prepare the stepwise model, and get the callback for finishing the batch
model, backprop_tok2vec = self.model.begin_update( model, backprop_tok2vec = self.model.begin_update(
@ -280,10 +284,13 @@ cdef class Parser:
cut_size = self.cfg["update_with_oracle_cut_size"] cut_size = self.cfg["update_with_oracle_cut_size"]
states, golds, max_steps = self._init_gold_batch( states, golds, max_steps = self._init_gold_batch(
examples, examples,
max_length=numpy.random.choice(range(20, cut_size)) max_length=numpy.random.choice(range(5, cut_size))
) )
else: else:
states, golds, max_steps = self.moves.init_gold_batch(examples) states, golds, _ = self.moves.init_gold_batch(examples)
max_steps = max([len(eg.x) for eg in examples])
if not states:
return losses
all_states = list(states) all_states = list(states)
states_golds = zip(states, golds) states_golds = zip(states, golds)
for _ in range(max_steps): for _ in range(max_steps):
@ -292,6 +299,17 @@ cdef class Parser:
states, golds = zip(*states_golds) states, golds = zip(*states_golds)
scores, backprop = model.begin_update(states) scores, backprop = model.begin_update(states)
d_scores = self.get_batch_loss(states, golds, scores, losses) d_scores = self.get_batch_loss(states, golds, scores, losses)
if self.cfg["normalize_gradients_with_batch_size"]:
# We have to be very careful how we do this, because of the way we
# cut up the batch. We subdivide long sequences. If we normalize
# naively, we end up normalizing by sequence length, which
# is bad: that would mean that states in long sequences
# consistently get smaller gradients. Imagine if we have two
# sequences, one length 1000, one length 20. If we cut up
# the 1k sequence so that we have a "batch" of 50 subsequences,
# we don't want the gradients to get 50 times smaller!
d_scores /= n_examples
backprop(d_scores) backprop(d_scores)
# Follow the predicted action # Follow the predicted action
self.transition_states(states, scores) self.transition_states(states, scores)
@ -389,8 +407,6 @@ cdef class Parser:
cpu_log_loss(c_d_scores, cpu_log_loss(c_d_scores,
costs, is_valid, &scores[i, 0], d_scores.shape[1]) costs, is_valid, &scores[i, 0], d_scores.shape[1])
c_d_scores += d_scores.shape[1] c_d_scores += d_scores.shape[1]
if len(states) and self.cfg["normalize_gradients_with_batch_size"]:
d_scores /= len(states)
if losses is not None: if losses is not None:
losses.setdefault(self.name, 0.) losses.setdefault(self.name, 0.)
losses[self.name] += (d_scores**2).sum() losses[self.name] += (d_scores**2).sum()
@ -503,41 +519,49 @@ cdef class Parser:
return self return self
def _init_gold_batch(self, examples, min_length=5, max_length=500): def _init_gold_batch(self, examples, min_length=5, max_length=500):
"""Make a square batch, of length equal to the shortest doc. A long """Make a square batch, of length equal to the shortest transition
sequence or a cap. A long
doc will get multiple states. Let's say we have a doc of length 2*N, doc will get multiple states. Let's say we have a doc of length 2*N,
where N is the shortest doc. We'll make two states, one representing where N is the shortest doc. We'll make two states, one representing
long_doc[:N], and another representing long_doc[N:].""" long_doc[:N], and another representing long_doc[N:]."""
cdef: cdef:
StateClass start_state
StateClass state StateClass state
Transition action Transition action
all_states = self.moves.init_batch([eg.predicted for eg in examples]) all_states = self.moves.init_batch([eg.predicted for eg in examples])
kept = [] kept = []
max_length_seen = 0
for state, eg in zip(all_states, examples): for state, eg in zip(all_states, examples):
if self.moves.has_gold(eg) and not state.is_final(): if self.moves.has_gold(eg) and not state.is_final():
gold = self.moves.init_gold(state, eg) gold = self.moves.init_gold(state, eg)
kept.append((eg, state, gold)) oracle_actions = self.moves.get_oracle_sequence_from_state(
max_length = max(min_length, min(max_length, min([len(eg.x) for eg in examples]))) state.copy(), gold)
max_moves = 0 kept.append((eg, state, gold, oracle_actions))
min_length = min(min_length, len(oracle_actions))
max_length_seen = max(max_length, len(oracle_actions))
if not kept:
return [], [], 0
max_length = max(min_length, min(max_length, max_length_seen))
states = [] states = []
golds = [] golds = []
for eg, state, gold in kept: cdef int clas
oracle_actions = self.moves.get_oracle_sequence_from_state( max_moves = 0
state, gold) for eg, state, gold, oracle_actions in kept:
start = 0 for i in range(0, len(oracle_actions), max_length):
while start < len(eg.predicted): start_state = state.copy()
state = state.copy()
n_moves = 0 n_moves = 0
while state.B(0) < start and not state.is_final(): for clas in oracle_actions[i:i+max_length]:
action = self.moves.c[oracle_actions.pop(0)] action = self.moves.c[clas]
action.do(state.c, action.label) action.do(state.c, action.label)
state.c.push_hist(action.clas) state.c.push_hist(action.clas)
n_moves += 1 n_moves += 1
has_gold = self.moves.has_gold(eg, start=start, if state.is_final():
end=start+max_length) break
if not state.is_final() and has_gold: max_moves = max(max_moves, n_moves)
states.append(state) if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
states.append(start_state)
golds.append(gold) golds.append(gold)
max_moves = max(max_moves, n_moves) max_moves = max(max_moves, n_moves)
start += min(max_length, len(eg.x)-start) if state.is_final():
max_moves = max(max_moves, len(oracle_actions)) break
return states, golds, max_moves return states, golds, max_moves