mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Try different oracle cuts
This commit is contained in:
parent
1c07820681
commit
cd9194c823
|
@ -281,7 +281,7 @@ cdef class Parser(Pipe):
|
|||
# Chop sequences into lengths of this many words, to make the
|
||||
# batch uniform length.
|
||||
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
||||
states, golds, _ = self._init_gold_batch(
|
||||
states, golds, max_moves = self._init_gold_batch(
|
||||
examples,
|
||||
max_length=max_moves
|
||||
)
|
||||
|
@ -304,10 +304,9 @@ cdef class Parser(Pipe):
|
|||
# Follow the predicted action
|
||||
self.transition_states(states, scores)
|
||||
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
|
||||
n_moves += 1
|
||||
if max_moves >= 1 and n_moves >= max_moves:
|
||||
break
|
||||
n_moves += 1
|
||||
|
||||
backprop_tok2vec(golds)
|
||||
if sgd not in (None, False):
|
||||
self.model.finish_update(sgd)
|
||||
|
@ -513,21 +512,19 @@ cdef class Parser(Pipe):
|
|||
StateClass state
|
||||
Transition action
|
||||
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
||||
states = []
|
||||
golds = []
|
||||
to_cut = []
|
||||
for state, eg in zip(all_states, examples):
|
||||
if self.moves.has_gold(eg) and not state.is_final():
|
||||
gold = self.moves.init_gold(state, eg)
|
||||
if len(eg.x) < max_length:
|
||||
states.append(state)
|
||||
golds.append(gold)
|
||||
else:
|
||||
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
||||
state.copy(), gold)
|
||||
to_cut.append((eg, state, gold, oracle_actions))
|
||||
states = []
|
||||
golds = []
|
||||
if not to_cut:
|
||||
return states, golds, 0
|
||||
lengths = [len(x[-1]) for x in to_cut]
|
||||
max_length = min(min(lengths), max_length)
|
||||
cdef int clas
|
||||
for eg, state, gold, oracle_actions in to_cut:
|
||||
for i in range(0, len(oracle_actions), max_length):
|
||||
|
|
Loading…
Reference in New Issue
Block a user