mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-14 22:24:15 +03:00
Try to fix parser training
This commit is contained in:
parent
3a6b93ae3a
commit
456c881ae3
|
@ -83,6 +83,8 @@ cdef class TransitionSystem:
|
|||
def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None):
|
||||
if state.is_final():
|
||||
return []
|
||||
if not self.has_gold(eg):
|
||||
return []
|
||||
cdef Pool mem = Pool()
|
||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||
assert self.n_moves > 0
|
||||
|
|
|
@ -316,8 +316,9 @@ cdef class Parser(TrainablePipe):
|
|||
validate_examples(examples, "Parser.update")
|
||||
for multitask in self._multitasks:
|
||||
multitask.update(examples, drop=drop, sgd=sgd)
|
||||
|
||||
examples = [eg for eg in examples if self.moves.has_gold(eg)]
|
||||
# We need to take care to act on the whole batch, because we might be
|
||||
# getting vectors via a listener.
|
||||
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
|
||||
if len(examples) == 0:
|
||||
return losses
|
||||
set_dropout_rate(self.model, drop)
|
||||
|
@ -347,7 +348,8 @@ cdef class Parser(TrainablePipe):
|
|||
states, golds, _ = self.moves.init_gold_batch(examples)
|
||||
if not states:
|
||||
return losses
|
||||
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
|
||||
docs = [eg.predicted for eg in examples]
|
||||
model, backprop_tok2vec = self.model.begin_update(docs)
|
||||
|
||||
all_states = list(states)
|
||||
states_golds = list(zip(states, golds))
|
||||
|
@ -371,7 +373,6 @@ cdef class Parser(TrainablePipe):
|
|||
backprop_tok2vec(golds)
|
||||
if sgd not in (None, False):
|
||||
self.finish_update(sgd)
|
||||
docs = [eg.predicted for eg in examples]
|
||||
# If we want to set the annotations based on predictions, it's really
|
||||
# hard to avoid parsing the data twice :(.
|
||||
# The issue is that we cut up the gold batch into sub-states, and that
|
||||
|
@ -601,7 +602,7 @@ cdef class Parser(TrainablePipe):
|
|||
states = []
|
||||
golds = []
|
||||
for state, eg, history in zip(all_states, examples, oracle_histories):
|
||||
if state.is_final():
|
||||
if not history:
|
||||
continue
|
||||
gold = self.moves.init_gold(state, eg)
|
||||
if len(history) < max_length:
|
||||
|
@ -609,6 +610,8 @@ cdef class Parser(TrainablePipe):
|
|||
golds.append(gold)
|
||||
continue
|
||||
for i in range(0, len(history), max_length):
|
||||
if state.is_final():
|
||||
break
|
||||
start_state = state.copy()
|
||||
for clas in history[i:i+max_length]:
|
||||
action = self.moves.c[clas]
|
||||
|
@ -618,6 +621,4 @@ cdef class Parser(TrainablePipe):
|
|||
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
|
||||
states.append(start_state)
|
||||
golds.append(gold)
|
||||
if state.is_final():
|
||||
break
|
||||
return states, golds, max_length
|
||||
|
|
Loading…
Reference in New Issue
Block a user