mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
Try different oracle cuts
This commit is contained in:
parent
1c07820681
commit
cd9194c823
|
@ -281,7 +281,7 @@ cdef class Parser(Pipe):
|
||||||
# Chop sequences into lengths of this many words, to make the
|
# Chop sequences into lengths of this many words, to make the
|
||||||
# batch uniform length.
|
# batch uniform length.
|
||||||
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
||||||
states, golds, _ = self._init_gold_batch(
|
states, golds, max_moves = self._init_gold_batch(
|
||||||
examples,
|
examples,
|
||||||
max_length=max_moves
|
max_length=max_moves
|
||||||
)
|
)
|
||||||
|
@ -304,10 +304,9 @@ cdef class Parser(Pipe):
|
||||||
# Follow the predicted action
|
# Follow the predicted action
|
||||||
self.transition_states(states, scores)
|
self.transition_states(states, scores)
|
||||||
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
|
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
|
||||||
|
n_moves += 1
|
||||||
if max_moves >= 1 and n_moves >= max_moves:
|
if max_moves >= 1 and n_moves >= max_moves:
|
||||||
break
|
break
|
||||||
n_moves += 1
|
|
||||||
|
|
||||||
backprop_tok2vec(golds)
|
backprop_tok2vec(golds)
|
||||||
if sgd not in (None, False):
|
if sgd not in (None, False):
|
||||||
self.model.finish_update(sgd)
|
self.model.finish_update(sgd)
|
||||||
|
@ -513,21 +512,19 @@ cdef class Parser(Pipe):
|
||||||
StateClass state
|
StateClass state
|
||||||
Transition action
|
Transition action
|
||||||
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
||||||
states = []
|
|
||||||
golds = []
|
|
||||||
to_cut = []
|
to_cut = []
|
||||||
for state, eg in zip(all_states, examples):
|
for state, eg in zip(all_states, examples):
|
||||||
if self.moves.has_gold(eg) and not state.is_final():
|
if self.moves.has_gold(eg) and not state.is_final():
|
||||||
gold = self.moves.init_gold(state, eg)
|
gold = self.moves.init_gold(state, eg)
|
||||||
if len(eg.x) < max_length:
|
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
||||||
states.append(state)
|
state.copy(), gold)
|
||||||
golds.append(gold)
|
to_cut.append((eg, state, gold, oracle_actions))
|
||||||
else:
|
states = []
|
||||||
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
golds = []
|
||||||
state.copy(), gold)
|
|
||||||
to_cut.append((eg, state, gold, oracle_actions))
|
|
||||||
if not to_cut:
|
if not to_cut:
|
||||||
return states, golds, 0
|
return states, golds, 0
|
||||||
|
lengths = [len(x[-1]) for x in to_cut]
|
||||||
|
max_length = min(min(lengths), max_length)
|
||||||
cdef int clas
|
cdef int clas
|
||||||
for eg, state, gold, oracle_actions in to_cut:
|
for eg, state, gold, oracle_actions in to_cut:
|
||||||
for i in range(0, len(oracle_actions), max_length):
|
for i in range(0, len(oracle_actions), max_length):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user