mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
* Refactor transition system oracles, to split out move and label cost. Preparing to add Unshift move. Will exclude non-monotonic.
This commit is contained in:
parent
e2578fbb90
commit
8f142c1838
|
@ -11,6 +11,7 @@ from ._state cimport count_left_kids
|
|||
from ..structs cimport TokenC
|
||||
|
||||
from .transition_system cimport do_func_t, get_cost_func_t
|
||||
from .transition_system cimport move_cost_func_t, label_cost_func_t
|
||||
from ..gold cimport GoldParse
|
||||
from ..gold cimport GoldParseC
|
||||
|
||||
|
@ -46,6 +47,35 @@ MOVE_NAMES[CONSTITUENT] = 'C'
|
|||
MOVE_NAMES[ADJUST] = 'A'
|
||||
|
||||
|
||||
# Helper functions for the arc-eager oracle
|
||||
|
||||
cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -1:
|
||||
# When we push a word, we can't make arcs to or from the stack. So, we lose
|
||||
# any of those arcs.
|
||||
cdef int cost = 0
|
||||
cost += head_in_stack(st, target, gold.heads)
|
||||
cost += children_in_stack(st, target, gold.heads)
|
||||
return cost
|
||||
|
||||
|
||||
cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1:
|
||||
cdef int cost = 0
|
||||
cost += children_in_buffer(st, target, gold.heads)
|
||||
cost += head_in_buffer(st, target, gold.heads)
|
||||
return cost
|
||||
|
||||
|
||||
cdef int arc_cost(const GoldParseC* gold, int head, int child, int label) except -1:
|
||||
if gold.heads[child] != head:
|
||||
return 0
|
||||
elif gold.labels[child] == -1:
|
||||
return 0
|
||||
elif gold.labels[child] == label:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
cdef class Shift:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
|
@ -62,14 +92,20 @@ cdef class Shift:
|
|||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not Shift.is_valid(s, label):
|
||||
return 9000
|
||||
cost = 0
|
||||
cost += head_in_stack(s, s.i, gold.heads)
|
||||
cost += children_in_stack(s, s.i, gold.heads)
|
||||
return Shift.move_cost(s, gold) + Shift.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
cdef int cost = push_cost(s, gold, s.i)
|
||||
# If we can break, and there's no cost to doing so, we should
|
||||
if Break.is_valid(s, label) and Break.cost(s, gold, -1) == 0:
|
||||
if Break.is_valid(s, -1) and Break.cost(s, gold, -1) == 0:
|
||||
cost += 1
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
return 0
|
||||
|
||||
|
||||
cdef class Reduce:
|
||||
@staticmethod
|
||||
|
@ -89,11 +125,19 @@ cdef class Reduce:
|
|||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not Reduce.is_valid(s, label):
|
||||
return 9000
|
||||
cdef int cost = 0
|
||||
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
||||
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
if NON_MONOTONIC:
|
||||
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
||||
return cost
|
||||
return pop_cost(s, gold, s.stack[0])
|
||||
else:
|
||||
return children_in_buffer(s, s.stack[0], gold.heads)
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
cdef class LeftArc:
|
||||
|
@ -117,19 +161,21 @@ cdef class LeftArc:
|
|||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not LeftArc.is_valid(s, label):
|
||||
return 9000
|
||||
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
if not LeftArc.is_valid(s, -1):
|
||||
return 9000
|
||||
cdef int cost = 0
|
||||
if gold.heads[s.stack[0]] == s.i:
|
||||
cost += label != -1 and label != gold.labels[s.stack[0]]
|
||||
return cost
|
||||
# If we're at EOL, then the left arc will add an arc to ROOT.
|
||||
elif at_eol(s):
|
||||
# Are we root?
|
||||
if gold.labels[s.stack[0]] != -1:
|
||||
# If we're at EOL, prefer to reduce or break over left-arc
|
||||
if Reduce.is_valid(s, -1) or Break.is_valid(s, -1):
|
||||
cost += gold.heads[s.stack[0]] != s.stack[0]
|
||||
# Are we labelling correctly?
|
||||
cost += label != -1 and label != gold.labels[s.stack[0]]
|
||||
return cost
|
||||
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
||||
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
||||
|
@ -139,6 +185,14 @@ cdef class LeftArc:
|
|||
cost += gold.heads[s.stack[0]] == s.stack[0]
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if label == -1 or gold.labels[s.stack[0]] == -1:
|
||||
return 0
|
||||
if gold.heads[s.stack[0]] == s.i and label != gold.labels[s.stack[0]]:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
cdef class RightArc:
|
||||
@staticmethod
|
||||
|
@ -154,17 +208,25 @@ cdef class RightArc:
|
|||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not RightArc.is_valid(s, label):
|
||||
return 9000
|
||||
cdef int cost
|
||||
cost = 0
|
||||
if gold.heads[s.i] == s.stack[0]:
|
||||
cost += label != -1 and label != gold.labels[s.i]
|
||||
return cost
|
||||
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
return push_cost(s, gold, s.i) - (gold.heads[s.i] == s.stack[0])
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
return arc_cost(gold, s.stack[0], s.i, label)
|
||||
#cdef int cost = 0
|
||||
#if gold.heads[s.i] == s.stack[0]:
|
||||
# cost += label != -1 and label != gold.labels[s.i]
|
||||
# return cost
|
||||
# This indicates missing head
|
||||
if gold.labels[s.i] != -1:
|
||||
cost += head_in_buffer(s, s.i, gold.heads)
|
||||
cost += children_in_stack(s, s.i, gold.heads)
|
||||
cost += head_in_stack(s, s.i, gold.heads)
|
||||
return cost
|
||||
#if gold.labels[s.i] != -1:
|
||||
# cost += head_in_buffer(s, s.i, gold.heads)
|
||||
#cost += children_in_stack(s, s.i, gold.heads)
|
||||
#cost += head_in_stack(s, s.i, gold.heads)
|
||||
#return cost
|
||||
|
||||
|
||||
cdef class Break:
|
||||
|
@ -207,6 +269,11 @@ cdef class Break:
|
|||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not Break.is_valid(s, label):
|
||||
return 9000
|
||||
else:
|
||||
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
# When we break, we Reduce all of the words on the stack.
|
||||
cdef int cost = 0
|
||||
# Number of deps between S0...Sn and N0...Nn
|
||||
|
@ -214,6 +281,10 @@ cdef class Break:
|
|||
cost += children_in_stack(s, i, gold.heads)
|
||||
cost += head_in_stack(s, i, gold.heads)
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
return 0
|
||||
|
||||
|
||||
cdef class Constituent:
|
||||
|
@ -280,6 +351,17 @@ cdef class Constituent:
|
|||
# loss = 1 # If we see the start position, set loss to 1
|
||||
#return loss
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
if not Constituent.is_valid(s, -1):
|
||||
return 9000
|
||||
raise Exception("Constituent move should be disabled currently")
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
cdef class Adjust:
|
||||
@staticmethod
|
||||
|
@ -318,6 +400,16 @@ cdef class Adjust:
|
|||
if not Adjust.is_valid(s, label):
|
||||
return 9000
|
||||
raise Exception("Adjust move should be disabled currently")
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
if not Adjust.is_valid(s, -1):
|
||||
return 9000
|
||||
raise Exception("Adjust move should be disabled currently")
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
return 0
|
||||
# The gold standard is indexed by end, then by start, then a set of labels
|
||||
#gold_starts = gold.brackets(get_s0(s).r_edge, {})
|
||||
# Case 1: There are 0 brackets ending at this word.
|
||||
|
@ -460,32 +552,36 @@ cdef class ArcEager(TransitionSystem):
|
|||
output[i] = is_valid[self.c[i].move]
|
||||
|
||||
cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1:
|
||||
cdef Transition move
|
||||
move.label = -1
|
||||
cdef int i, move, label
|
||||
cdef label_cost_func_t[N_MOVES] label_cost_funcs
|
||||
cdef move_cost_func_t[N_MOVES] move_cost_funcs
|
||||
cdef int[N_MOVES] move_costs
|
||||
move_costs[SHIFT] = Shift.cost(s, &gold.c, -1)
|
||||
move_costs[REDUCE] = Reduce.cost(s, &gold.c, -1)
|
||||
move_costs[LEFT] = LeftArc.cost(s, &gold.c, -1)
|
||||
move_costs[RIGHT] = RightArc.cost(s, &gold.c, -1)
|
||||
move_costs[BREAK] = Break.cost(s, &gold.c, -1)
|
||||
move_costs[CONSTITUENT] = Constituent.cost(s, &gold.c, -1)
|
||||
move_costs[ADJUST] = Adjust.cost(s, &gold.c, -1)
|
||||
for i in range(N_MOVES):
|
||||
move_costs[i] = -1
|
||||
move_cost_funcs[SHIFT] = Shift.move_cost
|
||||
move_cost_funcs[REDUCE] = Reduce.move_cost
|
||||
move_cost_funcs[LEFT] = LeftArc.move_cost
|
||||
move_cost_funcs[RIGHT] = RightArc.move_cost
|
||||
move_cost_funcs[BREAK] = Break.move_cost
|
||||
move_cost_funcs[CONSTITUENT] = Constituent.move_cost
|
||||
move_cost_funcs[ADJUST] = Adjust.move_cost
|
||||
|
||||
label_cost_funcs[SHIFT] = Shift.label_cost
|
||||
label_cost_funcs[REDUCE] = Reduce.label_cost
|
||||
label_cost_funcs[LEFT] = LeftArc.label_cost
|
||||
label_cost_funcs[RIGHT] = RightArc.label_cost
|
||||
label_cost_funcs[BREAK] = Break.label_cost
|
||||
label_cost_funcs[CONSTITUENT] = Constituent.label_cost
|
||||
label_cost_funcs[ADJUST] = Adjust.label_cost
|
||||
|
||||
cdef int i, label
|
||||
cdef int* labels = gold.c.labels
|
||||
cdef int* heads = gold.c.heads
|
||||
for i in range(self.n_moves):
|
||||
move = self.c[i]
|
||||
output[i] = move_costs[move.move]
|
||||
if output[i] == 0:
|
||||
label = -1
|
||||
if move.move == RIGHT and heads[s.i] == s.stack[0]:
|
||||
label = labels[s.i]
|
||||
if move.move == LEFT and heads[s.stack[0]] == s.i:
|
||||
label = labels[s.stack[0]]
|
||||
elif move.move == LEFT and at_eol(s) and (Reduce.is_valid(s, -1) or Break.is_valid(s, 1)):
|
||||
label = labels[s.stack[0]]
|
||||
output[i] += move.label != label and label != -1
|
||||
move = self.c[i].move
|
||||
label = self.c[i].label
|
||||
if move_costs[move] == -1:
|
||||
move_costs[move] = move_cost_funcs[move](s, &gold.c)
|
||||
output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label)
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||
cdef bint[N_MOVES] is_valid
|
||||
|
|
|
@ -21,6 +21,8 @@ cdef struct Transition:
|
|||
|
||||
|
||||
ctypedef int (*get_cost_func_t)(const State* state, const GoldParseC* gold, int label) except -1
|
||||
ctypedef int (*move_cost_func_t)(const State* state, const GoldParseC* gold) except -1
|
||||
ctypedef int (*label_cost_func_t)(const State* state, const GoldParseC* gold, int label) except -1
|
||||
|
||||
ctypedef int (*do_func_t)(State* state, int label) except -1
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user