mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-24 16:24:16 +03:00
Fix parser gold cutting and gradient normalization
This commit is contained in:
parent
8c5a88e777
commit
a1b6add4c8
|
@ -265,11 +265,15 @@ cdef class Parser:
|
|||
free(is_valid)
|
||||
|
||||
def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None):
|
||||
cdef StateClass state
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.)
|
||||
for multitask in self._multitasks:
|
||||
multitask.update(examples, drop=drop, sgd=sgd)
|
||||
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
|
||||
if n_examples == 0:
|
||||
return losses
|
||||
set_dropout_rate(self.model, drop)
|
||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||
model, backprop_tok2vec = self.model.begin_update(
|
||||
|
@ -280,10 +284,13 @@ cdef class Parser:
|
|||
cut_size = self.cfg["update_with_oracle_cut_size"]
|
||||
states, golds, max_steps = self._init_gold_batch(
|
||||
examples,
|
||||
max_length=numpy.random.choice(range(20, cut_size))
|
||||
max_length=numpy.random.choice(range(5, cut_size))
|
||||
)
|
||||
else:
|
||||
states, golds, max_steps = self.moves.init_gold_batch(examples)
|
||||
states, golds, _ = self.moves.init_gold_batch(examples)
|
||||
max_steps = max([len(eg.x) for eg in examples])
|
||||
if not states:
|
||||
return losses
|
||||
all_states = list(states)
|
||||
states_golds = zip(states, golds)
|
||||
for _ in range(max_steps):
|
||||
|
@ -292,6 +299,17 @@ cdef class Parser:
|
|||
states, golds = zip(*states_golds)
|
||||
scores, backprop = model.begin_update(states)
|
||||
d_scores = self.get_batch_loss(states, golds, scores, losses)
|
||||
if self.cfg["normalize_gradients_with_batch_size"]:
|
||||
# We have to be very careful how we do this, because of the way we
|
||||
# cut up the batch. We subdivide long sequences. If we normalize
|
||||
# naively, we end up normalizing by sequence length, which
|
||||
# is bad: that would mean that states in long sequences
|
||||
# consistently get smaller gradients. Imagine if we have two
|
||||
# sequences, one length 1000, one length 20. If we cut up
|
||||
# the 1k sequence so that we have a "batch" of 50 subsequences,
|
||||
# we don't want the gradients to get 50 times smaller!
|
||||
d_scores /= n_examples
|
||||
|
||||
backprop(d_scores)
|
||||
# Follow the predicted action
|
||||
self.transition_states(states, scores)
|
||||
|
@ -389,8 +407,6 @@ cdef class Parser:
|
|||
cpu_log_loss(c_d_scores,
|
||||
costs, is_valid, &scores[i, 0], d_scores.shape[1])
|
||||
c_d_scores += d_scores.shape[1]
|
||||
if len(states) and self.cfg["normalize_gradients_with_batch_size"]:
|
||||
d_scores /= len(states)
|
||||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.)
|
||||
losses[self.name] += (d_scores**2).sum()
|
||||
|
@ -503,41 +519,49 @@ cdef class Parser:
|
|||
return self
|
||||
|
||||
def _init_gold_batch(self, examples, min_length=5, max_length=500):
|
||||
"""Make a square batch, of length equal to the shortest doc. A long
|
||||
"""Make a square batch, of length equal to the shortest transition
|
||||
sequence or a cap. A long
|
||||
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||
where N is the shortest doc. We'll make two states, one representing
|
||||
long_doc[:N], and another representing long_doc[N:]."""
|
||||
cdef:
|
||||
StateClass start_state
|
||||
StateClass state
|
||||
Transition action
|
||||
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
||||
kept = []
|
||||
max_length_seen = 0
|
||||
for state, eg in zip(all_states, examples):
|
||||
if self.moves.has_gold(eg) and not state.is_final():
|
||||
gold = self.moves.init_gold(state, eg)
|
||||
kept.append((eg, state, gold))
|
||||
max_length = max(min_length, min(max_length, min([len(eg.x) for eg in examples])))
|
||||
max_moves = 0
|
||||
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
||||
state.copy(), gold)
|
||||
kept.append((eg, state, gold, oracle_actions))
|
||||
min_length = min(min_length, len(oracle_actions))
|
||||
max_length_seen = max(max_length, len(oracle_actions))
|
||||
if not kept:
|
||||
return [], [], 0
|
||||
max_length = max(min_length, min(max_length, max_length_seen))
|
||||
states = []
|
||||
golds = []
|
||||
for eg, state, gold in kept:
|
||||
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
||||
state, gold)
|
||||
start = 0
|
||||
while start < len(eg.predicted):
|
||||
state = state.copy()
|
||||
cdef int clas
|
||||
max_moves = 0
|
||||
for eg, state, gold, oracle_actions in kept:
|
||||
for i in range(0, len(oracle_actions), max_length):
|
||||
start_state = state.copy()
|
||||
n_moves = 0
|
||||
while state.B(0) < start and not state.is_final():
|
||||
action = self.moves.c[oracle_actions.pop(0)]
|
||||
for clas in oracle_actions[i:i+max_length]:
|
||||
action = self.moves.c[clas]
|
||||
action.do(state.c, action.label)
|
||||
state.c.push_hist(action.clas)
|
||||
n_moves += 1
|
||||
has_gold = self.moves.has_gold(eg, start=start,
|
||||
end=start+max_length)
|
||||
if not state.is_final() and has_gold:
|
||||
states.append(state)
|
||||
if state.is_final():
|
||||
break
|
||||
max_moves = max(max_moves, n_moves)
|
||||
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
|
||||
states.append(start_state)
|
||||
golds.append(gold)
|
||||
max_moves = max(max_moves, n_moves)
|
||||
start += min(max_length, len(eg.x)-start)
|
||||
max_moves = max(max_moves, len(oracle_actions))
|
||||
if state.is_final():
|
||||
break
|
||||
return states, golds, max_moves
|
||||
|
|
Loading…
Reference in New Issue
Block a user