mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-13 17:52:31 +03:00
* Refactor transition system oracles, to split out move and label cost. Preparing to add Unshift move. Will exclude non-monotonic.
This commit is contained in:
parent
e2578fbb90
commit
8f142c1838
|
@ -11,6 +11,7 @@ from ._state cimport count_left_kids
|
||||||
from ..structs cimport TokenC
|
from ..structs cimport TokenC
|
||||||
|
|
||||||
from .transition_system cimport do_func_t, get_cost_func_t
|
from .transition_system cimport do_func_t, get_cost_func_t
|
||||||
|
from .transition_system cimport move_cost_func_t, label_cost_func_t
|
||||||
from ..gold cimport GoldParse
|
from ..gold cimport GoldParse
|
||||||
from ..gold cimport GoldParseC
|
from ..gold cimport GoldParseC
|
||||||
|
|
||||||
|
@ -46,6 +47,35 @@ MOVE_NAMES[CONSTITUENT] = 'C'
|
||||||
MOVE_NAMES[ADJUST] = 'A'
|
MOVE_NAMES[ADJUST] = 'A'
|
||||||
|
|
||||||
|
|
||||||
|
# Helper functions for the arc-eager oracle
|
||||||
|
|
||||||
|
cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -1:
|
||||||
|
# When we push a word, we can't make arcs to or from the stack. So, we lose
|
||||||
|
# any of those arcs.
|
||||||
|
cdef int cost = 0
|
||||||
|
cost += head_in_stack(st, target, gold.heads)
|
||||||
|
cost += children_in_stack(st, target, gold.heads)
|
||||||
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1:
|
||||||
|
cdef int cost = 0
|
||||||
|
cost += children_in_buffer(st, target, gold.heads)
|
||||||
|
cost += head_in_buffer(st, target, gold.heads)
|
||||||
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
cdef int arc_cost(const GoldParseC* gold, int head, int child, int label) except -1:
|
||||||
|
if gold.heads[child] != head:
|
||||||
|
return 0
|
||||||
|
elif gold.labels[child] == -1:
|
||||||
|
return 0
|
||||||
|
elif gold.labels[child] == label:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
cdef class Shift:
|
cdef class Shift:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(const State* s, int label) except -1:
|
cdef bint is_valid(const State* s, int label) except -1:
|
||||||
|
@ -62,14 +92,20 @@ cdef class Shift:
|
||||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
if not Shift.is_valid(s, label):
|
if not Shift.is_valid(s, label):
|
||||||
return 9000
|
return 9000
|
||||||
cost = 0
|
return Shift.move_cost(s, gold) + Shift.label_cost(s, gold, label)
|
||||||
cost += head_in_stack(s, s.i, gold.heads)
|
|
||||||
cost += children_in_stack(s, s.i, gold.heads)
|
@staticmethod
|
||||||
|
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||||
|
cdef int cost = push_cost(s, gold, s.i)
|
||||||
# If we can break, and there's no cost to doing so, we should
|
# If we can break, and there's no cost to doing so, we should
|
||||||
if Break.is_valid(s, label) and Break.cost(s, gold, -1) == 0:
|
if Break.is_valid(s, -1) and Break.cost(s, gold, -1) == 0:
|
||||||
cost += 1
|
cost += 1
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
cdef class Reduce:
|
cdef class Reduce:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -89,11 +125,19 @@ cdef class Reduce:
|
||||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
if not Reduce.is_valid(s, label):
|
if not Reduce.is_valid(s, label):
|
||||||
return 9000
|
return 9000
|
||||||
cdef int cost = 0
|
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
|
||||||
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
|
||||||
|
@staticmethod
|
||||||
|
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||||
if NON_MONOTONIC:
|
if NON_MONOTONIC:
|
||||||
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
return pop_cost(s, gold, s.stack[0])
|
||||||
return cost
|
else:
|
||||||
|
return children_in_buffer(s, s.stack[0], gold.heads)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cdef class LeftArc:
|
cdef class LeftArc:
|
||||||
|
@ -117,19 +161,21 @@ cdef class LeftArc:
|
||||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
if not LeftArc.is_valid(s, label):
|
if not LeftArc.is_valid(s, label):
|
||||||
return 9000
|
return 9000
|
||||||
|
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||||
|
if not LeftArc.is_valid(s, -1):
|
||||||
|
return 9000
|
||||||
cdef int cost = 0
|
cdef int cost = 0
|
||||||
if gold.heads[s.stack[0]] == s.i:
|
if gold.heads[s.stack[0]] == s.i:
|
||||||
cost += label != -1 and label != gold.labels[s.stack[0]]
|
|
||||||
return cost
|
return cost
|
||||||
# If we're at EOL, then the left arc will add an arc to ROOT.
|
|
||||||
elif at_eol(s):
|
elif at_eol(s):
|
||||||
# Are we root?
|
# Are we root?
|
||||||
if gold.labels[s.stack[0]] != -1:
|
if gold.labels[s.stack[0]] != -1:
|
||||||
# If we're at EOL, prefer to reduce or break over left-arc
|
# If we're at EOL, prefer to reduce or break over left-arc
|
||||||
if Reduce.is_valid(s, -1) or Break.is_valid(s, -1):
|
if Reduce.is_valid(s, -1) or Break.is_valid(s, -1):
|
||||||
cost += gold.heads[s.stack[0]] != s.stack[0]
|
cost += gold.heads[s.stack[0]] != s.stack[0]
|
||||||
# Are we labelling correctly?
|
|
||||||
cost += label != -1 and label != gold.labels[s.stack[0]]
|
|
||||||
return cost
|
return cost
|
||||||
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
||||||
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
||||||
|
@ -139,6 +185,14 @@ cdef class LeftArc:
|
||||||
cost += gold.heads[s.stack[0]] == s.stack[0]
|
cost += gold.heads[s.stack[0]] == s.stack[0]
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
|
if label == -1 or gold.labels[s.stack[0]] == -1:
|
||||||
|
return 0
|
||||||
|
if gold.heads[s.stack[0]] == s.i and label != gold.labels[s.stack[0]]:
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
cdef class RightArc:
|
cdef class RightArc:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -154,17 +208,25 @@ cdef class RightArc:
|
||||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
if not RightArc.is_valid(s, label):
|
if not RightArc.is_valid(s, label):
|
||||||
return 9000
|
return 9000
|
||||||
cdef int cost
|
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
|
||||||
cost = 0
|
|
||||||
if gold.heads[s.i] == s.stack[0]:
|
@staticmethod
|
||||||
cost += label != -1 and label != gold.labels[s.i]
|
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||||
return cost
|
return push_cost(s, gold, s.i) - (gold.heads[s.i] == s.stack[0])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
|
return arc_cost(gold, s.stack[0], s.i, label)
|
||||||
|
#cdef int cost = 0
|
||||||
|
#if gold.heads[s.i] == s.stack[0]:
|
||||||
|
# cost += label != -1 and label != gold.labels[s.i]
|
||||||
|
# return cost
|
||||||
# This indicates missing head
|
# This indicates missing head
|
||||||
if gold.labels[s.i] != -1:
|
#if gold.labels[s.i] != -1:
|
||||||
cost += head_in_buffer(s, s.i, gold.heads)
|
# cost += head_in_buffer(s, s.i, gold.heads)
|
||||||
cost += children_in_stack(s, s.i, gold.heads)
|
#cost += children_in_stack(s, s.i, gold.heads)
|
||||||
cost += head_in_stack(s, s.i, gold.heads)
|
#cost += head_in_stack(s, s.i, gold.heads)
|
||||||
return cost
|
#return cost
|
||||||
|
|
||||||
|
|
||||||
cdef class Break:
|
cdef class Break:
|
||||||
|
@ -207,6 +269,11 @@ cdef class Break:
|
||||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
if not Break.is_valid(s, label):
|
if not Break.is_valid(s, label):
|
||||||
return 9000
|
return 9000
|
||||||
|
else:
|
||||||
|
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||||
# When we break, we Reduce all of the words on the stack.
|
# When we break, we Reduce all of the words on the stack.
|
||||||
cdef int cost = 0
|
cdef int cost = 0
|
||||||
# Number of deps between S0...Sn and N0...Nn
|
# Number of deps between S0...Sn and N0...Nn
|
||||||
|
@ -215,6 +282,10 @@ cdef class Break:
|
||||||
cost += head_in_stack(s, i, gold.heads)
|
cost += head_in_stack(s, i, gold.heads)
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
cdef class Constituent:
|
cdef class Constituent:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -280,6 +351,17 @@ cdef class Constituent:
|
||||||
# loss = 1 # If we see the start position, set loss to 1
|
# loss = 1 # If we see the start position, set loss to 1
|
||||||
#return loss
|
#return loss
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||||
|
if not Constituent.is_valid(s, -1):
|
||||||
|
return 9000
|
||||||
|
raise Exception("Constituent move should be disabled currently")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cdef class Adjust:
|
cdef class Adjust:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -318,6 +400,16 @@ cdef class Adjust:
|
||||||
if not Adjust.is_valid(s, label):
|
if not Adjust.is_valid(s, label):
|
||||||
return 9000
|
return 9000
|
||||||
raise Exception("Adjust move should be disabled currently")
|
raise Exception("Adjust move should be disabled currently")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||||
|
if not Adjust.is_valid(s, -1):
|
||||||
|
return 9000
|
||||||
|
raise Exception("Adjust move should be disabled currently")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
|
return 0
|
||||||
# The gold standard is indexed by end, then by start, then a set of labels
|
# The gold standard is indexed by end, then by start, then a set of labels
|
||||||
#gold_starts = gold.brackets(get_s0(s).r_edge, {})
|
#gold_starts = gold.brackets(get_s0(s).r_edge, {})
|
||||||
# Case 1: There are 0 brackets ending at this word.
|
# Case 1: There are 0 brackets ending at this word.
|
||||||
|
@ -460,32 +552,36 @@ cdef class ArcEager(TransitionSystem):
|
||||||
output[i] = is_valid[self.c[i].move]
|
output[i] = is_valid[self.c[i].move]
|
||||||
|
|
||||||
cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1:
|
cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1:
|
||||||
cdef Transition move
|
cdef int i, move, label
|
||||||
move.label = -1
|
cdef label_cost_func_t[N_MOVES] label_cost_funcs
|
||||||
|
cdef move_cost_func_t[N_MOVES] move_cost_funcs
|
||||||
cdef int[N_MOVES] move_costs
|
cdef int[N_MOVES] move_costs
|
||||||
move_costs[SHIFT] = Shift.cost(s, &gold.c, -1)
|
for i in range(N_MOVES):
|
||||||
move_costs[REDUCE] = Reduce.cost(s, &gold.c, -1)
|
move_costs[i] = -1
|
||||||
move_costs[LEFT] = LeftArc.cost(s, &gold.c, -1)
|
move_cost_funcs[SHIFT] = Shift.move_cost
|
||||||
move_costs[RIGHT] = RightArc.cost(s, &gold.c, -1)
|
move_cost_funcs[REDUCE] = Reduce.move_cost
|
||||||
move_costs[BREAK] = Break.cost(s, &gold.c, -1)
|
move_cost_funcs[LEFT] = LeftArc.move_cost
|
||||||
move_costs[CONSTITUENT] = Constituent.cost(s, &gold.c, -1)
|
move_cost_funcs[RIGHT] = RightArc.move_cost
|
||||||
move_costs[ADJUST] = Adjust.cost(s, &gold.c, -1)
|
move_cost_funcs[BREAK] = Break.move_cost
|
||||||
|
move_cost_funcs[CONSTITUENT] = Constituent.move_cost
|
||||||
|
move_cost_funcs[ADJUST] = Adjust.move_cost
|
||||||
|
|
||||||
|
label_cost_funcs[SHIFT] = Shift.label_cost
|
||||||
|
label_cost_funcs[REDUCE] = Reduce.label_cost
|
||||||
|
label_cost_funcs[LEFT] = LeftArc.label_cost
|
||||||
|
label_cost_funcs[RIGHT] = RightArc.label_cost
|
||||||
|
label_cost_funcs[BREAK] = Break.label_cost
|
||||||
|
label_cost_funcs[CONSTITUENT] = Constituent.label_cost
|
||||||
|
label_cost_funcs[ADJUST] = Adjust.label_cost
|
||||||
|
|
||||||
cdef int i, label
|
|
||||||
cdef int* labels = gold.c.labels
|
cdef int* labels = gold.c.labels
|
||||||
cdef int* heads = gold.c.heads
|
cdef int* heads = gold.c.heads
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
move = self.c[i]
|
move = self.c[i].move
|
||||||
output[i] = move_costs[move.move]
|
label = self.c[i].label
|
||||||
if output[i] == 0:
|
if move_costs[move] == -1:
|
||||||
label = -1
|
move_costs[move] = move_cost_funcs[move](s, &gold.c)
|
||||||
if move.move == RIGHT and heads[s.i] == s.stack[0]:
|
output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label)
|
||||||
label = labels[s.i]
|
|
||||||
if move.move == LEFT and heads[s.stack[0]] == s.i:
|
|
||||||
label = labels[s.stack[0]]
|
|
||||||
elif move.move == LEFT and at_eol(s) and (Reduce.is_valid(s, -1) or Break.is_valid(s, 1)):
|
|
||||||
label = labels[s.stack[0]]
|
|
||||||
output[i] += move.label != label and label != -1
|
|
||||||
|
|
||||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||||
cdef bint[N_MOVES] is_valid
|
cdef bint[N_MOVES] is_valid
|
||||||
|
|
|
@ -21,6 +21,8 @@ cdef struct Transition:
|
||||||
|
|
||||||
|
|
||||||
ctypedef int (*get_cost_func_t)(const State* state, const GoldParseC* gold, int label) except -1
|
ctypedef int (*get_cost_func_t)(const State* state, const GoldParseC* gold, int label) except -1
|
||||||
|
ctypedef int (*move_cost_func_t)(const State* state, const GoldParseC* gold) except -1
|
||||||
|
ctypedef int (*label_cost_func_t)(const State* state, const GoldParseC* gold, int label) except -1
|
||||||
|
|
||||||
ctypedef int (*do_func_t)(State* state, int label) except -1
|
ctypedef int (*do_func_t)(State* state, int label) except -1
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user