mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Improve cutting logic in parser
This commit is contained in:
parent
42e1109def
commit
ca989f4cc4
|
@ -292,10 +292,8 @@ cdef class Parser:
|
|||
if not states:
|
||||
return losses
|
||||
all_states = list(states)
|
||||
states_golds = zip(states, golds)
|
||||
for _ in range(max_steps):
|
||||
if not states_golds:
|
||||
break
|
||||
states_golds = list(zip(states, golds))
|
||||
while states_golds:
|
||||
states, golds = zip(*states_golds)
|
||||
scores, backprop = model.begin_update(states)
|
||||
d_scores = self.get_batch_loss(states, golds, scores, losses)
|
||||
|
@ -519,21 +517,25 @@ cdef class Parser:
|
|||
StateClass state
|
||||
Transition action
|
||||
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
||||
states = []
|
||||
golds = []
|
||||
kept = []
|
||||
max_length_seen = 0
|
||||
for state, eg in zip(all_states, examples):
|
||||
if self.moves.has_gold(eg) and not state.is_final():
|
||||
gold = self.moves.init_gold(state, eg)
|
||||
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
||||
state.copy(), gold)
|
||||
kept.append((eg, state, gold, oracle_actions))
|
||||
min_length = min(min_length, len(oracle_actions))
|
||||
max_length_seen = max(max_length, len(oracle_actions))
|
||||
if len(eg.x) < max_length:
|
||||
states.append(state)
|
||||
golds.append(gold)
|
||||
else:
|
||||
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
||||
state.copy(), gold)
|
||||
kept.append((eg, state, gold, oracle_actions))
|
||||
min_length = min(min_length, len(oracle_actions))
|
||||
max_length_seen = max(max_length, len(oracle_actions))
|
||||
if not kept:
|
||||
return [], [], 0
|
||||
return states, golds, 0
|
||||
max_length = max(min_length, min(max_length, max_length_seen))
|
||||
states = []
|
||||
golds = []
|
||||
cdef int clas
|
||||
max_moves = 0
|
||||
for eg, state, gold, oracle_actions in kept:
|
||||
|
|
Loading…
Reference in New Issue
Block a user