mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-29 01:13:17 +03:00
Try to fix parser training
This commit is contained in:
parent
3a6b93ae3a
commit
456c881ae3
|
@ -83,6 +83,8 @@ cdef class TransitionSystem:
|
||||||
def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None):
|
def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None):
|
||||||
if state.is_final():
|
if state.is_final():
|
||||||
return []
|
return []
|
||||||
|
if not self.has_gold(eg):
|
||||||
|
return []
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||||
assert self.n_moves > 0
|
assert self.n_moves > 0
|
||||||
|
|
|
@ -316,8 +316,9 @@ cdef class Parser(TrainablePipe):
|
||||||
validate_examples(examples, "Parser.update")
|
validate_examples(examples, "Parser.update")
|
||||||
for multitask in self._multitasks:
|
for multitask in self._multitasks:
|
||||||
multitask.update(examples, drop=drop, sgd=sgd)
|
multitask.update(examples, drop=drop, sgd=sgd)
|
||||||
|
# We need to take care to act on the whole batch, because we might be
|
||||||
examples = [eg for eg in examples if self.moves.has_gold(eg)]
|
# getting vectors via a listener.
|
||||||
|
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
|
||||||
if len(examples) == 0:
|
if len(examples) == 0:
|
||||||
return losses
|
return losses
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
|
@ -347,7 +348,8 @@ cdef class Parser(TrainablePipe):
|
||||||
states, golds, _ = self.moves.init_gold_batch(examples)
|
states, golds, _ = self.moves.init_gold_batch(examples)
|
||||||
if not states:
|
if not states:
|
||||||
return losses
|
return losses
|
||||||
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
|
docs = [eg.predicted for eg in examples]
|
||||||
|
model, backprop_tok2vec = self.model.begin_update(docs)
|
||||||
|
|
||||||
all_states = list(states)
|
all_states = list(states)
|
||||||
states_golds = list(zip(states, golds))
|
states_golds = list(zip(states, golds))
|
||||||
|
@ -371,7 +373,6 @@ cdef class Parser(TrainablePipe):
|
||||||
backprop_tok2vec(golds)
|
backprop_tok2vec(golds)
|
||||||
if sgd not in (None, False):
|
if sgd not in (None, False):
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
docs = [eg.predicted for eg in examples]
|
|
||||||
# If we want to set the annotations based on predictions, it's really
|
# If we want to set the annotations based on predictions, it's really
|
||||||
# hard to avoid parsing the data twice :(.
|
# hard to avoid parsing the data twice :(.
|
||||||
# The issue is that we cut up the gold batch into sub-states, and that
|
# The issue is that we cut up the gold batch into sub-states, and that
|
||||||
|
@ -601,7 +602,7 @@ cdef class Parser(TrainablePipe):
|
||||||
states = []
|
states = []
|
||||||
golds = []
|
golds = []
|
||||||
for state, eg, history in zip(all_states, examples, oracle_histories):
|
for state, eg, history in zip(all_states, examples, oracle_histories):
|
||||||
if state.is_final():
|
if not history:
|
||||||
continue
|
continue
|
||||||
gold = self.moves.init_gold(state, eg)
|
gold = self.moves.init_gold(state, eg)
|
||||||
if len(history) < max_length:
|
if len(history) < max_length:
|
||||||
|
@ -609,6 +610,8 @@ cdef class Parser(TrainablePipe):
|
||||||
golds.append(gold)
|
golds.append(gold)
|
||||||
continue
|
continue
|
||||||
for i in range(0, len(history), max_length):
|
for i in range(0, len(history), max_length):
|
||||||
|
if state.is_final():
|
||||||
|
break
|
||||||
start_state = state.copy()
|
start_state = state.copy()
|
||||||
for clas in history[i:i+max_length]:
|
for clas in history[i:i+max_length]:
|
||||||
action = self.moves.c[clas]
|
action = self.moves.c[clas]
|
||||||
|
@ -618,6 +621,4 @@ cdef class Parser(TrainablePipe):
|
||||||
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
|
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
|
||||||
states.append(start_state)
|
states.append(start_state)
|
||||||
golds.append(gold)
|
golds.append(gold)
|
||||||
if state.is_final():
|
|
||||||
break
|
|
||||||
return states, golds, max_length
|
return states, golds, max_length
|
||||||
|
|
Loading…
Reference in New Issue
Block a user