* Move StateClass into the interface for is_valid

This commit is contained in:
Matthew Honnibal 2015-06-09 23:23:28 +02:00
parent 09617a4638
commit e0cf61f591
10 changed files with 132 additions and 340 deletions

View File

@ -1,10 +1,11 @@
from thinc.typedefs cimport atom_t
from ._state cimport State
from .stateclass cimport StateClass
cdef int fill_context(atom_t* context, State* state) except -1
cdef int _new_fill_context(atom_t* context, State* state) except -1
cdef int _new_fill_context(atom_t* context, StateClass state) except -1
# Context elements
# Ensure each token's attributes are listed: w, p, c, c6, c4. The order

View File

@ -65,13 +65,11 @@ cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
context[10] = token.ent_iob
context[11] = token.ent_type
cdef int _new_fill_context(atom_t* ctxt, State* state) except -1:
cdef int _new_fill_context(atom_t* ctxt, StateClass st) except -1:
# Take care to fill every element of context!
# We could memset, but this makes it very easy to have broken features that
# make almost no impact on accuracy. If instead they're unset, the impact
# tends to be dramatic, so we get an obvious regression to fix...
cdef StateClass st = StateClass(state.sent_len)
st.from_struct(state)
fill_token(&ctxt[S2w], st.S_(2))
fill_token(&ctxt[S1w], st.S_(1))
fill_token(&ctxt[S1rw], st.R_(st.S(1), 1))
@ -89,8 +87,8 @@ cdef int _new_fill_context(atom_t* ctxt, State* state) except -1:
fill_token(&ctxt[P2w], st.safe_get(st.B(0)-2))
# TODO
fill_token(&ctxt[E0w], get_e0(state))
fill_token(&ctxt[E1w], get_e1(state))
fill_token(&ctxt[E0w], st.E_(0))
fill_token(&ctxt[E1w], st.E_(1))
if st.stack_depth() >= 1 and not st.eol():
ctxt[dist] = min(st.S(0) - st.B(0), 5) # TODO: This is backwards!!

View File

@ -4,6 +4,8 @@ from thinc.typedefs cimport weight_t
from ._state cimport State
from .stateclass cimport StateClass
from .transition_system cimport TransitionSystem, Transition
cdef class ArcEager(TransitionSystem):

View File

@ -40,9 +40,6 @@ cdef enum:
BREAK
CONSTITUENT
ADJUST
N_MOVES
@ -52,8 +49,6 @@ MOVE_NAMES[REDUCE] = 'D'
MOVE_NAMES[LEFT] = 'L'
MOVE_NAMES[RIGHT] = 'R'
MOVE_NAMES[BREAK] = 'B'
MOVE_NAMES[CONSTITUENT] = 'C'
MOVE_NAMES[ADJUST] = 'A'
# Helper functions for the arc-eager oracle
@ -69,15 +64,8 @@ cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -
cost += 1
if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
cost += 1
cost += Break.is_valid(st, -1) and Break.move_cost(st, gold) == 0
cost += Break.is_valid(stcls, -1) and Break.move_cost(st, gold) == 0
return cost
# When we push a word, we can't make arcs to or from the stack. So, we lose
# any of those arcs.
#cost += head_in_stack(st, target, gold.heads)
#cost += children_in_stack(st, target, gold.heads)
# If we can Break, we shouldn't push
#cost += Break.is_valid(st, -1) and Break.move_cost(st, gold) == 0
#return cost
cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1:
@ -92,10 +80,6 @@ cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
break
return cost
#cost += children_in_buffer(st, target, gold.heads)
#cost += head_in_buffer(st, target, gold.heads)
#return cost
cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child) except -1:
cdef StateClass stcls = StateClass(st.sent_len)
@ -108,14 +92,6 @@ cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child)
return 1
else:
return 0
#if arc_is_gold(gold, head, child):
# return 0
#elif (child + st.sent[child].head) == gold.heads[child]:
# return 1
#elif gold.heads[child] >= st.i:
# return 1
#else:
# return 0
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
@ -146,11 +122,7 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) except -1:
cdef class Shift:
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
return not at_eol(s)
@staticmethod
cdef bint _new_is_valid(StateClass st, int label) except -1:
cdef bint is_valid(StateClass st, int label) except -1:
return not st.eol()
@staticmethod
@ -162,8 +134,6 @@ cdef class Shift:
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Shift.is_valid(s, label):
return 9000
return Shift.move_cost(s, gold) + Shift.label_cost(s, gold, label)
@staticmethod
@ -177,19 +147,12 @@ cdef class Shift:
cdef class Reduce:
@staticmethod
cdef bint _new_is_valid(StateClass st, int label) except -1:
cdef bint is_valid(StateClass st, int label) except -1:
if NON_MONOTONIC:
return st.stack_depth() >= 2 #and not missing_brackets(s)
else:
return st.stack_depth() >= 2 and st.has_head(st.S(0))
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
if NON_MONOTONIC:
return s.stack_len >= 2 #and not missing_brackets(s)
else:
return s.stack_len >= 2 and has_head(get_s0(s))
@staticmethod
cdef int transition(State* state, int label) except -1:
if NON_MONOTONIC and not has_head(get_s0(state)) and state.stack_len >= 2:
@ -198,8 +161,6 @@ cdef class Reduce:
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Reduce.is_valid(s, label):
return 9000
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
@staticmethod
@ -216,19 +177,12 @@ cdef class Reduce:
cdef class LeftArc:
@staticmethod
cdef bint _new_is_valid(StateClass st, int label) except -1:
cdef bint is_valid(StateClass st, int label) except -1:
if NON_MONOTONIC:
return st.stack_depth() >= 1 #and not missing_brackets(s)
else:
return st.stack_depth() >= 1 and not st.has_head(st.S(0))
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
if NON_MONOTONIC:
return s.stack_len >= 1 #and not missing_brackets(s)
else:
return s.stack_len >= 1 and not has_head(get_s0(s))
@staticmethod
cdef int transition(State* state, int label) except -1:
# Interpret left-arcs from EOL as attachment to root
@ -240,15 +194,11 @@ cdef class LeftArc:
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not LeftArc.is_valid(s, label):
return 9000
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
if not LeftArc.is_valid(s, -1):
return 9000
elif arc_is_gold(gold, s.i, s.stack[0]):
if arc_is_gold(gold, s.i, s.stack[0]):
return 0
else:
return pop_cost(s, gold, s.stack[0]) + arc_cost(s, gold, s.i, s.stack[0])
@ -260,11 +210,7 @@ cdef class LeftArc:
cdef class RightArc:
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
return s.stack_len >= 1 and not at_eol(s)
@staticmethod
cdef bint _new_is_valid(StateClass st, int label) except -1:
cdef bint is_valid(StateClass st, int label) except -1:
return st.stack_depth() >= 1 and not st.eol()
@staticmethod
@ -274,8 +220,6 @@ cdef class RightArc:
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not RightArc.is_valid(s, label):
return 9000
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
@staticmethod
@ -292,7 +236,7 @@ cdef class RightArc:
cdef class Break:
@staticmethod
cdef bint _new_is_valid(StateClass st, int label) except -1:
cdef bint is_valid(StateClass st, int label) except -1:
cdef int i
if not USE_BREAK:
return False
@ -317,32 +261,6 @@ cdef class Break:
# TODO: Constituency constraints
return True
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
cdef int i
if not USE_BREAK:
return False
elif at_eol(s):
return False
elif s.stack_len < 1:
return False
elif NON_MONOTONIC:
return True
else:
# In the Break transition paper, they have this constraint that prevents
# Break if stack is disconnected. But, if we're doing non-monotonic parsing,
# we prefer to relax this constraint. This is helpful in parsing whole
# documents, because then we don't get stuck with words on the stack.
seen_headless = False
for i in range(s.stack_len):
if s.sent[s.stack[-i]].head == 0:
if seen_headless:
return False
else:
seen_headless = True
# TODO: Constituency constraints
return True
@staticmethod
cdef int transition(State* state, int label) except -1:
state.sent[state.i-1].sent_end = True
@ -354,10 +272,7 @@ cdef class Break:
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Break.is_valid(s, label):
return 9000
else:
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
@ -374,163 +289,11 @@ cdef class Break:
return 0
cdef class Constituent:
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
if s.stack_len < 1:
return False
return False
#else:
# # If all stack elements are popped, can't constituent
# for i in range(s.ctnts.stack_len):
# if not s.ctnts.is_popped[-i]:
# return True
# else:
# return False
@staticmethod
cdef int transition(State* state, int label) except -1:
return False
#cdef Constituent* bracket = new_bracket(state.ctnts)
#bracket.parent = NULL
#bracket.label = self.label
#bracket.head = get_s0(state)
#bracket.length = 0
#attach(bracket, state.ctnts.stack)
# Attach rightward children. They're in the brackets array somewhere
# between here and B0.
#cdef Constituent* node
#cdef const TokenC* node_gov
#for i in range(1, bracket - state.ctnts.stack):
# node = bracket - i
# node_gov = node.head + node.head.head
# if node_gov == bracket.head:
# attach(bracket, node)
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Constituent.is_valid(s, label):
return 9000
raise Exception("Constituent move should be disabled currently")
# The gold standard is indexed by end, then by start, then a set of labels
#brackets = gold.brackets(get_s0(s).r_edge, {})
#if not brackets:
# return 2 # 2 loss for bad bracket, only 1 for good bracket bad label
# Index the current brackets in the state
#existing = set()
#for i in range(s.ctnt_len):
# if ctnt.end == s.r_edge and ctnt.label == self.label:
# existing.add(ctnt.start)
#cdef int loss = 2
#cdef const TokenC* child
#cdef const TokenC* s0 = get_s0(s)
#cdef int n_left = count_left_kids(s0)
# Iterate over the possible start positions, and check whether we have a
# (start, end, label) match to the gold tree
#for i in range(1, n_left):
# child = get_left(s, s0, i)
# if child.l_edge in brackets and child.l_edge not in existing:
# if self.label in brackets[child.l_edge]
# return 0
# else:
# loss = 1 # If we see the start position, set loss to 1
#return loss
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
if not Constituent.is_valid(s, -1):
return 9000
raise Exception("Constituent move should be disabled currently")
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
return 0
cdef class Adjust:
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
return False
#if s.ctnts.stack_len < 2:
# return False
#cdef const Constituent* b1 = s.ctnts.stack[-1]
#cdef const Constituent* b0 = s.ctnts.stack[0]
#if (b1.head + b1.head.head) != b0.head:
# return False
#elif b0.head >= b1.head:
# return False
#elif b0 >= b1:
# return False
@staticmethod
cdef int transition(State* state, int label) except -1:
return False
#cdef Constituent* b0 = state.ctnts.stack[0]
#cdef Constituent* b1 = state.ctnts.stack[1]
#assert (b1.head + b1.head.head) == b0.head
#assert b0.head < b1.head
#assert b0 < b1
#attach(b0, b1)
## Pop B1 from stack, but keep B0 on top
#state.ctnts.stack -= 1
#state.ctnts.stack[0] = b0
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Adjust.is_valid(s, label):
return 9000
raise Exception("Adjust move should be disabled currently")
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
if not Adjust.is_valid(s, -1):
return 9000
raise Exception("Adjust move should be disabled currently")
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
return 0
# The gold standard is indexed by end, then by start, then a set of labels
#gold_starts = gold.brackets(get_s0(s).r_edge, {})
# Case 1: There are 0 brackets ending at this word.
# --> Cost is sunk, but must allow brackets to begin
#if not gold_starts:
# return 0
# Is the top bracket correct?
#gold_labels = gold_starts.get(s.ctnt.start, set())
# TODO: Case where we have a unary rule
# TODO: Case where two brackets end on this word, with top bracket starting
# before
#cdef const TokenC* child
#cdef const TokenC* s0 = get_s0(s)
#cdef int n_left = count_left_kids(s0)
#cdef int i
# Iterate over the possible start positions, and check whether we have a
# (start, end, label) match to the gold tree
#for i in range(1, n_left):
# child = get_left(s, s0, i)
# if child.l_edge in brackets:
# if self.label in brackets[child.l_edge]:
# return 0
# else:
# loss = 1 # If we see the start position, set loss to 1
#return loss
cdef class ArcEager(TransitionSystem):
@classmethod
def get_labels(cls, gold_parses):
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {},
LEFT: {'ROOT': True}, BREAK: {'ROOT': True},
CONSTITUENT: {}, ADJUST: {'': True}}
LEFT: {'ROOT': True}, BREAK: {'ROOT': True}}
for raw_text, sents in gold_parses:
for (ids, words, tags, heads, labels, iob), ctnts in sents:
for child, head, label in zip(ids, heads, labels):
@ -539,8 +302,6 @@ cdef class ArcEager(TransitionSystem):
move_labels[RIGHT][label] = True
elif head > child:
move_labels[LEFT][label] = True
for start, end, label in ctnts:
move_labels[CONSTITUENT][label] = True
return move_labels
cdef int preprocess_gold(self, GoldParse gold) except -1:
@ -604,14 +365,6 @@ cdef class ArcEager(TransitionSystem):
t.is_valid = Break.is_valid
t.do = Break.transition
t.get_cost = Break.cost
elif move == CONSTITUENT:
t.is_valid = Constituent.is_valid
t.do = Constituent.transition
t.get_cost = Constituent.cost
elif move == ADJUST:
t.is_valid = Adjust.is_valid
t.do = Adjust.transition
t.get_cost = Adjust.cost
else:
raise Exception(move)
return t
@ -625,18 +378,13 @@ cdef class ArcEager(TransitionSystem):
if state.sent[i].head == 0 and state.sent[i].dep == 0:
state.sent[i].dep = root_label
cdef int set_valid(self, bint* output, const State* state) except -1:
raise Exception
cdef StateClass stcls = StateClass(state.sent_len)
stcls.from_struct(state)
cdef int set_valid(self, bint* output, StateClass stcls) except -1:
cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = Shift._new_is_valid(stcls, -1)
is_valid[REDUCE] = Reduce._new_is_valid(stcls, -1)
is_valid[LEFT] = LeftArc._new_is_valid(stcls, -1)
is_valid[RIGHT] = RightArc._new_is_valid(stcls, -1)
is_valid[BREAK] = Break._new_is_valid(stcls, -1)
is_valid[CONSTITUENT] = False # Constituent.is_valid(state, -1)
is_valid[ADJUST] = False # Adjust.is_valid(state, -1)
is_valid[SHIFT] = Shift.is_valid(stcls, -1)
is_valid[REDUCE] = Reduce.is_valid(stcls, -1)
is_valid[LEFT] = LeftArc.is_valid(stcls, -1)
is_valid[RIGHT] = RightArc.is_valid(stcls, -1)
is_valid[BREAK] = Break.is_valid(stcls, -1)
cdef int i
for i in range(self.n_moves):
output[i] = is_valid[self.c[i].move]
@ -653,38 +401,36 @@ cdef class ArcEager(TransitionSystem):
move_cost_funcs[LEFT] = LeftArc.move_cost
move_cost_funcs[RIGHT] = RightArc.move_cost
move_cost_funcs[BREAK] = Break.move_cost
move_cost_funcs[CONSTITUENT] = Constituent.move_cost
move_cost_funcs[ADJUST] = Adjust.move_cost
label_cost_funcs[SHIFT] = Shift.label_cost
label_cost_funcs[REDUCE] = Reduce.label_cost
label_cost_funcs[LEFT] = LeftArc.label_cost
label_cost_funcs[RIGHT] = RightArc.label_cost
label_cost_funcs[BREAK] = Break.label_cost
label_cost_funcs[CONSTITUENT] = Constituent.label_cost
label_cost_funcs[ADJUST] = Adjust.label_cost
cdef int* labels = gold.c.labels
cdef int* heads = gold.c.heads
for i in range(self.n_moves):
move = self.c[i].move
label = self.c[i].label
if move_costs[move] == -1:
move_costs[move] = move_cost_funcs[move](s, &gold.c)
output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label)
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
assert s is not NULL
cdef StateClass stcls = StateClass(s.sent_len)
stcls.from_struct(s)
self.set_valid(self._is_valid, stcls)
for i in range(self.n_moves):
if not self._is_valid[i]:
output[i] = 9000
else:
move = self.c[i].move
label = self.c[i].label
if move_costs[move] == -1:
move_costs[move] = move_cost_funcs[move](s, &gold.c)
output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label)
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = Shift._new_is_valid(stcls, -1)
is_valid[REDUCE] = Reduce._new_is_valid(stcls, -1)
is_valid[LEFT] = LeftArc._new_is_valid(stcls, -1)
is_valid[RIGHT] = RightArc._new_is_valid(stcls, -1)
is_valid[BREAK] = Break._new_is_valid(stcls, -1)
is_valid[CONSTITUENT] = False # Constituent._new_is_valid(s, -1)
is_valid[ADJUST] = False # Adjust._new_is_valid(s, -1)
is_valid[SHIFT] = Shift.is_valid(stcls, -1)
is_valid[REDUCE] = Reduce.is_valid(stcls, -1)
is_valid[LEFT] = LeftArc.is_valid(stcls, -1)
is_valid[RIGHT] = RightArc.is_valid(stcls, -1)
is_valid[BREAK] = Break.is_valid(stcls, -1)
cdef Transition best
cdef weight_t score = MIN_SCORE
cdef int i
@ -703,5 +449,3 @@ cdef class ArcEager(TransitionSystem):
best.label = self.c[i].label
score = scores[i]
return best

View File

@ -11,6 +11,8 @@ from thinc.typedefs cimport weight_t
from ..gold cimport GoldParseC
from ..gold cimport GoldParse
from .stateclass cimport StateClass
cdef enum:
MISSING
@ -132,14 +134,14 @@ cdef class BiluoPushDown(TransitionSystem):
raise Exception(move)
return t
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
cdef int best = -1
cdef weight_t score = -90000
cdef const Transition* m
cdef int i
for i in range(self.n_moves):
m = &self.c[i]
if m.is_valid(s, m.label) and scores[i] > score:
if m.is_valid(stcls, m.label) and scores[i] > score:
best = i
score = scores[i]
assert best >= 0
@ -147,16 +149,16 @@ cdef class BiluoPushDown(TransitionSystem):
t.score = score
return t
cdef int set_valid(self, bint* output, const State* s) except -1:
cdef int set_valid(self, bint* output, StateClass stcls) except -1:
cdef int i
for i in range(self.n_moves):
m = &self.c[i]
output[i] = m.is_valid(s, m.label)
output[i] = m.is_valid(stcls, m.label)
cdef class Missing:
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
cdef bint is_valid(StateClass st, int label) except -1:
return False
@staticmethod
@ -170,8 +172,8 @@ cdef class Missing:
cdef class Begin:
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
return label != 0 and not entity_is_open(s)
cdef bint is_valid(StateClass st, int label) except -1:
return label != 0 and not st.entity_is_open()
@staticmethod
cdef int transition(State* s, int label) except -1:
@ -186,8 +188,6 @@ cdef class Begin:
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Begin.is_valid(s, label):
return 9000
cdef int g_act = gold.ner[s.i].move
cdef int g_tag = gold.ner[s.i].label
@ -203,10 +203,11 @@ cdef class Begin:
# B, Gold U --> False (P)
return 1
cdef class In:
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
return entity_is_open(s) and label != 0 and s.ent.label == label
cdef bint is_valid(StateClass st, int label) except -1:
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
@staticmethod
cdef int transition(State* s, int label) except -1:
@ -216,8 +217,6 @@ cdef class In:
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not In.is_valid(s, label):
return 9000
move = IN
cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT
cdef int g_act = gold.ner[s.i].move
@ -245,11 +244,10 @@ cdef class In:
return 1
cdef class Last:
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
return entity_is_open(s) and label != 0 and s.ent.label == label
cdef bint is_valid(StateClass st, int label) except -1:
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
@staticmethod
cdef int transition(State* s, int label) except -1:
@ -260,8 +258,6 @@ cdef class Last:
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Last.is_valid(s, label):
return 9000
move = LAST
cdef int g_act = gold.ner[s.i].move
@ -290,8 +286,8 @@ cdef class Last:
cdef class Unit:
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
return label != 0 and not entity_is_open(s)
cdef bint is_valid(StateClass st, int label) except -1:
return label != 0 and not st.entity_is_open()
@staticmethod
cdef int transition(State* s, int label) except -1:
@ -306,8 +302,6 @@ cdef class Unit:
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Unit.is_valid(s, label):
return 9000
cdef int g_act = gold.ner[s.i].move
cdef int g_tag = gold.ner[s.i].label
@ -326,8 +320,8 @@ cdef class Unit:
cdef class Out:
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
return not entity_is_open(s)
cdef bint is_valid(StateClass st, int label) except -1:
return not st.entity_is_open()
@staticmethod
cdef int transition(State* s, int label) except -1:
@ -336,9 +330,6 @@ cdef class Out:
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Out.is_valid(s, label):
return 9000
cdef int g_act = gold.ner[s.i].move
cdef int g_tag = gold.ner[s.i].label

View File

@ -40,8 +40,9 @@ from ..gold cimport GoldParse
from . import _parse_features
from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport _new_fill_context as fill_context
#from ._parse_features cimport fill_context
from ._parse_features cimport _new_fill_context
from ._parse_features cimport fill_context
from .stateclass cimport StateClass
DEBUG = False
@ -104,11 +105,13 @@ cdef class Parser:
cdef Pool mem = Pool()
cdef State* state = new_state(mem, tokens.data, tokens.length)
self.moves.initialize_state(state)
cdef StateClass stcls = StateClass(state.sent_len)
cdef Transition guess
while not is_final(state):
fill_context(context, state)
stcls.from_struct(state)
_new_fill_context(context, stcls)
scores = self.model.score(context)
guess = self.moves.best_valid(scores, state)
guess = self.moves.best_valid(scores, stcls)
guess.do(state, guess.label)
self.moves.finalize_state(state)
tokens.set_parse(state.sent)
@ -133,12 +136,14 @@ cdef class Parser:
cdef const weight_t* scores
cdef Transition guess
cdef Transition best
cdef StateClass stcls = StateClass(state.sent_len)
cdef atom_t[CONTEXT_SIZE] context
loss = 0
while not is_final(state):
fill_context(context, state)
stcls.from_struct(state)
_new_fill_context(context, stcls)
scores = self.model.score(context)
guess = self.moves.best_valid(scores, state)
guess = self.moves.best_valid(scores, stcls)
best = self.moves.best_gold(scores, state, gold)
cost = guess.get_cost(state, &gold.c, guess.label)
self.model.update(context, guess.clas, best.clas, cost)
@ -174,12 +179,14 @@ cdef class Parser:
cdef int i, j, cost
cdef bint is_valid
cdef const Transition* move
cdef StateClass stcls = StateClass(gold.length)
for i in range(beam.size):
state = <State*>beam.at(i)
stcls.from_struct(state)
if not is_final(state):
fill_context(context, state)
self.model.set_scores(beam.scores[i], context)
self.moves.set_valid(beam.is_valid[i], state)
self.moves.set_valid(beam.is_valid[i], stcls)
if gold is not None:
for i in range(beam.size):

View File

@ -2,7 +2,7 @@ from libc.string cimport memcpy, memset
from cymem.cymem cimport Pool
from ..structs cimport TokenC
from ..structs cimport TokenC, Entity
from ._state cimport State
@ -14,10 +14,12 @@ cdef class StateClass:
cdef int* _stack
cdef int* _buffer
cdef TokenC* _sent
cdef Entity* _ents
cdef TokenC _empty_token
cdef int length
cdef int _s_i
cdef int _b_i
cdef int _e_i
cdef int from_struct(self, const State* state) except -1
@ -25,6 +27,7 @@ cdef class StateClass:
cdef int B(self, int i) nogil
cdef int H(self, int i) nogil
cdef int E(self, int i) nogil
cdef int L(self, int i, int idx) nogil
cdef int R(self, int i, int idx) nogil
@ -33,6 +36,7 @@ cdef class StateClass:
cdef const TokenC* B_(self, int i) nogil
cdef const TokenC* H_(self, int i) nogil
cdef const TokenC* E_(self, int i) nogil
cdef const TokenC* L_(self, int i, int idx) nogil
cdef const TokenC* R_(self, int i, int idx) nogil
@ -40,12 +44,15 @@ cdef class StateClass:
cdef const TokenC* safe_get(self, int i) nogil
cdef bint empty(self) nogil
cdef bint entity_is_open(self) nogil
cdef bint eol(self) nogil
cdef bint is_final(self) nogil
cdef bint has_head(self, int i) nogil
cdef int n_L(self, int i) nogil
@ -64,6 +71,12 @@ cdef class StateClass:
cdef void add_arc(self, int head, int child, int label) nogil
cdef void del_arc(self, int head, int child) nogil
cdef void open_ent(self, int label) nogil
cdef void close_ent(self) nogil
cdef void set_ent_tag(self, int i, int ent_iob, int ent_type) nogil
cdef void set_sent_end(self, int i) nogil

View File

@ -9,10 +9,12 @@ cdef class StateClass:
self._buffer = <int*>mem.alloc(length, sizeof(int))
self._stack = <int*>mem.alloc(length, sizeof(int))
self._sent = <TokenC*>mem.alloc(length, sizeof(TokenC))
self._ents = <Entity*>mem.alloc(length, sizeof(Entity))
self.mem = mem
self.length = length
self._s_i = 0
self._b_i = 0
self._e_i = 0
cdef int i
for i in range(length):
self._buffer[i] = i
@ -21,10 +23,13 @@ cdef class StateClass:
cdef int from_struct(self, const State* state) except -1:
self._s_i = state.stack_len
self._b_i = state.i
self._e_i = state.ents_len
memcpy(self._sent, state.sent, sizeof(TokenC) * self.length)
cdef int i
for i in range(state.stack_len):
self._stack[self._s_i - (i+1)] = state.stack[-i]
for i in range(state.ents_len):
self._ents[i] = state.ent[-i]
cdef int S(self, int i) nogil:
if i >= self._s_i:
@ -41,6 +46,9 @@ cdef class StateClass:
return -1
return self._sent[i].head + i
cdef int E(self, int i) nogil:
return -1
cdef int L(self, int i, int idx) nogil:
if idx < 1:
return -1
@ -94,7 +102,10 @@ cdef class StateClass:
return self.safe_get(self.B(i))
cdef const TokenC* H_(self, int i) nogil:
return self.safe_get(self.B(i))
return self.safe_get(self.H(i))
cdef const TokenC* E_(self, int i) nogil:
return self.safe_get(self.E(i))
cdef const TokenC* L_(self, int i, int idx) nogil:
return self.safe_get(self.L(i, idx))
@ -129,6 +140,11 @@ cdef class StateClass:
cdef bint stack_is_connected(self) nogil:
return False
cdef bint entity_is_open(self) nogil:
if self._e_i < 1:
return False
return self._ents[self._e_i-1].end != 0
cdef int stack_depth(self) nogil:
return self._s_i
@ -164,6 +180,21 @@ cdef class StateClass:
else:
self._sent[head].l_kids &= ~(1 << dist)
cdef void open_ent(self, int label) nogil:
self._ents[self._e_i].start = self.B(0)
self._ents[self._e_i].label = label
self._ents[self._e_i].end = 0
self._e_i += 1
cdef void close_ent(self) nogil:
self._ents[self._e_i].end = self.B(0)+1
self._sent[self.B(0)].ent_iob = 1
cdef void set_ent_tag(self, int i, int ent_iob, int ent_type) nogil:
if 0 <= i < self.length:
self._sent[i].ent_iob = ent_iob
self._sent[i].ent_type = ent_type
cdef void set_sent_end(self, int i) nogil:
if 0 < i < self.length:
self._sent[i].sent_end = True
@ -172,8 +203,10 @@ cdef class StateClass:
memcpy(self._sent, src._sent, self.length * sizeof(TokenC))
memcpy(self._stack, src._stack, self.length * sizeof(int))
memcpy(self._buffer, src._buffer, self.length * sizeof(int))
memcpy(self._ents, src._ents, self.length * sizeof(int))
self._b_i = src._b_i
self._s_i = src._s_i
self._e_i = src._e_i
# From https://en.wikipedia.org/wiki/Hamming_weight

View File

@ -7,6 +7,8 @@ from ..gold cimport GoldParse
from ..gold cimport GoldParseC
from ..strings cimport StringStore
from .stateclass cimport StateClass
cdef struct Transition:
int clas
@ -15,7 +17,7 @@ cdef struct Transition:
weight_t score
bint (*is_valid)(const State* state, int label) except -1
bint (*is_valid)(StateClass state, int label) except -1
int (*get_cost)(const State* state, const GoldParseC* gold, int label) except -1
int (*do)(State* state, int label) except -1
@ -43,11 +45,11 @@ cdef class TransitionSystem:
cdef Transition init_transition(self, int clas, int move, int label) except *
cdef int set_valid(self, bint* output, const State* state) except -1
cdef int set_valid(self, bint* output, StateClass state) except -1
cdef int set_costs(self, int* output, const State* state, GoldParse gold) except -1
cdef Transition best_valid(self, const weight_t* scores, const State* state) except *
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *
cdef Transition best_gold(self, const weight_t* scores, const State* state,
cdef Transition best_gold(self, const weight_t* scores, State* state,
GoldParse gold) except *

View File

@ -44,10 +44,10 @@ cdef class TransitionSystem:
cdef Transition init_transition(self, int clas, int move, int label) except *:
raise NotImplementedError
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
cdef Transition best_valid(self, const weight_t* scores, StateClass s) except *:
raise NotImplementedError
cdef int set_valid(self, bint* output, const State* state) except -1:
cdef int set_valid(self, bint* output, StateClass state) except -1:
raise NotImplementedError
cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1:
@ -63,9 +63,10 @@ cdef class TransitionSystem:
cdef weight_t score = MIN_SCORE
cdef int i
for i in range(self.n_moves):
cost = self.c[i].get_cost(s, &gold.c, self.c[i].label)
if scores[i] > score and cost == 0:
best = self.c[i]
score = scores[i]
if self.c[i].is_valid(stcls, self.c[i].label):
cost = self.c[i].get_cost(s, &gold.c, self.c[i].label)
if scores[i] > score and cost == 0:
best = self.c[i]
score = scores[i]
assert score > MIN_SCORE
return best