mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-14 03:26:24 +03:00
Fix set_annotations in parser.update
This commit is contained in:
parent
bb15d5b22f
commit
c6df0eafd0
|
@ -193,7 +193,11 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de
|
||||||
for i, (d_scores, bp_scores) in enumerate(zip(states_d_scores, backprops)):
|
for i, (d_scores, bp_scores) in enumerate(zip(states_d_scores, backprops)):
|
||||||
loss += (d_scores**2).mean()
|
loss += (d_scores**2).mean()
|
||||||
bp_scores(d_scores)
|
bp_scores(d_scores)
|
||||||
return loss
|
# Return the predicted sequence for each doc.
|
||||||
|
predicted_histories = []
|
||||||
|
for i in range(len(pbeam)):
|
||||||
|
predicted_histories.append(pbeam[i].histories[0])
|
||||||
|
return predicted_histories, loss
|
||||||
|
|
||||||
|
|
||||||
def collect_states(beams, docs):
|
def collect_states(beams, docs):
|
||||||
|
|
|
@ -638,16 +638,17 @@ cdef class ArcEager(TransitionSystem):
|
||||||
return gold
|
return gold
|
||||||
|
|
||||||
def init_gold_batch(self, examples):
|
def init_gold_batch(self, examples):
|
||||||
# TODO: Projectivity?
|
|
||||||
all_states = self.init_batch([eg.predicted for eg in examples])
|
all_states = self.init_batch([eg.predicted for eg in examples])
|
||||||
golds = []
|
golds = []
|
||||||
states = []
|
states = []
|
||||||
|
docs = []
|
||||||
for state, eg in zip(all_states, examples):
|
for state, eg in zip(all_states, examples):
|
||||||
if self.has_gold(eg) and not state.is_final():
|
if self.has_gold(eg) and not state.is_final():
|
||||||
golds.append(self.init_gold(state, eg))
|
golds.append(self.init_gold(state, eg))
|
||||||
states.append(state)
|
states.append(state)
|
||||||
|
docs.append(eg.x)
|
||||||
n_steps = sum([len(s.queue) for s in states])
|
n_steps = sum([len(s.queue) for s in states])
|
||||||
return states, golds, n_steps
|
return states, golds, docs
|
||||||
|
|
||||||
def _replace_unseen_labels(self, ArcEagerGold gold):
|
def _replace_unseen_labels(self, ArcEagerGold gold):
|
||||||
backoff_label = self.strings["dep"]
|
backoff_label = self.strings["dep"]
|
||||||
|
|
|
@ -120,6 +120,16 @@ cdef class TransitionSystem:
|
||||||
raise ValueError(Errors.E024)
|
raise ValueError(Errors.E024)
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
def follow_history(self, doc, history):
|
||||||
|
"""Get the state that results from following a sequence of actions."""
|
||||||
|
cdef int clas
|
||||||
|
cdef StateClass state
|
||||||
|
state = self.init_batch([doc])[0]
|
||||||
|
for clas in history:
|
||||||
|
action = self.c[clas]
|
||||||
|
action.do(state.c, action.label)
|
||||||
|
return state
|
||||||
|
|
||||||
def apply_transition(self, StateClass state, name):
|
def apply_transition(self, StateClass state, name):
|
||||||
if not self.is_valid(state, name):
|
if not self.is_valid(state, name):
|
||||||
raise ValueError(Errors.E170.format(name=name))
|
raise ValueError(Errors.E170.format(name=name))
|
||||||
|
|
|
@ -337,21 +337,22 @@ cdef class Parser(TrainablePipe):
|
||||||
# Chop sequences into lengths of this many words, to make the
|
# Chop sequences into lengths of this many words, to make the
|
||||||
# batch uniform length.
|
# batch uniform length.
|
||||||
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
||||||
states, golds, _ = self._init_gold_batch(
|
states, golds, max_moves, state2doc = self._init_gold_batch(
|
||||||
examples,
|
examples,
|
||||||
max_length=max_moves
|
max_length=max_moves
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
states, golds, _ = self.moves.init_gold_batch(examples)
|
states, golds, state2doc = self.moves.init_gold_batch(examples)
|
||||||
if not states:
|
if not states:
|
||||||
return losses
|
return losses
|
||||||
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
|
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
|
||||||
|
|
||||||
|
histories = [[] for example in examples]
|
||||||
all_states = list(states)
|
all_states = list(states)
|
||||||
states_golds = list(zip(states, golds))
|
states_golds = list(zip(states, golds, state2doc))
|
||||||
n_moves = 0
|
n_moves = 0
|
||||||
while states_golds:
|
while states_golds:
|
||||||
states, golds = zip(*states_golds)
|
states, golds, state2doc = zip(*states_golds)
|
||||||
scores, backprop = model.begin_update(states)
|
scores, backprop = model.begin_update(states)
|
||||||
d_scores = self.get_batch_loss(states, golds, scores, losses)
|
d_scores = self.get_batch_loss(states, golds, scores, losses)
|
||||||
# Note that the gradient isn't normalized by the batch size
|
# Note that the gradient isn't normalized by the batch size
|
||||||
|
@ -360,8 +361,13 @@ cdef class Parser(TrainablePipe):
|
||||||
# be getting smaller gradients for states in long sequences.
|
# be getting smaller gradients for states in long sequences.
|
||||||
backprop(d_scores)
|
backprop(d_scores)
|
||||||
# Follow the predicted action
|
# Follow the predicted action
|
||||||
self.transition_states(states, scores)
|
actions = self.transition_states(states, scores)
|
||||||
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
|
for i, action in enumerate(actions):
|
||||||
|
histories[i].append(action)
|
||||||
|
states_golds = [
|
||||||
|
s for s in zip(states, golds, state2doc)
|
||||||
|
if not s[0].is_final()
|
||||||
|
]
|
||||||
if max_moves >= 1 and n_moves >= max_moves:
|
if max_moves >= 1 and n_moves >= max_moves:
|
||||||
break
|
break
|
||||||
n_moves += 1
|
n_moves += 1
|
||||||
|
@ -370,11 +376,11 @@ cdef class Parser(TrainablePipe):
|
||||||
if sgd not in (None, False):
|
if sgd not in (None, False):
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
docs = [eg.predicted for eg in examples]
|
docs = [eg.predicted for eg in examples]
|
||||||
# TODO: Refactor so we don't have to parse twice like this (ugh)
|
states = [
|
||||||
# The issue is that we cut up the gold batch into sub-states, and that
|
self.moves.follow_history(doc, history)
|
||||||
# makes it hard to get the actual predicted transition sequence.
|
for doc, history in zip(docs, histories)
|
||||||
predicted_states = self.predict(docs)
|
]
|
||||||
self.set_annotations(docs, predicted_states)
|
self.set_annotations(docs, self._get_states(docs, states))
|
||||||
# Ugh, this is annoying. If we're working on GPU, we want to free the
|
# Ugh, this is annoying. If we're working on GPU, we want to free the
|
||||||
# memory ASAP. It seems that Python doesn't necessarily get around to
|
# memory ASAP. It seems that Python doesn't necessarily get around to
|
||||||
# removing these in time if we don't explicitly delete? It's confusing.
|
# removing these in time if we don't explicitly delete? It's confusing.
|
||||||
|
@ -435,13 +441,16 @@ cdef class Parser(TrainablePipe):
|
||||||
|
|
||||||
def update_beam(self, examples, *, beam_width,
|
def update_beam(self, examples, *, beam_width,
|
||||||
drop=0., sgd=None, losses=None, beam_density=0.0):
|
drop=0., sgd=None, losses=None, beam_density=0.0):
|
||||||
states, golds, _ = self.moves.init_gold_batch(examples)
|
if losses is None:
|
||||||
|
losses = {}
|
||||||
|
losses.setdefault(self.name, 0.0)
|
||||||
|
states, golds, docs = self.moves.init_gold_batch(examples)
|
||||||
if not states:
|
if not states:
|
||||||
return losses
|
return losses
|
||||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||||
model, backprop_tok2vec = self.model.begin_update(
|
model, backprop_tok2vec = self.model.begin_update(
|
||||||
[eg.predicted for eg in examples])
|
[eg.predicted for eg in examples])
|
||||||
loss = _beam_utils.update_beam(
|
predicted_histories, loss = _beam_utils.update_beam(
|
||||||
self.moves,
|
self.moves,
|
||||||
states,
|
states,
|
||||||
golds,
|
golds,
|
||||||
|
@ -453,6 +462,12 @@ cdef class Parser(TrainablePipe):
|
||||||
backprop_tok2vec(golds)
|
backprop_tok2vec(golds)
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
|
states = [
|
||||||
|
self.moves.follow_history(doc, history)
|
||||||
|
for doc, history in zip(docs, predicted_histories)
|
||||||
|
]
|
||||||
|
self.set_annotations(docs, states)
|
||||||
|
return losses
|
||||||
|
|
||||||
def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
|
def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
|
@ -595,18 +610,24 @@ cdef class Parser(TrainablePipe):
|
||||||
states = []
|
states = []
|
||||||
golds = []
|
golds = []
|
||||||
to_cut = []
|
to_cut = []
|
||||||
|
# Return a list indicating the position in the batch that each state
|
||||||
|
# refers to. This lets us put together the full list of predicted
|
||||||
|
# histories.
|
||||||
|
state2doc = []
|
||||||
|
doc2i = {eg.x: i for i, eg in enumerate(examples)}
|
||||||
for state, eg in zip(all_states, examples):
|
for state, eg in zip(all_states, examples):
|
||||||
if self.moves.has_gold(eg) and not state.is_final():
|
if self.moves.has_gold(eg) and not state.is_final():
|
||||||
gold = self.moves.init_gold(state, eg)
|
gold = self.moves.init_gold(state, eg)
|
||||||
if len(eg.x) < max_length:
|
if len(eg.x) < max_length:
|
||||||
states.append(state)
|
states.append(state)
|
||||||
golds.append(gold)
|
golds.append(gold)
|
||||||
|
state2doc.append(doc2i[eg.x])
|
||||||
else:
|
else:
|
||||||
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
||||||
state.copy(), gold)
|
state.copy(), gold)
|
||||||
to_cut.append((eg, state, gold, oracle_actions))
|
to_cut.append((eg, state, gold, oracle_actions))
|
||||||
if not to_cut:
|
if not to_cut:
|
||||||
return states, golds, 0
|
return states, golds, 0, state2doc
|
||||||
cdef int clas
|
cdef int clas
|
||||||
for eg, state, gold, oracle_actions in to_cut:
|
for eg, state, gold, oracle_actions in to_cut:
|
||||||
for i in range(0, len(oracle_actions), max_length):
|
for i in range(0, len(oracle_actions), max_length):
|
||||||
|
@ -619,6 +640,7 @@ cdef class Parser(TrainablePipe):
|
||||||
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
|
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
|
||||||
states.append(start_state)
|
states.append(start_state)
|
||||||
golds.append(gold)
|
golds.append(gold)
|
||||||
|
state2doc.append(doc2i[eg.x])
|
||||||
if state.is_final():
|
if state.is_final():
|
||||||
break
|
break
|
||||||
return states, golds, max_length
|
return states, golds, max_length, state2doc
|
||||||
|
|
Loading…
Reference in New Issue
Block a user