mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
* Refactor arc_eager to use new TransitionSystem base class. Need to fix oracle
This commit is contained in:
parent
b063001596
commit
8eadb984cb
|
@ -4,25 +4,8 @@ from thinc.typedefs cimport weight_t
|
||||||
|
|
||||||
|
|
||||||
from ._state cimport State
|
from ._state cimport State
|
||||||
|
from .transition_system cimport TransitionSystem, Transition
|
||||||
|
|
||||||
|
|
||||||
cdef struct Transition:
|
cdef class ArcEager(TransitionSystem):
|
||||||
int clas
|
pass
|
||||||
int move
|
|
||||||
int label
|
|
||||||
int cost
|
|
||||||
weight_t score
|
|
||||||
|
|
||||||
|
|
||||||
cdef class TransitionSystem:
|
|
||||||
cdef Pool mem
|
|
||||||
cdef readonly int n_moves
|
|
||||||
cdef dict label_ids
|
|
||||||
|
|
||||||
cdef const Transition* _moves
|
|
||||||
|
|
||||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *
|
|
||||||
cdef Transition best_gold(self, Transition* guess, const weight_t* scores,
|
|
||||||
const State* s,
|
|
||||||
const int* gold_heads, const int* gold_labels) except *
|
|
||||||
cdef int transition(self, State *s, const Transition* t) except -1
|
|
||||||
|
|
|
@ -8,11 +8,17 @@ from ._state cimport head_in_stack, children_in_stack
|
||||||
|
|
||||||
from ..structs cimport TokenC
|
from ..structs cimport TokenC
|
||||||
|
|
||||||
|
from .transition_system cimport do_func_t, get_cost_func_t
|
||||||
|
from .conll cimport GoldParse
|
||||||
|
|
||||||
|
|
||||||
DEF NON_MONOTONIC = True
|
DEF NON_MONOTONIC = True
|
||||||
DEF USE_BREAK = True
|
DEF USE_BREAK = True
|
||||||
|
|
||||||
|
cdef weight_t MIN_SCORE = -90000
|
||||||
|
|
||||||
|
# Break transition from here
|
||||||
|
# http://www.aclweb.org/anthology/P13-1074
|
||||||
cdef enum:
|
cdef enum:
|
||||||
SHIFT
|
SHIFT
|
||||||
REDUCE
|
REDUCE
|
||||||
|
@ -21,8 +27,150 @@ cdef enum:
|
||||||
BREAK
|
BREAK
|
||||||
N_MOVES
|
N_MOVES
|
||||||
|
|
||||||
# Break transition from here
|
|
||||||
# http://www.aclweb.org/anthology/P13-1074
|
cdef do_func_t[N_MOVES] do_funcs
|
||||||
|
cdef get_cost_func_t[N_MOVES] get_cost_funcs
|
||||||
|
|
||||||
|
|
||||||
|
cdef class ArcEager(TransitionSystem):
|
||||||
|
cdef Transition init_transition(self, int clas, int move, int label) except *:
|
||||||
|
return Transition(
|
||||||
|
score=0,
|
||||||
|
clas=i,
|
||||||
|
move=move,
|
||||||
|
label=label,
|
||||||
|
do=do_funcs[move],
|
||||||
|
get_cost=get_cost_funcs[move])
|
||||||
|
|
||||||
|
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||||
|
cdef bint[N_MOVES] is_valid
|
||||||
|
is_valid[SHIFT] = _can_shift(s)
|
||||||
|
is_valid[REDUCE] = _can_reduce(s)
|
||||||
|
is_valid[LEFT] = _can_left(s)
|
||||||
|
is_valid[RIGHT] = _can_right(s)
|
||||||
|
is_valid[BREAK] = _can_break(s)
|
||||||
|
cdef Transition best
|
||||||
|
cdef weight_t score = MIN_SCORE
|
||||||
|
cdef int i
|
||||||
|
for i in range(self.n_moves):
|
||||||
|
if scores[i] > score and is_valid[self.c[i].move]:
|
||||||
|
best = self.c[i]
|
||||||
|
score = scores[i]
|
||||||
|
# Label Shift moves with the best Right-Arc label, for non-monotonic
|
||||||
|
# actions
|
||||||
|
if best.move == SHIFT:
|
||||||
|
score = MIN_SCORE
|
||||||
|
for i in range(self.n_moves):
|
||||||
|
if self.c[i].move == RIGHT and scores[i] > score:
|
||||||
|
best.label = self.c[i].label
|
||||||
|
score = scores[i]
|
||||||
|
return best
|
||||||
|
|
||||||
|
|
||||||
|
cdef int _do_shift(const Transition* self, State* state) except -1:
|
||||||
|
# Set the dep label, in case we need it after we reduce
|
||||||
|
if NON_MONOTONIC:
|
||||||
|
get_s0(state).dep = self.label
|
||||||
|
push_stack(state)
|
||||||
|
|
||||||
|
|
||||||
|
cdef int _do_left(const Transition* self, State* state) except -1:
|
||||||
|
add_dep(state, state.i, state.stack[0], self.label)
|
||||||
|
pop_stack(state)
|
||||||
|
|
||||||
|
|
||||||
|
cdef int _do_right(const Transition* self, State* state) except -1:
|
||||||
|
add_dep(state, state.stack[0], state.i, self.label)
|
||||||
|
push_stack(state)
|
||||||
|
|
||||||
|
|
||||||
|
cdef int _do_reduce(const Transition* self, State* state) except -1:
|
||||||
|
# TODO: Huh? Is this some weirdness from the non-monotonic?
|
||||||
|
add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep)
|
||||||
|
pop_stack(state)
|
||||||
|
|
||||||
|
|
||||||
|
cdef int _do_break(const Transition* self, State* state) except -1:
|
||||||
|
state.sent[state.i-1].sent_end = True
|
||||||
|
while state.stack_len != 0:
|
||||||
|
if get_s0(state).head == 0:
|
||||||
|
get_s0(state).dep = 0
|
||||||
|
state.stack -= 1
|
||||||
|
state.stack_len -= 1
|
||||||
|
if not at_eol(state):
|
||||||
|
push_stack(state)
|
||||||
|
|
||||||
|
|
||||||
|
do_funcs[SHIFT] = _do_shift
|
||||||
|
do_funcs[REDUCE] = _do_reduce
|
||||||
|
do_funcs[LEFT] = _do_left
|
||||||
|
do_funcs[RIGHT] = _do_right
|
||||||
|
do_funcs[BREAK] = _do_break
|
||||||
|
|
||||||
|
|
||||||
|
cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||||
|
assert not at_eol(s)
|
||||||
|
cost = 0
|
||||||
|
cost += head_in_stack(s, s.i, gold.heads)
|
||||||
|
cost += children_in_stack(s, s.i, gold.heads)
|
||||||
|
if NON_MONOTONIC:
|
||||||
|
cost += gold[s.stack[0]] == s.i
|
||||||
|
# If we can break, and there's no cost to doing so, we should
|
||||||
|
if _can_break(s) and _break_cost(self, s, gold) == 0:
|
||||||
|
cost += 1
|
||||||
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||||
|
assert s.stack_len >= 1
|
||||||
|
cost = 0
|
||||||
|
if gold[s.i] == s.stack[0]:
|
||||||
|
return cost
|
||||||
|
cost += head_in_buffer(s, s.i, gold.heads)
|
||||||
|
cost += children_in_stack(s, s.i, gold.heads)
|
||||||
|
cost += head_in_stack(s, s.i, gold.heads)
|
||||||
|
if NON_MONOTONIC:
|
||||||
|
cost += gold[s.stack[0]] == s.i
|
||||||
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||||
|
assert s.stack_len >= 1
|
||||||
|
cost = 0
|
||||||
|
if gold[s.stack[0]] == s.i:
|
||||||
|
return cost
|
||||||
|
|
||||||
|
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
||||||
|
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
||||||
|
if NON_MONOTONIC and s.stack_len >= 2:
|
||||||
|
cost += gold[s.stack[0]] == s.stack[-1]
|
||||||
|
cost += gold[s.stack[0]] == s.stack[0]
|
||||||
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||||
|
cdef int cost = 0
|
||||||
|
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
||||||
|
if NON_MONOTONIC:
|
||||||
|
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
||||||
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||||
|
# When we break, we Reduce all of the words on the stack.
|
||||||
|
cdef int cost = 0
|
||||||
|
# Number of deps between S0...Sn and N0...Nn
|
||||||
|
for i in range(s.i, s.sent_len):
|
||||||
|
cost += children_in_stack(s, i, gold.heads)
|
||||||
|
cost += head_in_stack(s, i, gold.heads)
|
||||||
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
get_cost_funcs[SHIFT] = _shift_cost
|
||||||
|
get_cost_funcs[REDUCE] = _reduce_cost
|
||||||
|
get_cost_funcs[LEFT] = _left_cost
|
||||||
|
get_cost_funcs[RIGHT] = _right_cost
|
||||||
|
get_cost_funcs[BREAK] = _break_cost
|
||||||
|
|
||||||
|
|
||||||
cdef inline bint _can_shift(const State* s) nogil:
|
cdef inline bint _can_shift(const State* s) nogil:
|
||||||
|
@ -63,224 +211,3 @@ cdef inline bint _can_break(const State* s) nogil:
|
||||||
else:
|
else:
|
||||||
seen_headless = True
|
seen_headless = True
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
cdef int _shift_cost(const State* s, const int* gold) except -1:
|
|
||||||
assert not at_eol(s)
|
|
||||||
cost = 0
|
|
||||||
cost += head_in_stack(s, s.i, gold)
|
|
||||||
cost += children_in_stack(s, s.i, gold)
|
|
||||||
if NON_MONOTONIC:
|
|
||||||
cost += gold[s.stack[0]] == s.i
|
|
||||||
# If we can break, and there's no cost to doing so, we should
|
|
||||||
if _can_break(s) and _break_cost(s, gold) == 0:
|
|
||||||
cost += 1
|
|
||||||
return cost
|
|
||||||
|
|
||||||
|
|
||||||
cdef int _right_cost(const State* s, const int* gold) except -1:
|
|
||||||
assert s.stack_len >= 1
|
|
||||||
cost = 0
|
|
||||||
if gold[s.i] == s.stack[0]:
|
|
||||||
return cost
|
|
||||||
cost += head_in_buffer(s, s.i, gold)
|
|
||||||
cost += children_in_stack(s, s.i, gold)
|
|
||||||
cost += head_in_stack(s, s.i, gold)
|
|
||||||
if NON_MONOTONIC:
|
|
||||||
cost += gold[s.stack[0]] == s.i
|
|
||||||
return cost
|
|
||||||
|
|
||||||
|
|
||||||
cdef int _left_cost(const State* s, const int* gold) except -1:
|
|
||||||
assert s.stack_len >= 1
|
|
||||||
cost = 0
|
|
||||||
if gold[s.stack[0]] == s.i:
|
|
||||||
return cost
|
|
||||||
|
|
||||||
cost += head_in_buffer(s, s.stack[0], gold)
|
|
||||||
cost += children_in_buffer(s, s.stack[0], gold)
|
|
||||||
if NON_MONOTONIC and s.stack_len >= 2:
|
|
||||||
cost += gold[s.stack[0]] == s.stack[-1]
|
|
||||||
cost += gold[s.stack[0]] == s.stack[0]
|
|
||||||
return cost
|
|
||||||
|
|
||||||
|
|
||||||
cdef int _reduce_cost(const State* s, const int* gold) except -1:
|
|
||||||
cdef int cost = 0
|
|
||||||
cost += children_in_buffer(s, s.stack[0], gold)
|
|
||||||
if NON_MONOTONIC:
|
|
||||||
cost += head_in_buffer(s, s.stack[0], gold)
|
|
||||||
return cost
|
|
||||||
|
|
||||||
|
|
||||||
cdef int _break_cost(const State* s, const int* gold) except -1:
|
|
||||||
# When we break, we Reduce all of the words on the stack.
|
|
||||||
cdef int cost = 0
|
|
||||||
# Number of deps between S0...Sn and N0...Nn
|
|
||||||
for i in range(s.i, s.sent_len):
|
|
||||||
cost += children_in_stack(s, i, gold)
|
|
||||||
cost += head_in_stack(s, i, gold)
|
|
||||||
return cost
|
|
||||||
|
|
||||||
|
|
||||||
cdef class ArcEager(TransitionSystem):
|
|
||||||
def __init__(self, list left_labels, list right_labels):
|
|
||||||
self.mem = Pool()
|
|
||||||
left_labels.sort()
|
|
||||||
right_labels.sort()
|
|
||||||
if 'ROOT' in right_labels:
|
|
||||||
right_labels.pop(right_labels.index('ROOT'))
|
|
||||||
if 'ROOT' in left_labels:
|
|
||||||
left_labels.pop(left_labels.index('ROOT'))
|
|
||||||
self.n_moves = 3 + len(left_labels) + len(right_labels)
|
|
||||||
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
|
|
||||||
cdef int i = 0
|
|
||||||
moves[i].move = SHIFT
|
|
||||||
moves[i].label = 0
|
|
||||||
moves[i].clas = i
|
|
||||||
i += 1
|
|
||||||
moves[i].move = REDUCE
|
|
||||||
moves[i].label = 0
|
|
||||||
moves[i].clas = i
|
|
||||||
i += 1
|
|
||||||
self.label_ids = {'ROOT': 0}
|
|
||||||
cdef int label_id
|
|
||||||
for label_str in left_labels:
|
|
||||||
label_str = unicode(label_str)
|
|
||||||
label_id = self.label_ids.setdefault(label_str, len(self.label_ids))
|
|
||||||
moves[i].move = LEFT
|
|
||||||
moves[i].label = label_id
|
|
||||||
moves[i].clas = i
|
|
||||||
i += 1
|
|
||||||
for label_str in right_labels:
|
|
||||||
label_str = unicode(label_str)
|
|
||||||
label_id = self.label_ids.setdefault(label_str, len(self.label_ids))
|
|
||||||
moves[i].move = RIGHT
|
|
||||||
moves[i].label = label_id
|
|
||||||
moves[i].clas = i
|
|
||||||
i += 1
|
|
||||||
moves[i].move = BREAK
|
|
||||||
moves[i].label = 0
|
|
||||||
moves[i].clas = i
|
|
||||||
i += 1
|
|
||||||
self.c = moves
|
|
||||||
|
|
||||||
cdef int transition(self, State *s, const Transition* t) except -1:
|
|
||||||
if t.move == SHIFT:
|
|
||||||
# Set the dep label, in case we need it after we reduce
|
|
||||||
if NON_MONOTONIC:
|
|
||||||
s.sent[s.i].dep = t.label
|
|
||||||
push_stack(s)
|
|
||||||
elif t.move == LEFT:
|
|
||||||
add_dep(s, s.i, s.stack[0], t.label)
|
|
||||||
pop_stack(s)
|
|
||||||
elif t.move == RIGHT:
|
|
||||||
add_dep(s, s.stack[0], s.i, t.label)
|
|
||||||
push_stack(s)
|
|
||||||
elif t.move == REDUCE:
|
|
||||||
# TODO: Huh? Is this some weirdness from the non-monotonic?
|
|
||||||
add_dep(s, s.stack[-1], s.stack[0], get_s0(s).dep)
|
|
||||||
pop_stack(s)
|
|
||||||
elif t.move == BREAK:
|
|
||||||
s.sent[s.i-1].sent_end = True
|
|
||||||
while s.stack_len != 0:
|
|
||||||
if get_s0(s).head == 0:
|
|
||||||
get_s0(s).dep = 0
|
|
||||||
s.stack -= 1
|
|
||||||
s.stack_len -= 1
|
|
||||||
|
|
||||||
if not at_eol(s):
|
|
||||||
push_stack(s)
|
|
||||||
else:
|
|
||||||
raise Exception(t.move)
|
|
||||||
|
|
||||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
|
||||||
cdef bint[N_MOVES] valid
|
|
||||||
valid[SHIFT] = _can_shift(s)
|
|
||||||
valid[LEFT] = _can_left(s)
|
|
||||||
valid[RIGHT] = _can_right(s)
|
|
||||||
valid[REDUCE] = _can_reduce(s)
|
|
||||||
valid[BREAK] = _can_break(s)
|
|
||||||
|
|
||||||
cdef int best = -1
|
|
||||||
cdef weight_t score = 0
|
|
||||||
cdef weight_t best_r_score = -9000
|
|
||||||
cdef int best_r_label = -1
|
|
||||||
cdef int i
|
|
||||||
for i in range(self.n_moves):
|
|
||||||
if valid[self._moves[i].move] and (best == -1 or scores[i] > score):
|
|
||||||
best = i
|
|
||||||
score = scores[i]
|
|
||||||
if self._moves[i].move == RIGHT and scores[i] > best_r_score:
|
|
||||||
best_r_label = self._moves[i].label
|
|
||||||
assert best >= 0
|
|
||||||
cdef Transition t = self._moves[best]
|
|
||||||
t.score = score
|
|
||||||
if t.move == SHIFT:
|
|
||||||
t.label = best_r_label
|
|
||||||
return t
|
|
||||||
|
|
||||||
cdef Transition best_gold(self, Transition* guess, const weight_t* scores,
|
|
||||||
const State* s,
|
|
||||||
const int* gold_heads, const int* gold_labels) except *:
|
|
||||||
# If we can create a gold dependency, only one action can be correct
|
|
||||||
cdef int[N_MOVES] unl_costs
|
|
||||||
unl_costs[SHIFT] = _shift_cost(s, gold_heads) if _can_shift(s) else -1
|
|
||||||
unl_costs[LEFT] = _left_cost(s, gold_heads) if _can_left(s) else -1
|
|
||||||
unl_costs[RIGHT] = _right_cost(s, gold_heads) if _can_right(s) else -1
|
|
||||||
unl_costs[REDUCE] = _reduce_cost(s, gold_heads) if _can_reduce(s) else -1
|
|
||||||
unl_costs[BREAK] = _break_cost(s, gold_heads) if _can_break(s) else -1
|
|
||||||
|
|
||||||
guess.cost = unl_costs[guess.move]
|
|
||||||
cdef Transition t
|
|
||||||
cdef int target_label
|
|
||||||
cdef int i
|
|
||||||
if gold_heads[s.stack[0]] == s.i:
|
|
||||||
target_label = gold_labels[s.stack[0]]
|
|
||||||
if guess.move == LEFT:
|
|
||||||
guess.cost += guess.label != target_label
|
|
||||||
for i in range(self.n_moves):
|
|
||||||
t = self._moves[i]
|
|
||||||
if t.move == LEFT and t.label == target_label:
|
|
||||||
return t
|
|
||||||
elif gold_heads[s.i] == s.stack[0]:
|
|
||||||
target_label = gold_labels[s.i]
|
|
||||||
if guess.move == RIGHT:
|
|
||||||
if unl_costs[guess.move] != 0:
|
|
||||||
guess.cost += guess.label != target_label
|
|
||||||
for i in range(self.n_moves):
|
|
||||||
t = self._moves[i]
|
|
||||||
if t.label == target_label and unl_costs[t.move] == 0:
|
|
||||||
return t
|
|
||||||
|
|
||||||
cdef int best = -1
|
|
||||||
cdef weight_t score = -9000
|
|
||||||
for i in range(self.n_moves):
|
|
||||||
t = self._moves[i]
|
|
||||||
if unl_costs[t.move] == 0 and (best == -1 or scores[i] > score):
|
|
||||||
best = i
|
|
||||||
score = scores[i]
|
|
||||||
t = self._moves[best]
|
|
||||||
t.score = score
|
|
||||||
if best < 0:
|
|
||||||
msg = ("No gold move found for configuration.\n"
|
|
||||||
"Is the gold-standard parse a projective tree?\n"
|
|
||||||
"S unl cost: %d\n"
|
|
||||||
"D unl cost: %d\n"
|
|
||||||
"L unl cost: %d\n"
|
|
||||||
"R unl cost: %d\n"
|
|
||||||
"S0, S0 gold: %d, %d\n"
|
|
||||||
"N0, N0 gold: %d, %d\n"
|
|
||||||
|
|
||||||
)
|
|
||||||
fields = [unl_costs[SHIFT], unl_costs[REDUCE], unl_costs[LEFT],
|
|
||||||
unl_costs[RIGHT],
|
|
||||||
s.stack[0], gold_heads[s.stack[0]],
|
|
||||||
s.i, gold_heads[s.i]]
|
|
||||||
|
|
||||||
raise OracleError(msg % tuple(fields))
|
|
||||||
return t
|
|
||||||
|
|
||||||
|
|
||||||
class OracleError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from thinc.typedefs cimport weight_t
|
||||||
|
|
||||||
from ..structs cimport TokenC
|
from ..structs cimport TokenC
|
||||||
from ._state cimport State
|
from ._state cimport State
|
||||||
|
from .conll cimport GoldParse
|
||||||
|
|
||||||
|
|
||||||
cdef struct Transition:
|
cdef struct Transition:
|
||||||
|
@ -12,19 +12,13 @@ cdef struct Transition:
|
||||||
int label
|
int label
|
||||||
|
|
||||||
weight_t score
|
weight_t score
|
||||||
int cost
|
|
||||||
|
|
||||||
int (*get_cost)(const Transition* self, const State* state, const TokenC* gold) except -1
|
|
||||||
|
|
||||||
int (*is_valid)(const Transition* self, const State* state) except -1
|
|
||||||
|
|
||||||
|
int (*get_cost)(const Transition* self, const State* state, GoldParse gold) except -1
|
||||||
int (*do)(const Transition* self, State* state) except -1
|
int (*do)(const Transition* self, State* state) except -1
|
||||||
|
|
||||||
|
|
||||||
ctypedef int (*get_cost_func_t)(const Transition* self, const State* state,
|
ctypedef int (*get_cost_func_t)(const Transition* self, const State* state,
|
||||||
const TokenC* gold) except -1
|
GoldParse gold) except -1
|
||||||
|
|
||||||
ctypedef int (*is_valid_func_t)(const Transition* self, const State* state) except -1
|
|
||||||
|
|
||||||
ctypedef int (*do_func_t)(const Transition* self, State* state) except -1
|
ctypedef int (*do_func_t)(const Transition* self, State* state) except -1
|
||||||
|
|
||||||
|
@ -39,4 +33,4 @@ cdef class TransitionSystem:
|
||||||
cdef const Transition best_valid(self, const weight_t*, const State*) except *
|
cdef const Transition best_valid(self, const weight_t*, const State*) except *
|
||||||
|
|
||||||
cdef const Transition best_gold(self, const weight_t*, const State*,
|
cdef const Transition best_gold(self, const weight_t*, const State*,
|
||||||
const TokenC*) except *
|
GoldParse gold) except *
|
||||||
|
|
|
@ -25,22 +25,7 @@ cdef class TransitionSystem:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||||
cdef Transition best
|
raise NotImplementedError
|
||||||
cdef weight_t score = MIN_SCORE
|
|
||||||
cdef int i
|
|
||||||
for i in range(self.n_moves):
|
|
||||||
if scores[i] > score and self.c[i].is_valid(&self.c[i], s):
|
|
||||||
best = self.c[i]
|
|
||||||
score = scores[i]
|
|
||||||
# Label Shift moves with the best Right-Arc label, for non-monotonic
|
|
||||||
# actions
|
|
||||||
#if best.move == SHIFT:
|
|
||||||
# score = MIN_SCORE
|
|
||||||
# for i in range(self.n_moves):
|
|
||||||
# if self.c[i].move == RIGHT and scores[i] > score:
|
|
||||||
# best.label = self.c[i].label
|
|
||||||
# score = scores[i]
|
|
||||||
return best
|
|
||||||
|
|
||||||
cdef Transition best_gold(self, const weight_t* scores, const State* s,
|
cdef Transition best_gold(self, const weight_t* scores, const State* s,
|
||||||
const TokenC* gold) except *:
|
const TokenC* gold) except *:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user