* Move StateClass into the interface for is_valid

This commit is contained in:
Matthew Honnibal 2015-06-09 23:23:28 +02:00
parent 09617a4638
commit e0cf61f591
10 changed files with 132 additions and 340 deletions

View File

@ -1,10 +1,11 @@
from thinc.typedefs cimport atom_t from thinc.typedefs cimport atom_t
from ._state cimport State from ._state cimport State
from .stateclass cimport StateClass
cdef int fill_context(atom_t* context, State* state) except -1 cdef int fill_context(atom_t* context, State* state) except -1
cdef int _new_fill_context(atom_t* context, State* state) except -1 cdef int _new_fill_context(atom_t* context, StateClass state) except -1
# Context elements # Context elements
# Ensure each token's attributes are listed: w, p, c, c6, c4. The order # Ensure each token's attributes are listed: w, p, c, c6, c4. The order

View File

@ -65,13 +65,11 @@ cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
context[10] = token.ent_iob context[10] = token.ent_iob
context[11] = token.ent_type context[11] = token.ent_type
cdef int _new_fill_context(atom_t* ctxt, State* state) except -1: cdef int _new_fill_context(atom_t* ctxt, StateClass st) except -1:
# Take care to fill every element of context! # Take care to fill every element of context!
# We could memset, but this makes it very easy to have broken features that # We could memset, but this makes it very easy to have broken features that
# make almost no impact on accuracy. If instead they're unset, the impact # make almost no impact on accuracy. If instead they're unset, the impact
# tends to be dramatic, so we get an obvious regression to fix... # tends to be dramatic, so we get an obvious regression to fix...
cdef StateClass st = StateClass(state.sent_len)
st.from_struct(state)
fill_token(&ctxt[S2w], st.S_(2)) fill_token(&ctxt[S2w], st.S_(2))
fill_token(&ctxt[S1w], st.S_(1)) fill_token(&ctxt[S1w], st.S_(1))
fill_token(&ctxt[S1rw], st.R_(st.S(1), 1)) fill_token(&ctxt[S1rw], st.R_(st.S(1), 1))
@ -89,8 +87,8 @@ cdef int _new_fill_context(atom_t* ctxt, State* state) except -1:
fill_token(&ctxt[P2w], st.safe_get(st.B(0)-2)) fill_token(&ctxt[P2w], st.safe_get(st.B(0)-2))
# TODO # TODO
fill_token(&ctxt[E0w], get_e0(state)) fill_token(&ctxt[E0w], st.E_(0))
fill_token(&ctxt[E1w], get_e1(state)) fill_token(&ctxt[E1w], st.E_(1))
if st.stack_depth() >= 1 and not st.eol(): if st.stack_depth() >= 1 and not st.eol():
ctxt[dist] = min(st.S(0) - st.B(0), 5) # TODO: This is backwards!! ctxt[dist] = min(st.S(0) - st.B(0), 5) # TODO: This is backwards!!

View File

@ -4,6 +4,8 @@ from thinc.typedefs cimport weight_t
from ._state cimport State from ._state cimport State
from .stateclass cimport StateClass
from .transition_system cimport TransitionSystem, Transition from .transition_system cimport TransitionSystem, Transition
cdef class ArcEager(TransitionSystem): cdef class ArcEager(TransitionSystem):

View File

@ -40,9 +40,6 @@ cdef enum:
BREAK BREAK
CONSTITUENT
ADJUST
N_MOVES N_MOVES
@ -52,8 +49,6 @@ MOVE_NAMES[REDUCE] = 'D'
MOVE_NAMES[LEFT] = 'L' MOVE_NAMES[LEFT] = 'L'
MOVE_NAMES[RIGHT] = 'R' MOVE_NAMES[RIGHT] = 'R'
MOVE_NAMES[BREAK] = 'B' MOVE_NAMES[BREAK] = 'B'
MOVE_NAMES[CONSTITUENT] = 'C'
MOVE_NAMES[ADJUST] = 'A'
# Helper functions for the arc-eager oracle # Helper functions for the arc-eager oracle
@ -69,15 +64,8 @@ cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -
cost += 1 cost += 1
if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)): if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
cost += 1 cost += 1
cost += Break.is_valid(st, -1) and Break.move_cost(st, gold) == 0 cost += Break.is_valid(stcls, -1) and Break.move_cost(st, gold) == 0
return cost return cost
# When we push a word, we can't make arcs to or from the stack. So, we lose
# any of those arcs.
#cost += head_in_stack(st, target, gold.heads)
#cost += children_in_stack(st, target, gold.heads)
# If we can Break, we shouldn't push
#cost += Break.is_valid(st, -1) and Break.move_cost(st, gold) == 0
#return cost
cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1: cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1:
@ -92,10 +80,6 @@ cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1
if gold.heads[B_i] == B_i or gold.heads[B_i] < target: if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
break break
return cost return cost
#cost += children_in_buffer(st, target, gold.heads)
#cost += head_in_buffer(st, target, gold.heads)
#return cost
cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child) except -1: cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child) except -1:
cdef StateClass stcls = StateClass(st.sent_len) cdef StateClass stcls = StateClass(st.sent_len)
@ -108,14 +92,6 @@ cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child)
return 1 return 1
else: else:
return 0 return 0
#if arc_is_gold(gold, head, child):
# return 0
#elif (child + st.sent[child].head) == gold.heads[child]:
# return 1
#elif gold.heads[child] >= st.i:
# return 1
#else:
# return 0
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1: cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
@ -146,11 +122,7 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) except -1:
cdef class Shift: cdef class Shift:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) except -1:
return not at_eol(s)
@staticmethod
cdef bint _new_is_valid(StateClass st, int label) except -1:
return not st.eol() return not st.eol()
@staticmethod @staticmethod
@ -162,8 +134,6 @@ cdef class Shift:
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Shift.is_valid(s, label):
return 9000
return Shift.move_cost(s, gold) + Shift.label_cost(s, gold, label) return Shift.move_cost(s, gold) + Shift.label_cost(s, gold, label)
@staticmethod @staticmethod
@ -177,19 +147,12 @@ cdef class Shift:
cdef class Reduce: cdef class Reduce:
@staticmethod @staticmethod
cdef bint _new_is_valid(StateClass st, int label) except -1: cdef bint is_valid(StateClass st, int label) except -1:
if NON_MONOTONIC: if NON_MONOTONIC:
return st.stack_depth() >= 2 #and not missing_brackets(s) return st.stack_depth() >= 2 #and not missing_brackets(s)
else: else:
return st.stack_depth() >= 2 and st.has_head(st.S(0)) return st.stack_depth() >= 2 and st.has_head(st.S(0))
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
if NON_MONOTONIC:
return s.stack_len >= 2 #and not missing_brackets(s)
else:
return s.stack_len >= 2 and has_head(get_s0(s))
@staticmethod @staticmethod
cdef int transition(State* state, int label) except -1: cdef int transition(State* state, int label) except -1:
if NON_MONOTONIC and not has_head(get_s0(state)) and state.stack_len >= 2: if NON_MONOTONIC and not has_head(get_s0(state)) and state.stack_len >= 2:
@ -198,8 +161,6 @@ cdef class Reduce:
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Reduce.is_valid(s, label):
return 9000
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label) return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
@staticmethod @staticmethod
@ -216,19 +177,12 @@ cdef class Reduce:
cdef class LeftArc: cdef class LeftArc:
@staticmethod @staticmethod
cdef bint _new_is_valid(StateClass st, int label) except -1: cdef bint is_valid(StateClass st, int label) except -1:
if NON_MONOTONIC: if NON_MONOTONIC:
return st.stack_depth() >= 1 #and not missing_brackets(s) return st.stack_depth() >= 1 #and not missing_brackets(s)
else: else:
return st.stack_depth() >= 1 and not st.has_head(st.S(0)) return st.stack_depth() >= 1 and not st.has_head(st.S(0))
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
if NON_MONOTONIC:
return s.stack_len >= 1 #and not missing_brackets(s)
else:
return s.stack_len >= 1 and not has_head(get_s0(s))
@staticmethod @staticmethod
cdef int transition(State* state, int label) except -1: cdef int transition(State* state, int label) except -1:
# Interpret left-arcs from EOL as attachment to root # Interpret left-arcs from EOL as attachment to root
@ -240,15 +194,11 @@ cdef class LeftArc:
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not LeftArc.is_valid(s, label):
return 9000
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label) return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
@staticmethod @staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1: cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
if not LeftArc.is_valid(s, -1): if arc_is_gold(gold, s.i, s.stack[0]):
return 9000
elif arc_is_gold(gold, s.i, s.stack[0]):
return 0 return 0
else: else:
return pop_cost(s, gold, s.stack[0]) + arc_cost(s, gold, s.i, s.stack[0]) return pop_cost(s, gold, s.stack[0]) + arc_cost(s, gold, s.i, s.stack[0])
@ -260,11 +210,7 @@ cdef class LeftArc:
cdef class RightArc: cdef class RightArc:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) except -1:
return s.stack_len >= 1 and not at_eol(s)
@staticmethod
cdef bint _new_is_valid(StateClass st, int label) except -1:
return st.stack_depth() >= 1 and not st.eol() return st.stack_depth() >= 1 and not st.eol()
@staticmethod @staticmethod
@ -274,8 +220,6 @@ cdef class RightArc:
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not RightArc.is_valid(s, label):
return 9000
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label) return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
@staticmethod @staticmethod
@ -292,7 +236,7 @@ cdef class RightArc:
cdef class Break: cdef class Break:
@staticmethod @staticmethod
cdef bint _new_is_valid(StateClass st, int label) except -1: cdef bint is_valid(StateClass st, int label) except -1:
cdef int i cdef int i
if not USE_BREAK: if not USE_BREAK:
return False return False
@ -317,32 +261,6 @@ cdef class Break:
# TODO: Constituency constraints # TODO: Constituency constraints
return True return True
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
cdef int i
if not USE_BREAK:
return False
elif at_eol(s):
return False
elif s.stack_len < 1:
return False
elif NON_MONOTONIC:
return True
else:
# In the Break transition paper, they have this constraint that prevents
# Break if stack is disconnected. But, if we're doing non-monotonic parsing,
# we prefer to relax this constraint. This is helpful in parsing whole
# documents, because then we don't get stuck with words on the stack.
seen_headless = False
for i in range(s.stack_len):
if s.sent[s.stack[-i]].head == 0:
if seen_headless:
return False
else:
seen_headless = True
# TODO: Constituency constraints
return True
@staticmethod @staticmethod
cdef int transition(State* state, int label) except -1: cdef int transition(State* state, int label) except -1:
state.sent[state.i-1].sent_end = True state.sent[state.i-1].sent_end = True
@ -354,9 +272,6 @@ cdef class Break:
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Break.is_valid(s, label):
return 9000
else:
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label) return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
@staticmethod @staticmethod
@ -374,163 +289,11 @@ cdef class Break:
return 0 return 0
cdef class Constituent:
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
if s.stack_len < 1:
return False
return False
#else:
# # If all stack elements are popped, can't constituent
# for i in range(s.ctnts.stack_len):
# if not s.ctnts.is_popped[-i]:
# return True
# else:
# return False
@staticmethod
cdef int transition(State* state, int label) except -1:
return False
#cdef Constituent* bracket = new_bracket(state.ctnts)
#bracket.parent = NULL
#bracket.label = self.label
#bracket.head = get_s0(state)
#bracket.length = 0
#attach(bracket, state.ctnts.stack)
# Attach rightward children. They're in the brackets array somewhere
# between here and B0.
#cdef Constituent* node
#cdef const TokenC* node_gov
#for i in range(1, bracket - state.ctnts.stack):
# node = bracket - i
# node_gov = node.head + node.head.head
# if node_gov == bracket.head:
# attach(bracket, node)
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Constituent.is_valid(s, label):
return 9000
raise Exception("Constituent move should be disabled currently")
# The gold standard is indexed by end, then by start, then a set of labels
#brackets = gold.brackets(get_s0(s).r_edge, {})
#if not brackets:
# return 2 # 2 loss for bad bracket, only 1 for good bracket bad label
# Index the current brackets in the state
#existing = set()
#for i in range(s.ctnt_len):
# if ctnt.end == s.r_edge and ctnt.label == self.label:
# existing.add(ctnt.start)
#cdef int loss = 2
#cdef const TokenC* child
#cdef const TokenC* s0 = get_s0(s)
#cdef int n_left = count_left_kids(s0)
# Iterate over the possible start positions, and check whether we have a
# (start, end, label) match to the gold tree
#for i in range(1, n_left):
# child = get_left(s, s0, i)
# if child.l_edge in brackets and child.l_edge not in existing:
# if self.label in brackets[child.l_edge]
# return 0
# else:
# loss = 1 # If we see the start position, set loss to 1
#return loss
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
if not Constituent.is_valid(s, -1):
return 9000
raise Exception("Constituent move should be disabled currently")
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
return 0
cdef class Adjust:
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
return False
#if s.ctnts.stack_len < 2:
# return False
#cdef const Constituent* b1 = s.ctnts.stack[-1]
#cdef const Constituent* b0 = s.ctnts.stack[0]
#if (b1.head + b1.head.head) != b0.head:
# return False
#elif b0.head >= b1.head:
# return False
#elif b0 >= b1:
# return False
@staticmethod
cdef int transition(State* state, int label) except -1:
return False
#cdef Constituent* b0 = state.ctnts.stack[0]
#cdef Constituent* b1 = state.ctnts.stack[1]
#assert (b1.head + b1.head.head) == b0.head
#assert b0.head < b1.head
#assert b0 < b1
#attach(b0, b1)
## Pop B1 from stack, but keep B0 on top
#state.ctnts.stack -= 1
#state.ctnts.stack[0] = b0
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Adjust.is_valid(s, label):
return 9000
raise Exception("Adjust move should be disabled currently")
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
if not Adjust.is_valid(s, -1):
return 9000
raise Exception("Adjust move should be disabled currently")
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
return 0
# The gold standard is indexed by end, then by start, then a set of labels
#gold_starts = gold.brackets(get_s0(s).r_edge, {})
# Case 1: There are 0 brackets ending at this word.
# --> Cost is sunk, but must allow brackets to begin
#if not gold_starts:
# return 0
# Is the top bracket correct?
#gold_labels = gold_starts.get(s.ctnt.start, set())
# TODO: Case where we have a unary rule
# TODO: Case where two brackets end on this word, with top bracket starting
# before
#cdef const TokenC* child
#cdef const TokenC* s0 = get_s0(s)
#cdef int n_left = count_left_kids(s0)
#cdef int i
# Iterate over the possible start positions, and check whether we have a
# (start, end, label) match to the gold tree
#for i in range(1, n_left):
# child = get_left(s, s0, i)
# if child.l_edge in brackets:
# if self.label in brackets[child.l_edge]:
# return 0
# else:
# loss = 1 # If we see the start position, set loss to 1
#return loss
cdef class ArcEager(TransitionSystem): cdef class ArcEager(TransitionSystem):
@classmethod @classmethod
def get_labels(cls, gold_parses): def get_labels(cls, gold_parses):
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {}, move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {},
LEFT: {'ROOT': True}, BREAK: {'ROOT': True}, LEFT: {'ROOT': True}, BREAK: {'ROOT': True}}
CONSTITUENT: {}, ADJUST: {'': True}}
for raw_text, sents in gold_parses: for raw_text, sents in gold_parses:
for (ids, words, tags, heads, labels, iob), ctnts in sents: for (ids, words, tags, heads, labels, iob), ctnts in sents:
for child, head, label in zip(ids, heads, labels): for child, head, label in zip(ids, heads, labels):
@ -539,8 +302,6 @@ cdef class ArcEager(TransitionSystem):
move_labels[RIGHT][label] = True move_labels[RIGHT][label] = True
elif head > child: elif head > child:
move_labels[LEFT][label] = True move_labels[LEFT][label] = True
for start, end, label in ctnts:
move_labels[CONSTITUENT][label] = True
return move_labels return move_labels
cdef int preprocess_gold(self, GoldParse gold) except -1: cdef int preprocess_gold(self, GoldParse gold) except -1:
@ -604,14 +365,6 @@ cdef class ArcEager(TransitionSystem):
t.is_valid = Break.is_valid t.is_valid = Break.is_valid
t.do = Break.transition t.do = Break.transition
t.get_cost = Break.cost t.get_cost = Break.cost
elif move == CONSTITUENT:
t.is_valid = Constituent.is_valid
t.do = Constituent.transition
t.get_cost = Constituent.cost
elif move == ADJUST:
t.is_valid = Adjust.is_valid
t.do = Adjust.transition
t.get_cost = Adjust.cost
else: else:
raise Exception(move) raise Exception(move)
return t return t
@ -625,18 +378,13 @@ cdef class ArcEager(TransitionSystem):
if state.sent[i].head == 0 and state.sent[i].dep == 0: if state.sent[i].head == 0 and state.sent[i].dep == 0:
state.sent[i].dep = root_label state.sent[i].dep = root_label
cdef int set_valid(self, bint* output, const State* state) except -1: cdef int set_valid(self, bint* output, StateClass stcls) except -1:
raise Exception
cdef StateClass stcls = StateClass(state.sent_len)
stcls.from_struct(state)
cdef bint[N_MOVES] is_valid cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = Shift._new_is_valid(stcls, -1) is_valid[SHIFT] = Shift.is_valid(stcls, -1)
is_valid[REDUCE] = Reduce._new_is_valid(stcls, -1) is_valid[REDUCE] = Reduce.is_valid(stcls, -1)
is_valid[LEFT] = LeftArc._new_is_valid(stcls, -1) is_valid[LEFT] = LeftArc.is_valid(stcls, -1)
is_valid[RIGHT] = RightArc._new_is_valid(stcls, -1) is_valid[RIGHT] = RightArc.is_valid(stcls, -1)
is_valid[BREAK] = Break._new_is_valid(stcls, -1) is_valid[BREAK] = Break.is_valid(stcls, -1)
is_valid[CONSTITUENT] = False # Constituent.is_valid(state, -1)
is_valid[ADJUST] = False # Adjust.is_valid(state, -1)
cdef int i cdef int i
for i in range(self.n_moves): for i in range(self.n_moves):
output[i] = is_valid[self.c[i].move] output[i] = is_valid[self.c[i].move]
@ -653,38 +401,36 @@ cdef class ArcEager(TransitionSystem):
move_cost_funcs[LEFT] = LeftArc.move_cost move_cost_funcs[LEFT] = LeftArc.move_cost
move_cost_funcs[RIGHT] = RightArc.move_cost move_cost_funcs[RIGHT] = RightArc.move_cost
move_cost_funcs[BREAK] = Break.move_cost move_cost_funcs[BREAK] = Break.move_cost
move_cost_funcs[CONSTITUENT] = Constituent.move_cost
move_cost_funcs[ADJUST] = Adjust.move_cost
label_cost_funcs[SHIFT] = Shift.label_cost label_cost_funcs[SHIFT] = Shift.label_cost
label_cost_funcs[REDUCE] = Reduce.label_cost label_cost_funcs[REDUCE] = Reduce.label_cost
label_cost_funcs[LEFT] = LeftArc.label_cost label_cost_funcs[LEFT] = LeftArc.label_cost
label_cost_funcs[RIGHT] = RightArc.label_cost label_cost_funcs[RIGHT] = RightArc.label_cost
label_cost_funcs[BREAK] = Break.label_cost label_cost_funcs[BREAK] = Break.label_cost
label_cost_funcs[CONSTITUENT] = Constituent.label_cost
label_cost_funcs[ADJUST] = Adjust.label_cost
cdef int* labels = gold.c.labels cdef int* labels = gold.c.labels
cdef int* heads = gold.c.heads cdef int* heads = gold.c.heads
cdef StateClass stcls = StateClass(s.sent_len)
stcls.from_struct(s)
self.set_valid(self._is_valid, stcls)
for i in range(self.n_moves): for i in range(self.n_moves):
if not self._is_valid[i]:
output[i] = 9000
else:
move = self.c[i].move move = self.c[i].move
label = self.c[i].label label = self.c[i].label
if move_costs[move] == -1: if move_costs[move] == -1:
move_costs[move] = move_cost_funcs[move](s, &gold.c) move_costs[move] = move_cost_funcs[move](s, &gold.c)
output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label) output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label)
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
assert s is not NULL
cdef StateClass stcls = StateClass(s.sent_len)
stcls.from_struct(s)
cdef bint[N_MOVES] is_valid cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = Shift._new_is_valid(stcls, -1) is_valid[SHIFT] = Shift.is_valid(stcls, -1)
is_valid[REDUCE] = Reduce._new_is_valid(stcls, -1) is_valid[REDUCE] = Reduce.is_valid(stcls, -1)
is_valid[LEFT] = LeftArc._new_is_valid(stcls, -1) is_valid[LEFT] = LeftArc.is_valid(stcls, -1)
is_valid[RIGHT] = RightArc._new_is_valid(stcls, -1) is_valid[RIGHT] = RightArc.is_valid(stcls, -1)
is_valid[BREAK] = Break._new_is_valid(stcls, -1) is_valid[BREAK] = Break.is_valid(stcls, -1)
is_valid[CONSTITUENT] = False # Constituent._new_is_valid(s, -1)
is_valid[ADJUST] = False # Adjust._new_is_valid(s, -1)
cdef Transition best cdef Transition best
cdef weight_t score = MIN_SCORE cdef weight_t score = MIN_SCORE
cdef int i cdef int i
@ -703,5 +449,3 @@ cdef class ArcEager(TransitionSystem):
best.label = self.c[i].label best.label = self.c[i].label
score = scores[i] score = scores[i]
return best return best

View File

@ -11,6 +11,8 @@ from thinc.typedefs cimport weight_t
from ..gold cimport GoldParseC from ..gold cimport GoldParseC
from ..gold cimport GoldParse from ..gold cimport GoldParse
from .stateclass cimport StateClass
cdef enum: cdef enum:
MISSING MISSING
@ -132,14 +134,14 @@ cdef class BiluoPushDown(TransitionSystem):
raise Exception(move) raise Exception(move)
return t return t
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
cdef int best = -1 cdef int best = -1
cdef weight_t score = -90000 cdef weight_t score = -90000
cdef const Transition* m cdef const Transition* m
cdef int i cdef int i
for i in range(self.n_moves): for i in range(self.n_moves):
m = &self.c[i] m = &self.c[i]
if m.is_valid(s, m.label) and scores[i] > score: if m.is_valid(stcls, m.label) and scores[i] > score:
best = i best = i
score = scores[i] score = scores[i]
assert best >= 0 assert best >= 0
@ -147,16 +149,16 @@ cdef class BiluoPushDown(TransitionSystem):
t.score = score t.score = score
return t return t
cdef int set_valid(self, bint* output, const State* s) except -1: cdef int set_valid(self, bint* output, StateClass stcls) except -1:
cdef int i cdef int i
for i in range(self.n_moves): for i in range(self.n_moves):
m = &self.c[i] m = &self.c[i]
output[i] = m.is_valid(s, m.label) output[i] = m.is_valid(stcls, m.label)
cdef class Missing: cdef class Missing:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) except -1:
return False return False
@staticmethod @staticmethod
@ -170,8 +172,8 @@ cdef class Missing:
cdef class Begin: cdef class Begin:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) except -1:
return label != 0 and not entity_is_open(s) return label != 0 and not st.entity_is_open()
@staticmethod @staticmethod
cdef int transition(State* s, int label) except -1: cdef int transition(State* s, int label) except -1:
@ -186,8 +188,6 @@ cdef class Begin:
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Begin.is_valid(s, label):
return 9000
cdef int g_act = gold.ner[s.i].move cdef int g_act = gold.ner[s.i].move
cdef int g_tag = gold.ner[s.i].label cdef int g_tag = gold.ner[s.i].label
@ -203,10 +203,11 @@ cdef class Begin:
# B, Gold U --> False (P) # B, Gold U --> False (P)
return 1 return 1
cdef class In: cdef class In:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) except -1:
return entity_is_open(s) and label != 0 and s.ent.label == label return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
@staticmethod @staticmethod
cdef int transition(State* s, int label) except -1: cdef int transition(State* s, int label) except -1:
@ -216,8 +217,6 @@ cdef class In:
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not In.is_valid(s, label):
return 9000
move = IN move = IN
cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT
cdef int g_act = gold.ner[s.i].move cdef int g_act = gold.ner[s.i].move
@ -245,11 +244,10 @@ cdef class In:
return 1 return 1
cdef class Last: cdef class Last:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) except -1:
return entity_is_open(s) and label != 0 and s.ent.label == label return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
@staticmethod @staticmethod
cdef int transition(State* s, int label) except -1: cdef int transition(State* s, int label) except -1:
@ -260,8 +258,6 @@ cdef class Last:
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Last.is_valid(s, label):
return 9000
move = LAST move = LAST
cdef int g_act = gold.ner[s.i].move cdef int g_act = gold.ner[s.i].move
@ -290,8 +286,8 @@ cdef class Last:
cdef class Unit: cdef class Unit:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) except -1:
return label != 0 and not entity_is_open(s) return label != 0 and not st.entity_is_open()
@staticmethod @staticmethod
cdef int transition(State* s, int label) except -1: cdef int transition(State* s, int label) except -1:
@ -306,8 +302,6 @@ cdef class Unit:
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Unit.is_valid(s, label):
return 9000
cdef int g_act = gold.ner[s.i].move cdef int g_act = gold.ner[s.i].move
cdef int g_tag = gold.ner[s.i].label cdef int g_tag = gold.ner[s.i].label
@ -326,8 +320,8 @@ cdef class Unit:
cdef class Out: cdef class Out:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) except -1:
return not entity_is_open(s) return not st.entity_is_open()
@staticmethod @staticmethod
cdef int transition(State* s, int label) except -1: cdef int transition(State* s, int label) except -1:
@ -336,9 +330,6 @@ cdef class Out:
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Out.is_valid(s, label):
return 9000
cdef int g_act = gold.ner[s.i].move cdef int g_act = gold.ner[s.i].move
cdef int g_tag = gold.ner[s.i].label cdef int g_tag = gold.ner[s.i].label

View File

@ -40,8 +40,9 @@ from ..gold cimport GoldParse
from . import _parse_features from . import _parse_features
from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport _new_fill_context as fill_context from ._parse_features cimport _new_fill_context
#from ._parse_features cimport fill_context from ._parse_features cimport fill_context
from .stateclass cimport StateClass
DEBUG = False DEBUG = False
@ -104,11 +105,13 @@ cdef class Parser:
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef State* state = new_state(mem, tokens.data, tokens.length) cdef State* state = new_state(mem, tokens.data, tokens.length)
self.moves.initialize_state(state) self.moves.initialize_state(state)
cdef StateClass stcls = StateClass(state.sent_len)
cdef Transition guess cdef Transition guess
while not is_final(state): while not is_final(state):
fill_context(context, state) stcls.from_struct(state)
_new_fill_context(context, stcls)
scores = self.model.score(context) scores = self.model.score(context)
guess = self.moves.best_valid(scores, state) guess = self.moves.best_valid(scores, stcls)
guess.do(state, guess.label) guess.do(state, guess.label)
self.moves.finalize_state(state) self.moves.finalize_state(state)
tokens.set_parse(state.sent) tokens.set_parse(state.sent)
@ -133,12 +136,14 @@ cdef class Parser:
cdef const weight_t* scores cdef const weight_t* scores
cdef Transition guess cdef Transition guess
cdef Transition best cdef Transition best
cdef StateClass stcls = StateClass(state.sent_len)
cdef atom_t[CONTEXT_SIZE] context cdef atom_t[CONTEXT_SIZE] context
loss = 0 loss = 0
while not is_final(state): while not is_final(state):
fill_context(context, state) stcls.from_struct(state)
_new_fill_context(context, stcls)
scores = self.model.score(context) scores = self.model.score(context)
guess = self.moves.best_valid(scores, state) guess = self.moves.best_valid(scores, stcls)
best = self.moves.best_gold(scores, state, gold) best = self.moves.best_gold(scores, state, gold)
cost = guess.get_cost(state, &gold.c, guess.label) cost = guess.get_cost(state, &gold.c, guess.label)
self.model.update(context, guess.clas, best.clas, cost) self.model.update(context, guess.clas, best.clas, cost)
@ -174,12 +179,14 @@ cdef class Parser:
cdef int i, j, cost cdef int i, j, cost
cdef bint is_valid cdef bint is_valid
cdef const Transition* move cdef const Transition* move
cdef StateClass stcls = StateClass(gold.length)
for i in range(beam.size): for i in range(beam.size):
state = <State*>beam.at(i) state = <State*>beam.at(i)
stcls.from_struct(state)
if not is_final(state): if not is_final(state):
fill_context(context, state) fill_context(context, state)
self.model.set_scores(beam.scores[i], context) self.model.set_scores(beam.scores[i], context)
self.moves.set_valid(beam.is_valid[i], state) self.moves.set_valid(beam.is_valid[i], stcls)
if gold is not None: if gold is not None:
for i in range(beam.size): for i in range(beam.size):

View File

@ -2,7 +2,7 @@ from libc.string cimport memcpy, memset
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from ..structs cimport TokenC from ..structs cimport TokenC, Entity
from ._state cimport State from ._state cimport State
@ -14,10 +14,12 @@ cdef class StateClass:
cdef int* _stack cdef int* _stack
cdef int* _buffer cdef int* _buffer
cdef TokenC* _sent cdef TokenC* _sent
cdef Entity* _ents
cdef TokenC _empty_token cdef TokenC _empty_token
cdef int length cdef int length
cdef int _s_i cdef int _s_i
cdef int _b_i cdef int _b_i
cdef int _e_i
cdef int from_struct(self, const State* state) except -1 cdef int from_struct(self, const State* state) except -1
@ -25,6 +27,7 @@ cdef class StateClass:
cdef int B(self, int i) nogil cdef int B(self, int i) nogil
cdef int H(self, int i) nogil cdef int H(self, int i) nogil
cdef int E(self, int i) nogil
cdef int L(self, int i, int idx) nogil cdef int L(self, int i, int idx) nogil
cdef int R(self, int i, int idx) nogil cdef int R(self, int i, int idx) nogil
@ -33,6 +36,7 @@ cdef class StateClass:
cdef const TokenC* B_(self, int i) nogil cdef const TokenC* B_(self, int i) nogil
cdef const TokenC* H_(self, int i) nogil cdef const TokenC* H_(self, int i) nogil
cdef const TokenC* E_(self, int i) nogil
cdef const TokenC* L_(self, int i, int idx) nogil cdef const TokenC* L_(self, int i, int idx) nogil
cdef const TokenC* R_(self, int i, int idx) nogil cdef const TokenC* R_(self, int i, int idx) nogil
@ -41,12 +45,15 @@ cdef class StateClass:
cdef bint empty(self) nogil cdef bint empty(self) nogil
cdef bint entity_is_open(self) nogil
cdef bint eol(self) nogil cdef bint eol(self) nogil
cdef bint is_final(self) nogil cdef bint is_final(self) nogil
cdef bint has_head(self, int i) nogil cdef bint has_head(self, int i) nogil
cdef int n_L(self, int i) nogil cdef int n_L(self, int i) nogil
cdef int n_R(self, int i) nogil cdef int n_R(self, int i) nogil
@ -65,6 +72,12 @@ cdef class StateClass:
cdef void del_arc(self, int head, int child) nogil cdef void del_arc(self, int head, int child) nogil
cdef void open_ent(self, int label) nogil
cdef void close_ent(self) nogil
cdef void set_ent_tag(self, int i, int ent_iob, int ent_type) nogil
cdef void set_sent_end(self, int i) nogil cdef void set_sent_end(self, int i) nogil
cdef void clone(self, StateClass src) nogil cdef void clone(self, StateClass src) nogil

View File

@ -9,10 +9,12 @@ cdef class StateClass:
self._buffer = <int*>mem.alloc(length, sizeof(int)) self._buffer = <int*>mem.alloc(length, sizeof(int))
self._stack = <int*>mem.alloc(length, sizeof(int)) self._stack = <int*>mem.alloc(length, sizeof(int))
self._sent = <TokenC*>mem.alloc(length, sizeof(TokenC)) self._sent = <TokenC*>mem.alloc(length, sizeof(TokenC))
self._ents = <Entity*>mem.alloc(length, sizeof(Entity))
self.mem = mem self.mem = mem
self.length = length self.length = length
self._s_i = 0 self._s_i = 0
self._b_i = 0 self._b_i = 0
self._e_i = 0
cdef int i cdef int i
for i in range(length): for i in range(length):
self._buffer[i] = i self._buffer[i] = i
@ -21,10 +23,13 @@ cdef class StateClass:
cdef int from_struct(self, const State* state) except -1: cdef int from_struct(self, const State* state) except -1:
self._s_i = state.stack_len self._s_i = state.stack_len
self._b_i = state.i self._b_i = state.i
self._e_i = state.ents_len
memcpy(self._sent, state.sent, sizeof(TokenC) * self.length) memcpy(self._sent, state.sent, sizeof(TokenC) * self.length)
cdef int i cdef int i
for i in range(state.stack_len): for i in range(state.stack_len):
self._stack[self._s_i - (i+1)] = state.stack[-i] self._stack[self._s_i - (i+1)] = state.stack[-i]
for i in range(state.ents_len):
self._ents[i] = state.ent[-i]
cdef int S(self, int i) nogil: cdef int S(self, int i) nogil:
if i >= self._s_i: if i >= self._s_i:
@ -41,6 +46,9 @@ cdef class StateClass:
return -1 return -1
return self._sent[i].head + i return self._sent[i].head + i
cdef int E(self, int i) nogil:
return -1
cdef int L(self, int i, int idx) nogil: cdef int L(self, int i, int idx) nogil:
if idx < 1: if idx < 1:
return -1 return -1
@ -94,7 +102,10 @@ cdef class StateClass:
return self.safe_get(self.B(i)) return self.safe_get(self.B(i))
cdef const TokenC* H_(self, int i) nogil: cdef const TokenC* H_(self, int i) nogil:
return self.safe_get(self.B(i)) return self.safe_get(self.H(i))
cdef const TokenC* E_(self, int i) nogil:
return self.safe_get(self.E(i))
cdef const TokenC* L_(self, int i, int idx) nogil: cdef const TokenC* L_(self, int i, int idx) nogil:
return self.safe_get(self.L(i, idx)) return self.safe_get(self.L(i, idx))
@ -129,6 +140,11 @@ cdef class StateClass:
cdef bint stack_is_connected(self) nogil: cdef bint stack_is_connected(self) nogil:
return False return False
cdef bint entity_is_open(self) nogil:
if self._e_i < 1:
return False
return self._ents[self._e_i-1].end != 0
cdef int stack_depth(self) nogil: cdef int stack_depth(self) nogil:
return self._s_i return self._s_i
@ -164,6 +180,21 @@ cdef class StateClass:
else: else:
self._sent[head].l_kids &= ~(1 << dist) self._sent[head].l_kids &= ~(1 << dist)
cdef void open_ent(self, int label) nogil:
self._ents[self._e_i].start = self.B(0)
self._ents[self._e_i].label = label
self._ents[self._e_i].end = 0
self._e_i += 1
cdef void close_ent(self) nogil:
self._ents[self._e_i].end = self.B(0)+1
self._sent[self.B(0)].ent_iob = 1
cdef void set_ent_tag(self, int i, int ent_iob, int ent_type) nogil:
if 0 <= i < self.length:
self._sent[i].ent_iob = ent_iob
self._sent[i].ent_type = ent_type
cdef void set_sent_end(self, int i) nogil: cdef void set_sent_end(self, int i) nogil:
if 0 < i < self.length: if 0 < i < self.length:
self._sent[i].sent_end = True self._sent[i].sent_end = True
@ -172,8 +203,10 @@ cdef class StateClass:
memcpy(self._sent, src._sent, self.length * sizeof(TokenC)) memcpy(self._sent, src._sent, self.length * sizeof(TokenC))
memcpy(self._stack, src._stack, self.length * sizeof(int)) memcpy(self._stack, src._stack, self.length * sizeof(int))
memcpy(self._buffer, src._buffer, self.length * sizeof(int)) memcpy(self._buffer, src._buffer, self.length * sizeof(int))
memcpy(self._ents, src._ents, self.length * sizeof(int))
self._b_i = src._b_i self._b_i = src._b_i
self._s_i = src._s_i self._s_i = src._s_i
self._e_i = src._e_i
# From https://en.wikipedia.org/wiki/Hamming_weight # From https://en.wikipedia.org/wiki/Hamming_weight

View File

@ -7,6 +7,8 @@ from ..gold cimport GoldParse
from ..gold cimport GoldParseC from ..gold cimport GoldParseC
from ..strings cimport StringStore from ..strings cimport StringStore
from .stateclass cimport StateClass
cdef struct Transition: cdef struct Transition:
int clas int clas
@ -15,7 +17,7 @@ cdef struct Transition:
weight_t score weight_t score
bint (*is_valid)(const State* state, int label) except -1 bint (*is_valid)(StateClass state, int label) except -1
int (*get_cost)(const State* state, const GoldParseC* gold, int label) except -1 int (*get_cost)(const State* state, const GoldParseC* gold, int label) except -1
int (*do)(State* state, int label) except -1 int (*do)(State* state, int label) except -1
@ -43,11 +45,11 @@ cdef class TransitionSystem:
cdef Transition init_transition(self, int clas, int move, int label) except * cdef Transition init_transition(self, int clas, int move, int label) except *
cdef int set_valid(self, bint* output, const State* state) except -1 cdef int set_valid(self, bint* output, StateClass state) except -1
cdef int set_costs(self, int* output, const State* state, GoldParse gold) except -1 cdef int set_costs(self, int* output, const State* state, GoldParse gold) except -1
cdef Transition best_valid(self, const weight_t* scores, const State* state) except * cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *
cdef Transition best_gold(self, const weight_t* scores, const State* state, cdef Transition best_gold(self, const weight_t* scores, State* state,
GoldParse gold) except * GoldParse gold) except *

View File

@ -44,10 +44,10 @@ cdef class TransitionSystem:
cdef Transition init_transition(self, int clas, int move, int label) except *: cdef Transition init_transition(self, int clas, int move, int label) except *:
raise NotImplementedError raise NotImplementedError
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: cdef Transition best_valid(self, const weight_t* scores, StateClass s) except *:
raise NotImplementedError raise NotImplementedError
cdef int set_valid(self, bint* output, const State* state) except -1: cdef int set_valid(self, bint* output, StateClass state) except -1:
raise NotImplementedError raise NotImplementedError
cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1: cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1:
@ -63,6 +63,7 @@ cdef class TransitionSystem:
cdef weight_t score = MIN_SCORE cdef weight_t score = MIN_SCORE
cdef int i cdef int i
for i in range(self.n_moves): for i in range(self.n_moves):
if self.c[i].is_valid(stcls, self.c[i].label):
cost = self.c[i].get_cost(s, &gold.c, self.c[i].label) cost = self.c[i].get_cost(s, &gold.c, self.c[i].label)
if scores[i] > score and cost == 0: if scores[i] > score and cost == 0:
best = self.c[i] best = self.c[i]