mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
Improve cutting logic in parser
This commit is contained in:
parent
42e1109def
commit
ca989f4cc4
|
@ -292,10 +292,8 @@ cdef class Parser:
|
||||||
if not states:
|
if not states:
|
||||||
return losses
|
return losses
|
||||||
all_states = list(states)
|
all_states = list(states)
|
||||||
states_golds = zip(states, golds)
|
states_golds = list(zip(states, golds))
|
||||||
for _ in range(max_steps):
|
while states_golds:
|
||||||
if not states_golds:
|
|
||||||
break
|
|
||||||
states, golds = zip(*states_golds)
|
states, golds = zip(*states_golds)
|
||||||
scores, backprop = model.begin_update(states)
|
scores, backprop = model.begin_update(states)
|
||||||
d_scores = self.get_batch_loss(states, golds, scores, losses)
|
d_scores = self.get_batch_loss(states, golds, scores, losses)
|
||||||
|
@ -519,21 +517,25 @@ cdef class Parser:
|
||||||
StateClass state
|
StateClass state
|
||||||
Transition action
|
Transition action
|
||||||
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
||||||
|
states = []
|
||||||
|
golds = []
|
||||||
kept = []
|
kept = []
|
||||||
max_length_seen = 0
|
max_length_seen = 0
|
||||||
for state, eg in zip(all_states, examples):
|
for state, eg in zip(all_states, examples):
|
||||||
if self.moves.has_gold(eg) and not state.is_final():
|
if self.moves.has_gold(eg) and not state.is_final():
|
||||||
gold = self.moves.init_gold(state, eg)
|
gold = self.moves.init_gold(state, eg)
|
||||||
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
if len(eg.x) < max_length:
|
||||||
state.copy(), gold)
|
states.append(state)
|
||||||
kept.append((eg, state, gold, oracle_actions))
|
golds.append(gold)
|
||||||
min_length = min(min_length, len(oracle_actions))
|
else:
|
||||||
max_length_seen = max(max_length, len(oracle_actions))
|
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
||||||
|
state.copy(), gold)
|
||||||
|
kept.append((eg, state, gold, oracle_actions))
|
||||||
|
min_length = min(min_length, len(oracle_actions))
|
||||||
|
max_length_seen = max(max_length, len(oracle_actions))
|
||||||
if not kept:
|
if not kept:
|
||||||
return [], [], 0
|
return states, golds, 0
|
||||||
max_length = max(min_length, min(max_length, max_length_seen))
|
max_length = max(min_length, min(max_length, max_length_seen))
|
||||||
states = []
|
|
||||||
golds = []
|
|
||||||
cdef int clas
|
cdef int clas
|
||||||
max_moves = 0
|
max_moves = 0
|
||||||
for eg, state, gold, oracle_actions in kept:
|
for eg, state, gold, oracle_actions in kept:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user