mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
* Cost functions now take StateClass argument, instead of State*.
This commit is contained in:
parent
e0cf61f591
commit
4b98b3e9c8
|
@ -53,9 +53,7 @@ MOVE_NAMES[BREAK] = 'B'
|
|||
|
||||
# Helper functions for the arc-eager oracle
|
||||
|
||||
cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -1:
|
||||
cdef StateClass stcls = StateClass(st.sent_len)
|
||||
stcls.from_struct(st)
|
||||
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) except -1:
|
||||
cdef int cost = 0
|
||||
cdef int i, S_i
|
||||
for i in range(stcls.stack_depth()):
|
||||
|
@ -64,13 +62,11 @@ cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -
|
|||
cost += 1
|
||||
if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
|
||||
cost += 1
|
||||
cost += Break.is_valid(stcls, -1) and Break.move_cost(st, gold) == 0
|
||||
cost += Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0
|
||||
return cost
|
||||
|
||||
|
||||
cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1:
|
||||
cdef StateClass stcls = StateClass(st.sent_len)
|
||||
stcls.from_struct(st)
|
||||
cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) except -1:
|
||||
cdef int cost = 0
|
||||
cdef int i, B_i
|
||||
for i in range(stcls.buffer_length()):
|
||||
|
@ -81,9 +77,7 @@ cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1
|
|||
break
|
||||
return cost
|
||||
|
||||
cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child) except -1:
|
||||
cdef StateClass stcls = StateClass(st.sent_len)
|
||||
stcls.from_struct(st)
|
||||
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) except -1:
|
||||
if arc_is_gold(gold, head, child):
|
||||
return 0
|
||||
elif stcls.H(child) == gold.heads[child]:
|
||||
|
@ -133,15 +127,15 @@ cdef class Shift:
|
|||
push_stack(state)
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
return Shift.move_cost(s, gold) + Shift.label_cost(s, gold, label)
|
||||
cdef int cost(StateClass st, const GoldParseC* gold, int label) except -1:
|
||||
return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
return push_cost(s, gold, s.i)
|
||||
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
|
||||
return push_cost(s, gold, s.B(0))
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
return 0
|
||||
|
||||
|
||||
|
@ -160,18 +154,18 @@ cdef class Reduce:
|
|||
pop_stack(state)
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
|
||||
if NON_MONOTONIC:
|
||||
return pop_cost(s, gold, s.stack[0])
|
||||
return pop_cost(s, gold, s.S(0))
|
||||
else:
|
||||
return children_in_buffer(s, s.stack[0], gold.heads)
|
||||
return children_in_buffer(s, s.S(0), gold.heads)
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
return 0
|
||||
|
||||
|
||||
|
@ -193,19 +187,19 @@ cdef class LeftArc:
|
|||
pop_stack(state)
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
if arc_is_gold(gold, s.i, s.stack[0]):
|
||||
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
|
||||
if arc_is_gold(gold, s.B(0), s.S(0)):
|
||||
return 0
|
||||
else:
|
||||
return pop_cost(s, gold, s.stack[0]) + arc_cost(s, gold, s.i, s.stack[0])
|
||||
return pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
return arc_is_gold(gold, s.i, s.stack[0]) and not label_is_gold(gold, s.i, s.stack[0], label)
|
||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
|
||||
|
||||
|
||||
cdef class RightArc:
|
||||
|
@ -219,19 +213,19 @@ cdef class RightArc:
|
|||
push_stack(state)
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
if arc_is_gold(gold, s.stack[0], s.i):
|
||||
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
|
||||
if arc_is_gold(gold, s.S(0), s.B(0)):
|
||||
return 0
|
||||
else:
|
||||
return push_cost(s, gold, s.i) + arc_cost(s, gold, s.stack[0], s.i)
|
||||
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
return arc_is_gold(gold, s.stack[0], s.i) and not label_is_gold(gold, s.stack[0], s.i, label)
|
||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
|
||||
|
||||
|
||||
cdef class Break:
|
||||
|
@ -271,21 +265,25 @@ cdef class Break:
|
|||
state.stack_len -= 1
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
cdef int move_cost(StateClass s, const GoldParseC* gold) except -1:
|
||||
# When we break, we Reduce all of the words on the stack.
|
||||
cdef int cost = 0
|
||||
# Number of deps between S0...Sn and N0...Nn
|
||||
for i in range(s.i, s.sent_len):
|
||||
cost += children_in_stack(s, i, gold.heads)
|
||||
cost += head_in_stack(s, i, gold.heads)
|
||||
cdef int i, B_i, S_i
|
||||
for i in range(s.buffer_length()):
|
||||
B_i = s.B(i)
|
||||
for j in range(s.stack_depth()):
|
||||
S_i = s.S(j)
|
||||
cost += gold.heads[B_i] == S_i
|
||||
cost += gold.heads[S_i] == B_i
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
return 0
|
||||
|
||||
|
||||
|
@ -389,7 +387,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
for i in range(self.n_moves):
|
||||
output[i] = is_valid[self.c[i].move]
|
||||
|
||||
cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1:
|
||||
cdef int set_costs(self, int* output, StateClass stcls, GoldParse gold) except -1:
|
||||
cdef int i, move, label
|
||||
cdef label_cost_func_t[N_MOVES] label_cost_funcs
|
||||
cdef move_cost_func_t[N_MOVES] move_cost_funcs
|
||||
|
@ -411,8 +409,6 @@ cdef class ArcEager(TransitionSystem):
|
|||
cdef int* labels = gold.c.labels
|
||||
cdef int* heads = gold.c.heads
|
||||
|
||||
cdef StateClass stcls = StateClass(s.sent_len)
|
||||
stcls.from_struct(s)
|
||||
self.set_valid(self._is_valid, stcls)
|
||||
for i in range(self.n_moves):
|
||||
if not self._is_valid[i]:
|
||||
|
@ -421,8 +417,8 @@ cdef class ArcEager(TransitionSystem):
|
|||
move = self.c[i].move
|
||||
label = self.c[i].label
|
||||
if move_costs[move] == -1:
|
||||
move_costs[move] = move_cost_funcs[move](s, &gold.c)
|
||||
output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label)
|
||||
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
|
||||
output[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
|
||||
cdef bint[N_MOVES] is_valid
|
||||
|
|
|
@ -36,18 +36,14 @@ MOVE_NAMES[OUT] = 'O'
|
|||
cdef do_func_t[N_MOVES] do_funcs
|
||||
|
||||
|
||||
cdef bint entity_is_open(const State *s) except -1:
|
||||
return s.ents_len >= 1 and s.ent.end == 0
|
||||
|
||||
|
||||
cdef bint _entity_is_sunk(const State *s, Transition* golds) except -1:
|
||||
if not entity_is_open(s):
|
||||
cdef bint _entity_is_sunk(StateClass st, Transition* golds) except -1:
|
||||
if not st.entity_is_open():
|
||||
return False
|
||||
|
||||
cdef const Transition* gold = &golds[s.ent.start]
|
||||
cdef const Transition* gold = &golds[st.E(0)]
|
||||
if gold.move != BEGIN and gold.move != UNIT:
|
||||
return True
|
||||
elif gold.label != s.ent.label:
|
||||
elif gold.label != st.E_(0).ent_type:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
@ -166,7 +162,7 @@ cdef class Missing:
|
|||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
return 9000
|
||||
|
||||
|
||||
|
@ -187,7 +183,7 @@ cdef class Begin:
|
|||
s.i += 1
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int g_act = gold.ner[s.i].move
|
||||
cdef int g_tag = gold.ner[s.i].label
|
||||
|
||||
|
@ -216,7 +212,7 @@ cdef class In:
|
|||
s.i += 1
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
move = IN
|
||||
cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT
|
||||
cdef int g_act = gold.ner[s.i].move
|
||||
|
@ -257,7 +253,7 @@ cdef class Last:
|
|||
s.i += 1
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
move = LAST
|
||||
|
||||
cdef int g_act = gold.ner[s.i].move
|
||||
|
@ -301,7 +297,7 @@ cdef class Unit:
|
|||
s.i += 1
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int g_act = gold.ner[s.i].move
|
||||
cdef int g_tag = gold.ner[s.i].label
|
||||
|
||||
|
@ -329,7 +325,7 @@ cdef class Out:
|
|||
s.i += 1
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int g_act = gold.ner[s.i].move
|
||||
cdef int g_tag = gold.ner[s.i].label
|
||||
|
||||
|
|
|
@ -144,8 +144,8 @@ cdef class Parser:
|
|||
_new_fill_context(context, stcls)
|
||||
scores = self.model.score(context)
|
||||
guess = self.moves.best_valid(scores, stcls)
|
||||
best = self.moves.best_gold(scores, state, gold)
|
||||
cost = guess.get_cost(state, &gold.c, guess.label)
|
||||
best = self.moves.best_gold(scores, stcls, gold)
|
||||
cost = guess.get_cost(stcls, &gold.c, guess.label)
|
||||
self.model.update(context, guess.clas, best.clas, cost)
|
||||
guess.do(state, guess.label)
|
||||
loss += cost
|
||||
|
@ -191,7 +191,7 @@ cdef class Parser:
|
|||
if gold is not None:
|
||||
for i in range(beam.size):
|
||||
state = <State*>beam.at(i)
|
||||
self.moves.set_costs(beam.costs[i], state, gold)
|
||||
self.moves.set_costs(beam.costs[i], stcls, gold)
|
||||
if follow_gold:
|
||||
for j in range(self.moves.n_moves):
|
||||
beam.is_valid[i][j] *= beam.costs[i][j] == 0
|
||||
|
|
|
@ -18,13 +18,13 @@ cdef struct Transition:
|
|||
weight_t score
|
||||
|
||||
bint (*is_valid)(StateClass state, int label) except -1
|
||||
int (*get_cost)(const State* state, const GoldParseC* gold, int label) except -1
|
||||
int (*get_cost)(StateClass state, const GoldParseC* gold, int label) except -1
|
||||
int (*do)(State* state, int label) except -1
|
||||
|
||||
|
||||
ctypedef int (*get_cost_func_t)(const State* state, const GoldParseC* gold, int label) except -1
|
||||
ctypedef int (*move_cost_func_t)(const State* state, const GoldParseC* gold) except -1
|
||||
ctypedef int (*label_cost_func_t)(const State* state, const GoldParseC* gold, int label) except -1
|
||||
ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
|
||||
ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) except -1
|
||||
ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
|
||||
|
||||
ctypedef int (*do_func_t)(State* state, int label) except -1
|
||||
|
||||
|
@ -47,9 +47,9 @@ cdef class TransitionSystem:
|
|||
|
||||
cdef int set_valid(self, bint* output, StateClass state) except -1
|
||||
|
||||
cdef int set_costs(self, int* output, const State* state, GoldParse gold) except -1
|
||||
cdef int set_costs(self, int* output, StateClass state, GoldParse gold) except -1
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *
|
||||
|
||||
cdef Transition best_gold(self, const weight_t* scores, State* state,
|
||||
cdef Transition best_gold(self, const weight_t* scores, StateClass state,
|
||||
GoldParse gold) except *
|
||||
|
|
|
@ -50,21 +50,22 @@ cdef class TransitionSystem:
|
|||
cdef int set_valid(self, bint* output, StateClass state) except -1:
|
||||
raise NotImplementedError
|
||||
|
||||
cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1:
|
||||
cdef int set_costs(self, int* output, StateClass stcls, GoldParse gold) except -1:
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
output[i] = self.c[i].get_cost(s, &gold.c, self.c[i].label)
|
||||
if self.c[i].is_valid(stcls, self.c[i].label):
|
||||
output[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
|
||||
else:
|
||||
output[i] = 9000
|
||||
|
||||
cdef Transition best_gold(self, const weight_t* scores, const State* s,
|
||||
cdef Transition best_gold(self, const weight_t* scores, StateClass stcls,
|
||||
GoldParse gold) except *:
|
||||
cdef StateClass stcls = StateClass(s.sent_len)
|
||||
stcls.from_struct(s)
|
||||
cdef Transition best
|
||||
cdef weight_t score = MIN_SCORE
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
if self.c[i].is_valid(stcls, self.c[i].label):
|
||||
cost = self.c[i].get_cost(s, &gold.c, self.c[i].label)
|
||||
cost = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
|
||||
if scores[i] > score and cost == 0:
|
||||
best = self.c[i]
|
||||
score = scores[i]
|
||||
|
|
Loading…
Reference in New Issue
Block a user