mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-27 08:23:12 +03:00
Fix set_annotations during parser update
This commit is contained in:
parent
c631c355d1
commit
be155ead9b
|
@ -61,6 +61,14 @@ cdef class TransitionSystem:
|
||||||
offset += len(doc)
|
offset += len(doc)
|
||||||
return states
|
return states
|
||||||
|
|
||||||
|
def follow_history(self, doc, history):
|
||||||
|
cdef int clas
|
||||||
|
cdef StateClass state = StateClass(doc)
|
||||||
|
for clas in history:
|
||||||
|
action = self.c[clas]
|
||||||
|
action.do(state.c, action.label)
|
||||||
|
return state
|
||||||
|
|
||||||
def get_oracle_sequence(self, Example example, _debug=False):
|
def get_oracle_sequence(self, Example example, _debug=False):
|
||||||
states, golds, _ = self.init_gold_batch([example])
|
states, golds, _ = self.init_gold_batch([example])
|
||||||
if not states:
|
if not states:
|
||||||
|
|
|
@ -317,8 +317,8 @@ cdef class Parser(TrainablePipe):
|
||||||
for multitask in self._multitasks:
|
for multitask in self._multitasks:
|
||||||
multitask.update(examples, drop=drop, sgd=sgd)
|
multitask.update(examples, drop=drop, sgd=sgd)
|
||||||
|
|
||||||
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
|
examples = [eg for eg in examples if self.moves.has_gold(eg)]
|
||||||
if n_examples == 0:
|
if len(examples) == 0:
|
||||||
return losses
|
return losses
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
# The probability we use beam update, instead of falling back to
|
# The probability we use beam update, instead of falling back to
|
||||||
|
@ -332,6 +332,7 @@ cdef class Parser(TrainablePipe):
|
||||||
losses=losses,
|
losses=losses,
|
||||||
beam_density=self.cfg["beam_density"]
|
beam_density=self.cfg["beam_density"]
|
||||||
)
|
)
|
||||||
|
oracle_histories = [self.moves.get_oracle_sequence(eg) for eg in examples]
|
||||||
max_moves = self.cfg["update_with_oracle_cut_size"]
|
max_moves = self.cfg["update_with_oracle_cut_size"]
|
||||||
if max_moves >= 1:
|
if max_moves >= 1:
|
||||||
# Chop sequences into lengths of this many words, to make the
|
# Chop sequences into lengths of this many words, to make the
|
||||||
|
@ -339,6 +340,7 @@ cdef class Parser(TrainablePipe):
|
||||||
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
||||||
states, golds, _ = self._init_gold_batch(
|
states, golds, _ = self._init_gold_batch(
|
||||||
examples,
|
examples,
|
||||||
|
oracle_histories,
|
||||||
max_length=max_moves
|
max_length=max_moves
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -370,11 +372,15 @@ cdef class Parser(TrainablePipe):
|
||||||
if sgd not in (None, False):
|
if sgd not in (None, False):
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
docs = [eg.predicted for eg in examples]
|
docs = [eg.predicted for eg in examples]
|
||||||
# TODO: Refactor so we don't have to parse twice like this (ugh)
|
# If we want to set the annotations based on predictions, it's really
|
||||||
|
# hard to avoid parsing the data twice :(.
|
||||||
# The issue is that we cut up the gold batch into sub-states, and that
|
# The issue is that we cut up the gold batch into sub-states, and that
|
||||||
# makes it hard to get the actual predicted transition sequence.
|
# means there's no one predicted sequence during the update.
|
||||||
predicted_states = self.predict(docs)
|
gold_states = [
|
||||||
self.set_annotations(docs, predicted_states)
|
self.moves.follow_history(doc, history)
|
||||||
|
for doc, history in zip(docs, oracle_histories)
|
||||||
|
]
|
||||||
|
self.set_annotations(docs, gold_states)
|
||||||
# Ugh, this is annoying. If we're working on GPU, we want to free the
|
# Ugh, this is annoying. If we're working on GPU, we want to free the
|
||||||
# memory ASAP. It seems that Python doesn't necessarily get around to
|
# memory ASAP. It seems that Python doesn't necessarily get around to
|
||||||
# removing these in time if we don't explicitly delete? It's confusing.
|
# removing these in time if we don't explicitly delete? It's confusing.
|
||||||
|
@ -581,7 +587,7 @@ cdef class Parser(TrainablePipe):
|
||||||
raise ValueError(Errors.E149) from None
|
raise ValueError(Errors.E149) from None
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _init_gold_batch(self, examples, max_length):
|
def _init_gold_batch(self, examples, oracle_histories, max_length):
|
||||||
"""Make a square batch, of length equal to the shortest transition
|
"""Make a square batch, of length equal to the shortest transition
|
||||||
sequence or a cap. A long
|
sequence or a cap. A long
|
||||||
doc will get multiple states. Let's say we have a doc of length 2*N,
|
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||||
|
@ -594,24 +600,17 @@ cdef class Parser(TrainablePipe):
|
||||||
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
||||||
states = []
|
states = []
|
||||||
golds = []
|
golds = []
|
||||||
to_cut = []
|
for state, eg, history in zip(all_states, examples, oracle_histories):
|
||||||
for state, eg in zip(all_states, examples):
|
if state.is_final():
|
||||||
if self.moves.has_gold(eg) and not state.is_final():
|
continue
|
||||||
gold = self.moves.init_gold(state, eg)
|
gold = self.moves.init_gold(state, eg)
|
||||||
if len(eg.x) < max_length:
|
if len(history) < max_length:
|
||||||
states.append(state)
|
states.append(state)
|
||||||
golds.append(gold)
|
golds.append(gold)
|
||||||
else:
|
continue
|
||||||
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
for i in range(0, len(history), max_length):
|
||||||
state.copy(), gold)
|
|
||||||
to_cut.append((eg, state, gold, oracle_actions))
|
|
||||||
if not to_cut:
|
|
||||||
return states, golds, 0
|
|
||||||
cdef int clas
|
|
||||||
for eg, state, gold, oracle_actions in to_cut:
|
|
||||||
for i in range(0, len(oracle_actions), max_length):
|
|
||||||
start_state = state.copy()
|
start_state = state.copy()
|
||||||
for clas in oracle_actions[i:i+max_length]:
|
for clas in history[i:i+max_length]:
|
||||||
action = self.moves.c[clas]
|
action = self.moves.c[clas]
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
if state.is_final():
|
if state.is_final():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user