mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
* Ensure high loss for invalid moves, and fix label reading for arc-eager
This commit is contained in:
parent
f5f15a1ef2
commit
fdabd93bfb
|
@ -35,14 +35,15 @@ cdef get_cost_func_t[N_MOVES] get_cost_funcs
|
|||
cdef class ArcEager(TransitionSystem):
|
||||
@classmethod
|
||||
def get_labels(cls, gold_parses):
|
||||
labels = {SHIFT: {0: True}, REDUCE: {0: True}, RIGHT: {0: True},
|
||||
LEFT: {0: True}, BREAK: {0: True}}
|
||||
labels = {SHIFT: {'ROOT': True}, REDUCE: {'ROOT': True}, RIGHT: {},
|
||||
LEFT: {}, BREAK: {'ROOT': True}}
|
||||
for parse in gold_parses:
|
||||
for i, (head, label) in enumerate(zip(parse.heads, parse.labels)):
|
||||
if head > i:
|
||||
labels[RIGHT][label] = True
|
||||
else:
|
||||
labels[LEFT][label] = True
|
||||
if label != 'ROOT':
|
||||
if head > i:
|
||||
labels[RIGHT][label] = True
|
||||
elif head < i:
|
||||
labels[LEFT][label] = True
|
||||
return labels
|
||||
|
||||
cdef Transition init_transition(self, int clas, int move, int label) except *:
|
||||
|
@ -71,6 +72,8 @@ cdef class ArcEager(TransitionSystem):
|
|||
if scores[i] > score and is_valid[self.c[i].move]:
|
||||
best = self.c[i]
|
||||
score = scores[i]
|
||||
assert best.clas < self.n_moves
|
||||
assert score > MIN_SCORE
|
||||
# Label Shift moves with the best Right-Arc label, for non-monotonic
|
||||
# actions
|
||||
if best.move == SHIFT:
|
||||
|
@ -85,7 +88,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
cdef int _do_shift(const Transition* self, State* state) except -1:
|
||||
# Set the dep label, in case we need it after we reduce
|
||||
if NON_MONOTONIC:
|
||||
get_s0(state).dep = self.label
|
||||
state.sent[state.i].dep = self.label
|
||||
push_stack(state)
|
||||
|
||||
|
||||
|
@ -124,7 +127,8 @@ do_funcs[BREAK] = _do_break
|
|||
|
||||
|
||||
cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
assert not at_eol(s)
|
||||
if not _can_shift(s):
|
||||
return 9000
|
||||
cost = 0
|
||||
cost += head_in_stack(s, s.i, gold.c_heads)
|
||||
cost += children_in_stack(s, s.i, gold.c_heads)
|
||||
|
@ -137,7 +141,8 @@ cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) exc
|
|||
|
||||
|
||||
cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
assert s.stack_len >= 1
|
||||
if not _can_right(s):
|
||||
return 9000
|
||||
cost = 0
|
||||
if gold.c_heads[s.i] == s.stack[0]:
|
||||
cost += self.label != gold.c_labels[s.i]
|
||||
|
@ -151,7 +156,8 @@ cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) exc
|
|||
|
||||
|
||||
cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
assert s.stack_len >= 1
|
||||
if not _can_left(s):
|
||||
return 9000
|
||||
cost = 0
|
||||
if gold.c_heads[s.stack[0]] == s.i:
|
||||
cost += self.label != gold.c_labels[s.stack[0]]
|
||||
|
@ -166,6 +172,8 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) exce
|
|||
|
||||
|
||||
cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
if not _can_reduce(s):
|
||||
return 9000
|
||||
cdef int cost = 0
|
||||
cost += children_in_buffer(s, s.stack[0], gold.c_heads)
|
||||
if NON_MONOTONIC:
|
||||
|
@ -174,6 +182,8 @@ cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) ex
|
|||
|
||||
|
||||
cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
if not _can_break(s):
|
||||
return 9000
|
||||
# When we break, we Reduce all of the words on the stack.
|
||||
cdef int cost = 0
|
||||
# Number of deps between S0...Sn and N0...Nn
|
||||
|
|
Loading…
Reference in New Issue
Block a user