Fix set_annotations during parser update

This commit is contained in:
Matthew Honnibal 2021-01-25 11:56:36 +11:00
parent c631c355d1
commit be155ead9b
2 changed files with 31 additions and 24 deletions

View File

@ -61,6 +61,14 @@ cdef class TransitionSystem:
offset += len(doc) offset += len(doc)
return states return states
def follow_history(self, doc, history):
cdef int clas
cdef StateClass state = StateClass(doc)
for clas in history:
action = self.c[clas]
action.do(state.c, action.label)
return state
def get_oracle_sequence(self, Example example, _debug=False): def get_oracle_sequence(self, Example example, _debug=False):
states, golds, _ = self.init_gold_batch([example]) states, golds, _ = self.init_gold_batch([example])
if not states: if not states:

View File

@ -317,8 +317,8 @@ cdef class Parser(TrainablePipe):
for multitask in self._multitasks: for multitask in self._multitasks:
multitask.update(examples, drop=drop, sgd=sgd) multitask.update(examples, drop=drop, sgd=sgd)
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)]) examples = [eg for eg in examples if self.moves.has_gold(eg)]
if n_examples == 0: if len(examples) == 0:
return losses return losses
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
# The probability we use beam update, instead of falling back to # The probability we use beam update, instead of falling back to
@ -332,6 +332,7 @@ cdef class Parser(TrainablePipe):
losses=losses, losses=losses,
beam_density=self.cfg["beam_density"] beam_density=self.cfg["beam_density"]
) )
oracle_histories = [self.moves.get_oracle_sequence(eg) for eg in examples]
max_moves = self.cfg["update_with_oracle_cut_size"] max_moves = self.cfg["update_with_oracle_cut_size"]
if max_moves >= 1: if max_moves >= 1:
# Chop sequences into lengths of this many words, to make the # Chop sequences into lengths of this many words, to make the
@ -339,6 +340,7 @@ cdef class Parser(TrainablePipe):
max_moves = int(random.uniform(max_moves // 2, max_moves * 2)) max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
states, golds, _ = self._init_gold_batch( states, golds, _ = self._init_gold_batch(
examples, examples,
oracle_histories,
max_length=max_moves max_length=max_moves
) )
else: else:
@ -370,11 +372,15 @@ cdef class Parser(TrainablePipe):
if sgd not in (None, False): if sgd not in (None, False):
self.finish_update(sgd) self.finish_update(sgd)
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
# TODO: Refactor so we don't have to parse twice like this (ugh) # If we want to set the annotations based on predictions, it's really
# hard to avoid parsing the data twice :(.
# The issue is that we cut up the gold batch into sub-states, and that # The issue is that we cut up the gold batch into sub-states, and that
# makes it hard to get the actual predicted transition sequence. # means there's no one predicted sequence during the update.
predicted_states = self.predict(docs) gold_states = [
self.set_annotations(docs, predicted_states) self.moves.follow_history(doc, history)
for doc, history in zip(docs, oracle_histories)
]
self.set_annotations(docs, gold_states)
# Ugh, this is annoying. If we're working on GPU, we want to free the # Ugh, this is annoying. If we're working on GPU, we want to free the
# memory ASAP. It seems that Python doesn't necessarily get around to # memory ASAP. It seems that Python doesn't necessarily get around to
# removing these in time if we don't explicitly delete? It's confusing. # removing these in time if we don't explicitly delete? It's confusing.
@ -581,7 +587,7 @@ cdef class Parser(TrainablePipe):
raise ValueError(Errors.E149) from None raise ValueError(Errors.E149) from None
return self return self
def _init_gold_batch(self, examples, max_length): def _init_gold_batch(self, examples, oracle_histories, max_length):
"""Make a square batch, of length equal to the shortest transition """Make a square batch, of length equal to the shortest transition
sequence or a cap. A long sequence or a cap. A long
doc will get multiple states. Let's say we have a doc of length 2*N, doc will get multiple states. Let's say we have a doc of length 2*N,
@ -594,24 +600,17 @@ cdef class Parser(TrainablePipe):
all_states = self.moves.init_batch([eg.predicted for eg in examples]) all_states = self.moves.init_batch([eg.predicted for eg in examples])
states = [] states = []
golds = [] golds = []
to_cut = [] for state, eg, history in zip(all_states, examples, oracle_histories):
for state, eg in zip(all_states, examples): if state.is_final():
if self.moves.has_gold(eg) and not state.is_final(): continue
gold = self.moves.init_gold(state, eg) gold = self.moves.init_gold(state, eg)
if len(eg.x) < max_length: if len(history) < max_length:
states.append(state) states.append(state)
golds.append(gold) golds.append(gold)
else: continue
oracle_actions = self.moves.get_oracle_sequence_from_state( for i in range(0, len(history), max_length):
state.copy(), gold)
to_cut.append((eg, state, gold, oracle_actions))
if not to_cut:
return states, golds, 0
cdef int clas
for eg, state, gold, oracle_actions in to_cut:
for i in range(0, len(oracle_actions), max_length):
start_state = state.copy() start_state = state.copy()
for clas in oracle_actions[i:i+max_length]: for clas in history[i:i+max_length]:
action = self.moves.c[clas] action = self.moves.c[clas]
action.do(state.c, action.label) action.do(state.c, action.label)
if state.is_final(): if state.is_final():