mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
* Refactor arc_eager to use new TransitionSystem base class. Need to fix oracle
This commit is contained in:
parent
b063001596
commit
8eadb984cb
|
@ -4,25 +4,8 @@ from thinc.typedefs cimport weight_t
|
|||
|
||||
|
||||
from ._state cimport State
|
||||
from .transition_system cimport TransitionSystem, Transition
|
||||
|
||||
|
||||
cdef struct Transition:
|
||||
int clas
|
||||
int move
|
||||
int label
|
||||
int cost
|
||||
weight_t score
|
||||
|
||||
|
||||
cdef class TransitionSystem:
|
||||
cdef Pool mem
|
||||
cdef readonly int n_moves
|
||||
cdef dict label_ids
|
||||
|
||||
cdef const Transition* _moves
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *
|
||||
cdef Transition best_gold(self, Transition* guess, const weight_t* scores,
|
||||
const State* s,
|
||||
const int* gold_heads, const int* gold_labels) except *
|
||||
cdef int transition(self, State *s, const Transition* t) except -1
|
||||
cdef class ArcEager(TransitionSystem):
|
||||
pass
|
||||
|
|
|
@ -8,11 +8,17 @@ from ._state cimport head_in_stack, children_in_stack
|
|||
|
||||
from ..structs cimport TokenC
|
||||
|
||||
from .transition_system cimport do_func_t, get_cost_func_t
|
||||
from .conll cimport GoldParse
|
||||
|
||||
|
||||
DEF NON_MONOTONIC = True
|
||||
DEF USE_BREAK = True
|
||||
|
||||
cdef weight_t MIN_SCORE = -90000
|
||||
|
||||
# Break transition from here
|
||||
# http://www.aclweb.org/anthology/P13-1074
|
||||
cdef enum:
|
||||
SHIFT
|
||||
REDUCE
|
||||
|
@ -21,8 +27,150 @@ cdef enum:
|
|||
BREAK
|
||||
N_MOVES
|
||||
|
||||
# Break transition from here
|
||||
# http://www.aclweb.org/anthology/P13-1074
|
||||
|
||||
cdef do_func_t[N_MOVES] do_funcs
|
||||
cdef get_cost_func_t[N_MOVES] get_cost_funcs
|
||||
|
||||
|
||||
cdef class ArcEager(TransitionSystem):
|
||||
cdef Transition init_transition(self, int clas, int move, int label) except *:
|
||||
return Transition(
|
||||
score=0,
|
||||
clas=i,
|
||||
move=move,
|
||||
label=label,
|
||||
do=do_funcs[move],
|
||||
get_cost=get_cost_funcs[move])
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||
cdef bint[N_MOVES] is_valid
|
||||
is_valid[SHIFT] = _can_shift(s)
|
||||
is_valid[REDUCE] = _can_reduce(s)
|
||||
is_valid[LEFT] = _can_left(s)
|
||||
is_valid[RIGHT] = _can_right(s)
|
||||
is_valid[BREAK] = _can_break(s)
|
||||
cdef Transition best
|
||||
cdef weight_t score = MIN_SCORE
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
if scores[i] > score and is_valid[self.c[i].move]:
|
||||
best = self.c[i]
|
||||
score = scores[i]
|
||||
# Label Shift moves with the best Right-Arc label, for non-monotonic
|
||||
# actions
|
||||
if best.move == SHIFT:
|
||||
score = MIN_SCORE
|
||||
for i in range(self.n_moves):
|
||||
if self.c[i].move == RIGHT and scores[i] > score:
|
||||
best.label = self.c[i].label
|
||||
score = scores[i]
|
||||
return best
|
||||
|
||||
|
||||
cdef int _do_shift(const Transition* self, State* state) except -1:
|
||||
# Set the dep label, in case we need it after we reduce
|
||||
if NON_MONOTONIC:
|
||||
get_s0(state).dep = self.label
|
||||
push_stack(state)
|
||||
|
||||
|
||||
cdef int _do_left(const Transition* self, State* state) except -1:
|
||||
add_dep(state, state.i, state.stack[0], self.label)
|
||||
pop_stack(state)
|
||||
|
||||
|
||||
cdef int _do_right(const Transition* self, State* state) except -1:
|
||||
add_dep(state, state.stack[0], state.i, self.label)
|
||||
push_stack(state)
|
||||
|
||||
|
||||
cdef int _do_reduce(const Transition* self, State* state) except -1:
|
||||
# TODO: Huh? Is this some weirdness from the non-monotonic?
|
||||
add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep)
|
||||
pop_stack(state)
|
||||
|
||||
|
||||
cdef int _do_break(const Transition* self, State* state) except -1:
|
||||
state.sent[state.i-1].sent_end = True
|
||||
while state.stack_len != 0:
|
||||
if get_s0(state).head == 0:
|
||||
get_s0(state).dep = 0
|
||||
state.stack -= 1
|
||||
state.stack_len -= 1
|
||||
if not at_eol(state):
|
||||
push_stack(state)
|
||||
|
||||
|
||||
do_funcs[SHIFT] = _do_shift
|
||||
do_funcs[REDUCE] = _do_reduce
|
||||
do_funcs[LEFT] = _do_left
|
||||
do_funcs[RIGHT] = _do_right
|
||||
do_funcs[BREAK] = _do_break
|
||||
|
||||
|
||||
cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
assert not at_eol(s)
|
||||
cost = 0
|
||||
cost += head_in_stack(s, s.i, gold.heads)
|
||||
cost += children_in_stack(s, s.i, gold.heads)
|
||||
if NON_MONOTONIC:
|
||||
cost += gold[s.stack[0]] == s.i
|
||||
# If we can break, and there's no cost to doing so, we should
|
||||
if _can_break(s) and _break_cost(self, s, gold) == 0:
|
||||
cost += 1
|
||||
return cost
|
||||
|
||||
|
||||
cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
assert s.stack_len >= 1
|
||||
cost = 0
|
||||
if gold[s.i] == s.stack[0]:
|
||||
return cost
|
||||
cost += head_in_buffer(s, s.i, gold.heads)
|
||||
cost += children_in_stack(s, s.i, gold.heads)
|
||||
cost += head_in_stack(s, s.i, gold.heads)
|
||||
if NON_MONOTONIC:
|
||||
cost += gold[s.stack[0]] == s.i
|
||||
return cost
|
||||
|
||||
|
||||
cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
assert s.stack_len >= 1
|
||||
cost = 0
|
||||
if gold[s.stack[0]] == s.i:
|
||||
return cost
|
||||
|
||||
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
||||
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
||||
if NON_MONOTONIC and s.stack_len >= 2:
|
||||
cost += gold[s.stack[0]] == s.stack[-1]
|
||||
cost += gold[s.stack[0]] == s.stack[0]
|
||||
return cost
|
||||
|
||||
|
||||
cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
cdef int cost = 0
|
||||
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
||||
if NON_MONOTONIC:
|
||||
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
||||
return cost
|
||||
|
||||
|
||||
cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
# When we break, we Reduce all of the words on the stack.
|
||||
cdef int cost = 0
|
||||
# Number of deps between S0...Sn and N0...Nn
|
||||
for i in range(s.i, s.sent_len):
|
||||
cost += children_in_stack(s, i, gold.heads)
|
||||
cost += head_in_stack(s, i, gold.heads)
|
||||
return cost
|
||||
|
||||
|
||||
get_cost_funcs[SHIFT] = _shift_cost
|
||||
get_cost_funcs[REDUCE] = _reduce_cost
|
||||
get_cost_funcs[LEFT] = _left_cost
|
||||
get_cost_funcs[RIGHT] = _right_cost
|
||||
get_cost_funcs[BREAK] = _break_cost
|
||||
|
||||
|
||||
cdef inline bint _can_shift(const State* s) nogil:
|
||||
|
@ -63,224 +211,3 @@ cdef inline bint _can_break(const State* s) nogil:
|
|||
else:
|
||||
seen_headless = True
|
||||
return True
|
||||
|
||||
|
||||
cdef int _shift_cost(const State* s, const int* gold) except -1:
|
||||
assert not at_eol(s)
|
||||
cost = 0
|
||||
cost += head_in_stack(s, s.i, gold)
|
||||
cost += children_in_stack(s, s.i, gold)
|
||||
if NON_MONOTONIC:
|
||||
cost += gold[s.stack[0]] == s.i
|
||||
# If we can break, and there's no cost to doing so, we should
|
||||
if _can_break(s) and _break_cost(s, gold) == 0:
|
||||
cost += 1
|
||||
return cost
|
||||
|
||||
|
||||
cdef int _right_cost(const State* s, const int* gold) except -1:
|
||||
assert s.stack_len >= 1
|
||||
cost = 0
|
||||
if gold[s.i] == s.stack[0]:
|
||||
return cost
|
||||
cost += head_in_buffer(s, s.i, gold)
|
||||
cost += children_in_stack(s, s.i, gold)
|
||||
cost += head_in_stack(s, s.i, gold)
|
||||
if NON_MONOTONIC:
|
||||
cost += gold[s.stack[0]] == s.i
|
||||
return cost
|
||||
|
||||
|
||||
cdef int _left_cost(const State* s, const int* gold) except -1:
|
||||
assert s.stack_len >= 1
|
||||
cost = 0
|
||||
if gold[s.stack[0]] == s.i:
|
||||
return cost
|
||||
|
||||
cost += head_in_buffer(s, s.stack[0], gold)
|
||||
cost += children_in_buffer(s, s.stack[0], gold)
|
||||
if NON_MONOTONIC and s.stack_len >= 2:
|
||||
cost += gold[s.stack[0]] == s.stack[-1]
|
||||
cost += gold[s.stack[0]] == s.stack[0]
|
||||
return cost
|
||||
|
||||
|
||||
cdef int _reduce_cost(const State* s, const int* gold) except -1:
|
||||
cdef int cost = 0
|
||||
cost += children_in_buffer(s, s.stack[0], gold)
|
||||
if NON_MONOTONIC:
|
||||
cost += head_in_buffer(s, s.stack[0], gold)
|
||||
return cost
|
||||
|
||||
|
||||
cdef int _break_cost(const State* s, const int* gold) except -1:
|
||||
# When we break, we Reduce all of the words on the stack.
|
||||
cdef int cost = 0
|
||||
# Number of deps between S0...Sn and N0...Nn
|
||||
for i in range(s.i, s.sent_len):
|
||||
cost += children_in_stack(s, i, gold)
|
||||
cost += head_in_stack(s, i, gold)
|
||||
return cost
|
||||
|
||||
|
||||
cdef class ArcEager(TransitionSystem):
|
||||
def __init__(self, list left_labels, list right_labels):
|
||||
self.mem = Pool()
|
||||
left_labels.sort()
|
||||
right_labels.sort()
|
||||
if 'ROOT' in right_labels:
|
||||
right_labels.pop(right_labels.index('ROOT'))
|
||||
if 'ROOT' in left_labels:
|
||||
left_labels.pop(left_labels.index('ROOT'))
|
||||
self.n_moves = 3 + len(left_labels) + len(right_labels)
|
||||
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
|
||||
cdef int i = 0
|
||||
moves[i].move = SHIFT
|
||||
moves[i].label = 0
|
||||
moves[i].clas = i
|
||||
i += 1
|
||||
moves[i].move = REDUCE
|
||||
moves[i].label = 0
|
||||
moves[i].clas = i
|
||||
i += 1
|
||||
self.label_ids = {'ROOT': 0}
|
||||
cdef int label_id
|
||||
for label_str in left_labels:
|
||||
label_str = unicode(label_str)
|
||||
label_id = self.label_ids.setdefault(label_str, len(self.label_ids))
|
||||
moves[i].move = LEFT
|
||||
moves[i].label = label_id
|
||||
moves[i].clas = i
|
||||
i += 1
|
||||
for label_str in right_labels:
|
||||
label_str = unicode(label_str)
|
||||
label_id = self.label_ids.setdefault(label_str, len(self.label_ids))
|
||||
moves[i].move = RIGHT
|
||||
moves[i].label = label_id
|
||||
moves[i].clas = i
|
||||
i += 1
|
||||
moves[i].move = BREAK
|
||||
moves[i].label = 0
|
||||
moves[i].clas = i
|
||||
i += 1
|
||||
self.c = moves
|
||||
|
||||
cdef int transition(self, State *s, const Transition* t) except -1:
|
||||
if t.move == SHIFT:
|
||||
# Set the dep label, in case we need it after we reduce
|
||||
if NON_MONOTONIC:
|
||||
s.sent[s.i].dep = t.label
|
||||
push_stack(s)
|
||||
elif t.move == LEFT:
|
||||
add_dep(s, s.i, s.stack[0], t.label)
|
||||
pop_stack(s)
|
||||
elif t.move == RIGHT:
|
||||
add_dep(s, s.stack[0], s.i, t.label)
|
||||
push_stack(s)
|
||||
elif t.move == REDUCE:
|
||||
# TODO: Huh? Is this some weirdness from the non-monotonic?
|
||||
add_dep(s, s.stack[-1], s.stack[0], get_s0(s).dep)
|
||||
pop_stack(s)
|
||||
elif t.move == BREAK:
|
||||
s.sent[s.i-1].sent_end = True
|
||||
while s.stack_len != 0:
|
||||
if get_s0(s).head == 0:
|
||||
get_s0(s).dep = 0
|
||||
s.stack -= 1
|
||||
s.stack_len -= 1
|
||||
|
||||
if not at_eol(s):
|
||||
push_stack(s)
|
||||
else:
|
||||
raise Exception(t.move)
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||
cdef bint[N_MOVES] valid
|
||||
valid[SHIFT] = _can_shift(s)
|
||||
valid[LEFT] = _can_left(s)
|
||||
valid[RIGHT] = _can_right(s)
|
||||
valid[REDUCE] = _can_reduce(s)
|
||||
valid[BREAK] = _can_break(s)
|
||||
|
||||
cdef int best = -1
|
||||
cdef weight_t score = 0
|
||||
cdef weight_t best_r_score = -9000
|
||||
cdef int best_r_label = -1
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
if valid[self._moves[i].move] and (best == -1 or scores[i] > score):
|
||||
best = i
|
||||
score = scores[i]
|
||||
if self._moves[i].move == RIGHT and scores[i] > best_r_score:
|
||||
best_r_label = self._moves[i].label
|
||||
assert best >= 0
|
||||
cdef Transition t = self._moves[best]
|
||||
t.score = score
|
||||
if t.move == SHIFT:
|
||||
t.label = best_r_label
|
||||
return t
|
||||
|
||||
cdef Transition best_gold(self, Transition* guess, const weight_t* scores,
|
||||
const State* s,
|
||||
const int* gold_heads, const int* gold_labels) except *:
|
||||
# If we can create a gold dependency, only one action can be correct
|
||||
cdef int[N_MOVES] unl_costs
|
||||
unl_costs[SHIFT] = _shift_cost(s, gold_heads) if _can_shift(s) else -1
|
||||
unl_costs[LEFT] = _left_cost(s, gold_heads) if _can_left(s) else -1
|
||||
unl_costs[RIGHT] = _right_cost(s, gold_heads) if _can_right(s) else -1
|
||||
unl_costs[REDUCE] = _reduce_cost(s, gold_heads) if _can_reduce(s) else -1
|
||||
unl_costs[BREAK] = _break_cost(s, gold_heads) if _can_break(s) else -1
|
||||
|
||||
guess.cost = unl_costs[guess.move]
|
||||
cdef Transition t
|
||||
cdef int target_label
|
||||
cdef int i
|
||||
if gold_heads[s.stack[0]] == s.i:
|
||||
target_label = gold_labels[s.stack[0]]
|
||||
if guess.move == LEFT:
|
||||
guess.cost += guess.label != target_label
|
||||
for i in range(self.n_moves):
|
||||
t = self._moves[i]
|
||||
if t.move == LEFT and t.label == target_label:
|
||||
return t
|
||||
elif gold_heads[s.i] == s.stack[0]:
|
||||
target_label = gold_labels[s.i]
|
||||
if guess.move == RIGHT:
|
||||
if unl_costs[guess.move] != 0:
|
||||
guess.cost += guess.label != target_label
|
||||
for i in range(self.n_moves):
|
||||
t = self._moves[i]
|
||||
if t.label == target_label and unl_costs[t.move] == 0:
|
||||
return t
|
||||
|
||||
cdef int best = -1
|
||||
cdef weight_t score = -9000
|
||||
for i in range(self.n_moves):
|
||||
t = self._moves[i]
|
||||
if unl_costs[t.move] == 0 and (best == -1 or scores[i] > score):
|
||||
best = i
|
||||
score = scores[i]
|
||||
t = self._moves[best]
|
||||
t.score = score
|
||||
if best < 0:
|
||||
msg = ("No gold move found for configuration.\n"
|
||||
"Is the gold-standard parse a projective tree?\n"
|
||||
"S unl cost: %d\n"
|
||||
"D unl cost: %d\n"
|
||||
"L unl cost: %d\n"
|
||||
"R unl cost: %d\n"
|
||||
"S0, S0 gold: %d, %d\n"
|
||||
"N0, N0 gold: %d, %d\n"
|
||||
|
||||
)
|
||||
fields = [unl_costs[SHIFT], unl_costs[REDUCE], unl_costs[LEFT],
|
||||
unl_costs[RIGHT],
|
||||
s.stack[0], gold_heads[s.stack[0]],
|
||||
s.i, gold_heads[s.i]]
|
||||
|
||||
raise OracleError(msg % tuple(fields))
|
||||
return t
|
||||
|
||||
|
||||
class OracleError(Exception):
|
||||
pass
|
||||
|
|
|
@ -3,7 +3,7 @@ from thinc.typedefs cimport weight_t
|
|||
|
||||
from ..structs cimport TokenC
|
||||
from ._state cimport State
|
||||
|
||||
from .conll cimport GoldParse
|
||||
|
||||
|
||||
cdef struct Transition:
|
||||
|
@ -12,19 +12,13 @@ cdef struct Transition:
|
|||
int label
|
||||
|
||||
weight_t score
|
||||
int cost
|
||||
|
||||
int (*get_cost)(const Transition* self, const State* state, const TokenC* gold) except -1
|
||||
|
||||
int (*is_valid)(const Transition* self, const State* state) except -1
|
||||
|
||||
int (*get_cost)(const Transition* self, const State* state, GoldParse gold) except -1
|
||||
int (*do)(const Transition* self, State* state) except -1
|
||||
|
||||
|
||||
ctypedef int (*get_cost_func_t)(const Transition* self, const State* state,
|
||||
const TokenC* gold) except -1
|
||||
|
||||
ctypedef int (*is_valid_func_t)(const Transition* self, const State* state) except -1
|
||||
GoldParse gold) except -1
|
||||
|
||||
ctypedef int (*do_func_t)(const Transition* self, State* state) except -1
|
||||
|
||||
|
@ -39,4 +33,4 @@ cdef class TransitionSystem:
|
|||
cdef const Transition best_valid(self, const weight_t*, const State*) except *
|
||||
|
||||
cdef const Transition best_gold(self, const weight_t*, const State*,
|
||||
const TokenC*) except *
|
||||
GoldParse gold) except *
|
||||
|
|
|
@ -25,22 +25,7 @@ cdef class TransitionSystem:
|
|||
raise NotImplementedError
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||
cdef Transition best
|
||||
cdef weight_t score = MIN_SCORE
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
if scores[i] > score and self.c[i].is_valid(&self.c[i], s):
|
||||
best = self.c[i]
|
||||
score = scores[i]
|
||||
# Label Shift moves with the best Right-Arc label, for non-monotonic
|
||||
# actions
|
||||
#if best.move == SHIFT:
|
||||
# score = MIN_SCORE
|
||||
# for i in range(self.n_moves):
|
||||
# if self.c[i].move == RIGHT and scores[i] > score:
|
||||
# best.label = self.c[i].label
|
||||
# score = scores[i]
|
||||
return best
|
||||
raise NotImplementedError
|
||||
|
||||
cdef Transition best_gold(self, const weight_t* scores, const State* s,
|
||||
const TokenC* gold) except *:
|
||||
|
|
Loading…
Reference in New Issue
Block a user