Revert "Fix set_annotations in parser.update"

This reverts commit c6df0eafd0.
This commit is contained in:
Matthew Honnibal 2021-01-25 11:22:57 +11:00
parent 65f2270d59
commit c631c355d1
4 changed files with 18 additions and 55 deletions

View File

@ -193,11 +193,7 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de
for i, (d_scores, bp_scores) in enumerate(zip(states_d_scores, backprops)):
loss += (d_scores**2).mean()
bp_scores(d_scores)
# Return the predicted sequence for each doc.
predicted_histories = []
for i in range(len(pbeam)):
predicted_histories.append(pbeam[i].histories[0])
return predicted_histories, loss
return loss
def collect_states(beams, docs):

View File

@ -638,17 +638,16 @@ cdef class ArcEager(TransitionSystem):
return gold
def init_gold_batch(self, examples):
# TODO: Projectivity?
all_states = self.init_batch([eg.predicted for eg in examples])
golds = []
states = []
docs = []
for state, eg in zip(all_states, examples):
if self.has_gold(eg) and not state.is_final():
golds.append(self.init_gold(state, eg))
states.append(state)
docs.append(eg.x)
n_steps = sum([len(s.queue) for s in states])
return states, golds, docs
return states, golds, n_steps
def _replace_unseen_labels(self, ArcEagerGold gold):
backoff_label = self.strings["dep"]

View File

@ -120,16 +120,6 @@ cdef class TransitionSystem:
raise ValueError(Errors.E024)
return history
def follow_history(self, doc, history):
"""Get the state that results from following a sequence of actions."""
cdef int clas
cdef StateClass state
state = self.init_batch([doc])[0]
for clas in history:
action = self.c[clas]
action.do(state.c, action.label)
return state
def apply_transition(self, StateClass state, name):
if not self.is_valid(state, name):
raise ValueError(Errors.E170.format(name=name))

View File

@ -337,22 +337,21 @@ cdef class Parser(TrainablePipe):
# Chop sequences into lengths of this many words, to make the
# batch uniform length.
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
states, golds, max_moves, state2doc = self._init_gold_batch(
states, golds, _ = self._init_gold_batch(
examples,
max_length=max_moves
)
else:
states, golds, state2doc = self.moves.init_gold_batch(examples)
states, golds, _ = self.moves.init_gold_batch(examples)
if not states:
return losses
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
histories = [[] for example in examples]
all_states = list(states)
states_golds = list(zip(states, golds, state2doc))
states_golds = list(zip(states, golds))
n_moves = 0
while states_golds:
states, golds, state2doc = zip(*states_golds)
states, golds = zip(*states_golds)
scores, backprop = model.begin_update(states)
d_scores = self.get_batch_loss(states, golds, scores, losses)
# Note that the gradient isn't normalized by the batch size
@ -361,13 +360,8 @@ cdef class Parser(TrainablePipe):
# be getting smaller gradients for states in long sequences.
backprop(d_scores)
# Follow the predicted action
actions = self.transition_states(states, scores)
for i, action in enumerate(actions):
histories[i].append(action)
states_golds = [
s for s in zip(states, golds, state2doc)
if not s[0].is_final()
]
self.transition_states(states, scores)
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
if max_moves >= 1 and n_moves >= max_moves:
break
n_moves += 1
@ -376,11 +370,11 @@ cdef class Parser(TrainablePipe):
if sgd not in (None, False):
self.finish_update(sgd)
docs = [eg.predicted for eg in examples]
states = [
self.moves.follow_history(doc, history)
for doc, history in zip(docs, histories)
]
self.set_annotations(docs, self._get_states(docs, states))
# TODO: Refactor so we don't have to parse twice like this (ugh)
# The issue is that we cut up the gold batch into sub-states, and that
# makes it hard to get the actual predicted transition sequence.
predicted_states = self.predict(docs)
self.set_annotations(docs, predicted_states)
# Ugh, this is annoying. If we're working on GPU, we want to free the
# memory ASAP. It seems that Python doesn't necessarily get around to
# removing these in time if we don't explicitly delete? It's confusing.
@ -441,16 +435,13 @@ cdef class Parser(TrainablePipe):
def update_beam(self, examples, *, beam_width,
drop=0., sgd=None, losses=None, beam_density=0.0):
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
states, golds, docs = self.moves.init_gold_batch(examples)
states, golds, _ = self.moves.init_gold_batch(examples)
if not states:
return losses
# Prepare the stepwise model, and get the callback for finishing the batch
model, backprop_tok2vec = self.model.begin_update(
[eg.predicted for eg in examples])
predicted_histories, loss = _beam_utils.update_beam(
loss = _beam_utils.update_beam(
self.moves,
states,
golds,
@ -462,12 +453,6 @@ cdef class Parser(TrainablePipe):
backprop_tok2vec(golds)
if sgd is not None:
self.finish_update(sgd)
states = [
self.moves.follow_history(doc, history)
for doc, history in zip(docs, predicted_histories)
]
self.set_annotations(docs, states)
return losses
def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
cdef StateClass state
@ -610,24 +595,18 @@ cdef class Parser(TrainablePipe):
states = []
golds = []
to_cut = []
# Return a list indicating the position in the batch that each state
# refers to. This lets us put together the full list of predicted
# histories.
state2doc = []
doc2i = {eg.x: i for i, eg in enumerate(examples)}
for state, eg in zip(all_states, examples):
if self.moves.has_gold(eg) and not state.is_final():
gold = self.moves.init_gold(state, eg)
if len(eg.x) < max_length:
states.append(state)
golds.append(gold)
state2doc.append(doc2i[eg.x])
else:
oracle_actions = self.moves.get_oracle_sequence_from_state(
state.copy(), gold)
to_cut.append((eg, state, gold, oracle_actions))
if not to_cut:
return states, golds, 0, state2doc
return states, golds, 0
cdef int clas
for eg, state, gold, oracle_actions in to_cut:
for i in range(0, len(oracle_actions), max_length):
@ -640,7 +619,6 @@ cdef class Parser(TrainablePipe):
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
states.append(start_state)
golds.append(gold)
state2doc.append(doc2i[eg.x])
if state.is_final():
break
return states, golds, max_length, state2doc
return states, golds, max_length