* Prepare to switch to using state class, instead of state struct

This commit is contained in:
Matthew Honnibal 2015-06-09 21:20:14 +02:00
parent 2b9629ed62
commit 0895d454fb
10 changed files with 245 additions and 141 deletions

View File

@ -4,6 +4,7 @@ from ._state cimport State
cdef int fill_context(atom_t* context, State* state) except -1
cdef int _new_fill_context(atom_t* context, State* state) except -1
# Context elements
# Ensure each token's attributes are listed: w, p, c, c6, c4. The order

View File

@ -20,6 +20,11 @@ from ._state cimport has_head, get_left, get_right
from ._state cimport count_left_kids, count_right_kids
from .stateclass cimport StateClass
from cymem.cymem cimport Pool
cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
if token is NULL:
context[0] = 0
@ -60,6 +65,53 @@ cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
context[10] = token.ent_iob
context[11] = token.ent_type
cdef int _new_fill_context(atom_t* ctxt, State* state) except -1:
# Take care to fill every element of context!
# We could memset, but this makes it very easy to have broken features that
# make almost no impact on accuracy. If instead they're unset, the impact
# tends to be dramatic, so we get an obvious regression to fix...
cdef StateClass st = StateClass(state.sent_len)
st.from_struct(state)
fill_token(&ctxt[S2w], st.S_(2))
fill_token(&ctxt[S1w], st.S_(1))
fill_token(&ctxt[S1rw], st.R_(st.S(1), 1))
fill_token(&ctxt[S0lw], st.L_(st.S(0), 1))
fill_token(&ctxt[S0l2w], st.L_(st.S(0), 2))
fill_token(&ctxt[S0w], st.S_(0))
fill_token(&ctxt[S0r2w], st.R_(st.S(0), 2))
fill_token(&ctxt[S0rw], st.R_(st.S(0), 1))
fill_token(&ctxt[N0lw], st.L_(st.B(0), 1))
fill_token(&ctxt[N0l2w], st.L_(st.B(0), 2))
fill_token(&ctxt[N0w], st.B_(0))
fill_token(&ctxt[N1w], st.B_(1))
fill_token(&ctxt[N2w], st.B_(2))
fill_token(&ctxt[P1w], st.safe_get(st.B(0)-1))
fill_token(&ctxt[P2w], st.safe_get(st.B(0)-2))
# TODO
fill_token(&ctxt[E0w], get_e0(state))
fill_token(&ctxt[E1w], get_e1(state))
if st.stack_depth() >= 1 and not st.eol():
ctxt[dist] = min(st.S(0) - st.B(0), 5) # TODO: This is backwards!!
else:
ctxt[dist] = 0
ctxt[N0lv] = min(st.n_L(st.B(0)), 5)
ctxt[S0lv] = min(st.n_L(st.S(0)), 5)
ctxt[S0rv] = min(st.n_R(st.S(0)), 5)
ctxt[S1lv] = min(st.n_L(st.S(1)), 5)
ctxt[S1rv] = min(st.n_R(st.S(1)), 5)
ctxt[S0_has_head] = 0
ctxt[S1_has_head] = 0
ctxt[S2_has_head] = 0
if st.stack_depth() >= 1:
ctxt[S0_has_head] = st.has_head(st.S(0)) + 1
if st.stack_depth() >= 2:
ctxt[S1_has_head] = st.has_head(st.S(1)) + 1
if st.stack_depth() >= 3:
ctxt[S2_has_head] = st.has_head(st.S(2)) + 1
cdef int fill_context(atom_t* context, State* state) except -1:
# Take care to fill every element of context!

View File

@ -115,29 +115,33 @@ cdef bint has_head(const TokenC* t) nogil:
cdef const TokenC* get_left(const State* s, const TokenC* head, const int idx) nogil:
return _new_get_left(s, head, idx)
#cdef uint32_t kids = head.l_kids
#if kids == 0:
# return NULL
#cdef int offset = _nth_significant_bit(kids, idx)
#cdef const TokenC* child = head - offset
#if child >= s.sent:
# return child
##else:
# return NULL
"""
cdef uint32_t kids = head.l_kids
if kids == 0:
return NULL
cdef int offset = _nth_significant_bit(kids, idx)
cdef const TokenC* child = head - offset
if child >= s.sent:
return child
else:
return NULL
"""
cdef const TokenC* get_right(const State* s, const TokenC* head, const int idx) nogil:
return _new_get_right(s, head, idx)
#cdef uint32_t kids = head.r_kids
#if kids == 0:
# return NULL
#cdef int offset = _nth_significant_bit(kids, idx)
#cdef const TokenC* child = head + offset
#if child < (s.sent + s.sent_len):
# return child
#else:
# return NULL
"""
cdef uint32_t kids = head.r_kids
if kids == 0:
return NULL
cdef int offset = _nth_significant_bit(kids, idx)
cdef const TokenC* child = head + offset
if child < (s.sent + s.sent_len):
return child
else:
return NULL
"""
cdef int count_left_kids(const TokenC* head) nogil:
return _popcount(head.l_kids)

View File

@ -6,6 +6,5 @@ from thinc.typedefs cimport weight_t
from ._state cimport State
from .transition_system cimport TransitionSystem, Transition
cdef class ArcEager(TransitionSystem):
pass

View File

@ -22,7 +22,7 @@ from libc.stdint cimport uint32_t
from libc.string cimport memcpy
from cymem.cymem cimport Pool
from ..stateclass cimport StateClass
from .stateclass cimport StateClass
DEF NON_MONOTONIC = True
@ -59,32 +59,63 @@ MOVE_NAMES[ADJUST] = 'A'
# Helper functions for the arc-eager oracle
cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -1:
# When we push a word, we can't make arcs to or from the stack. So, we lose
# any of those arcs.
cdef StateClass stcls = StateClass(st.sent_len)
stcls.from_struct(st)
cdef int cost = 0
cost += head_in_stack(st, target, gold.heads)
cost += children_in_stack(st, target, gold.heads)
# If we can Break, we shouldn't push
cdef int i, S_i
for i in range(stcls.stack_depth()):
S_i = stcls.S(i)
if gold.heads[target] == S_i:
cost += 1
if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
cost += 1
cost += Break.is_valid(st, -1) and Break.move_cost(st, gold) == 0
return cost
# When we push a word, we can't make arcs to or from the stack. So, we lose
# any of those arcs.
#cost += head_in_stack(st, target, gold.heads)
#cost += children_in_stack(st, target, gold.heads)
# If we can Break, we shouldn't push
#cost += Break.is_valid(st, -1) and Break.move_cost(st, gold) == 0
#return cost
cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1:
cdef StateClass stcls = StateClass(st.sent_len)
stcls.from_struct(st)
cdef int cost = 0
cost += children_in_buffer(st, target, gold.heads)
cost += head_in_buffer(st, target, gold.heads)
cdef int i, B_i
for i in range(stcls.buffer_length()):
B_i = stcls.B(i)
cost += gold.heads[B_i] == target
cost += gold.heads[target] == B_i
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
break
return cost
#cost += children_in_buffer(st, target, gold.heads)
#cost += head_in_buffer(st, target, gold.heads)
#return cost
cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child) except -1:
cdef StateClass stcls = StateClass(st.sent_len)
stcls.from_struct(st)
if arc_is_gold(gold, head, child):
return 0
elif (child + st.sent[child].head) == gold.heads[child]:
elif stcls.H(child) == gold.heads[child]:
return 1
elif gold.heads[child] >= st.i:
elif gold.heads[child] >= stcls.B(0):
return 1
else:
return 0
#if arc_is_gold(gold, head, child):
# return 0
#elif (child + st.sent[child].head) == gold.heads[child]:
# return 1
#elif gold.heads[child] >= st.i:
# return 1
#else:
# return 0
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
@ -122,7 +153,6 @@ cdef class Shift:
cdef bint _new_is_valid(StateClass st, int label) except -1:
return not st.eol()
@staticmethod
cdef int transition(State* state, int label) except -1:
# Set the dep label, in case we need it after we reduce
@ -596,14 +626,17 @@ cdef class ArcEager(TransitionSystem):
state.sent[i].dep = root_label
cdef int set_valid(self, bint* output, const State* state) except -1:
raise Exception
cdef StateClass stcls = StateClass(state.sent_len)
stcls.from_struct(state)
cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = Shift.is_valid(state, -1)
is_valid[REDUCE] = Reduce.is_valid(state, -1)
is_valid[LEFT] = LeftArc.is_valid(state, -1)
is_valid[RIGHT] = RightArc.is_valid(state, -1)
is_valid[BREAK] = Break.is_valid(state, -1)
is_valid[CONSTITUENT] = Constituent.is_valid(state, -1)
is_valid[ADJUST] = Adjust.is_valid(state, -1)
is_valid[SHIFT] = Shift._new_is_valid(stcls, -1)
is_valid[REDUCE] = Reduce._new_is_valid(stcls, -1)
is_valid[LEFT] = LeftArc._new_is_valid(stcls, -1)
is_valid[RIGHT] = RightArc._new_is_valid(stcls, -1)
is_valid[BREAK] = Break._new_is_valid(stcls, -1)
is_valid[CONSTITUENT] = False # Constituent.is_valid(state, -1)
is_valid[ADJUST] = False # Adjust.is_valid(state, -1)
cdef int i
for i in range(self.n_moves):
output[i] = is_valid[self.c[i].move]
@ -641,10 +674,10 @@ cdef class ArcEager(TransitionSystem):
output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label)
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
cdef Pool mem = Pool()
cdef StateClass stcls = StateClass.from_struct(mem, s)
assert s is not NULL
cdef StateClass stcls = StateClass(s.sent_len)
stcls.from_struct(s)
cdef bint[N_MOVES] is_valid
#is_valid[SHIFT] = Shift.is_valid(s, -1)
is_valid[SHIFT] = Shift._new_is_valid(stcls, -1)
is_valid[REDUCE] = Reduce._new_is_valid(stcls, -1)
is_valid[LEFT] = LeftArc._new_is_valid(stcls, -1)

View File

@ -1,4 +1,5 @@
# cython: profile=True
# cython: experimental_cpp_class_def=True
"""
MALT-style dependency parser
"""
@ -38,7 +39,9 @@ from ._state cimport State, new_state, copy_state, is_final, push_stack, get_lef
from ..gold cimport GoldParse
from . import _parse_features
from ._parse_features cimport fill_context, CONTEXT_SIZE
from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport _new_fill_context as fill_context
#from ._parse_features cimport fill_context
DEBUG = False

View File

@ -2,14 +2,11 @@ from libc.string cimport memcpy, memset
from cymem.cymem cimport Pool
from structs cimport TokenC
from ..structs cimport TokenC
from .syntax._state cimport State
from ._state cimport State
from .vocab cimport EMPTY_LEXEME
cdef TokenC EMPTY_TOKEN
from ..vocab cimport EMPTY_LEXEME
cdef class StateClass:
@ -17,45 +14,13 @@ cdef class StateClass:
cdef int* _stack
cdef int* _buffer
cdef TokenC* _sent
cdef TokenC _empty_token
cdef int length
cdef int _s_i
cdef int _b_i
@staticmethod
cdef inline StateClass init(const TokenC* sent, int length):
cdef StateClass self = StateClass(length)
memcpy(self._sent, sent, sizeof(TokenC*) * length)
return self
cdef int from_struct(self, const State* state) except -1
@staticmethod
cdef inline StateClass from_struct(Pool mem, const State* state):
cdef StateClass self = StateClass.init(state.sent, state.sent_len)
memcpy(self._stack, state.stack - state.stack_len, sizeof(int) * state.stack_len)
self._s_i = state.stack_len - 1
self._b_i = state.i
return self
cdef inline const TokenC* S_(self, int i) nogil:
return self.safe_get(self.S(i))
cdef inline const TokenC* B_(self, int i) nogil:
return self.safe_get(self.B(i))
cdef inline const TokenC* H_(self, int i) nogil:
return self.safe_get(self.B(i))
cdef inline const TokenC* L_(self, int i, int idx) nogil:
return self.safe_get(self.L(i, idx))
cdef inline const TokenC* R_(self, int i, int idx) nogil:
return self.safe_get(self.R(i, idx))
cdef inline const TokenC* safe_get(self, int i) nogil:
if 0 >= i >= self.length:
return &EMPTY_TOKEN
else:
return self._sent
cdef int S(self, int i) nogil
cdef int B(self, int i) nogil
@ -64,6 +29,16 @@ cdef class StateClass:
cdef int L(self, int i, int idx) nogil
cdef int R(self, int i, int idx) nogil
cdef const TokenC* S_(self, int i) nogil
cdef const TokenC* B_(self, int i) nogil
cdef const TokenC* H_(self, int i) nogil
cdef const TokenC* L_(self, int i, int idx) nogil
cdef const TokenC* R_(self, int i, int idx) nogil
cdef const TokenC* safe_get(self, int i) nogil
cdef bint empty(self) nogil
cdef bint eol(self) nogil
@ -72,6 +47,10 @@ cdef class StateClass:
cdef bint has_head(self, int i) nogil
cdef int n_L(self, int i) nogil
cdef int n_R(self, int i) nogil
cdef bint stack_is_connected(self) nogil
cdef int stack_depth(self) nogil

View File

@ -1,24 +1,33 @@
from libc.string cimport memcpy, memset
from libc.stdint cimport uint32_t
from .vocab cimport EMPTY_LEXEME
memset(&EMPTY_TOKEN, 0, sizeof(TokenC))
EMPTY_TOKEN.lex = &EMPTY_LEXEME
from ..vocab cimport EMPTY_LEXEME
cdef class StateClass:
def __cinit__(self, int length):
self.mem = Pool()
self._stack = <int*>self.mem.alloc(sizeof(int), length)
self._buffer = <int*>self.mem.alloc(sizeof(int), length)
self._sent = <TokenC*>self.mem.alloc(sizeof(TokenC*), length)
self.length = 0
for i in range(self.length):
def __init__(self, int length):
cdef Pool mem = Pool()
self._buffer = <int*>mem.alloc(length, sizeof(int))
self._stack = <int*>mem.alloc(length, sizeof(int))
self._sent = <TokenC*>mem.alloc(length, sizeof(TokenC))
self.mem = mem
self.length = length
self._s_i = 0
self._b_i = 0
cdef int i
for i in range(length):
self._buffer[i] = i
self._empty_token.lex = &EMPTY_LEXEME
cdef int from_struct(self, const State* state) except -1:
self._s_i = state.stack_len
self._b_i = state.i
memcpy(self._sent, state.sent, sizeof(TokenC) * self.length)
cdef int i
for i in range(state.stack_len):
self._stack[self._s_i - (i+1)] = state.stack[-i]
cdef int S(self, int i) nogil:
if self._s_i - (i+1) < 0:
if i >= self._s_i:
return -1
return self._stack[self._s_i - (i+1)]
@ -33,14 +42,71 @@ cdef class StateClass:
return self._sent[i].head + i
cdef int L(self, int i, int idx) nogil:
if 0 <= _popcount(self.safe_get(i).l_kids) <= idx:
if idx < 1:
return -1
return _nth_significant_bit(self.safe_get(i).l_kids, idx)
if i < 0 or i >= self.length:
return -1
cdef const TokenC* target = &self._sent[i]
cdef const TokenC* ptr = self._sent
while ptr < target:
# If this head is still to the right of us, we can skip to it
# No token that's between this token and this head could be our
# child.
if (ptr.head >= 1) and (ptr + ptr.head) < target:
ptr += ptr.head
elif ptr + ptr.head == target:
idx -= 1
if idx == 0:
return ptr - self._sent
ptr += 1
else:
ptr += 1
return -1
cdef int R(self, int i, int idx) nogil:
if 0 <= _popcount(self.safe_get(i).r_kids) <= idx:
if idx < 1:
return -1
return _nth_significant_bit(self.safe_get(i).r_kids, idx)
if i < 0 or i >= self.length:
return -1
cdef const TokenC* ptr = self._sent + (self.length - 1)
cdef const TokenC* target = &self._sent[i]
while ptr > target:
# If this head is still to the right of us, we can skip to it
# No token that's between this token and this head could be our
# child.
if (ptr.head < 0) and ((ptr + ptr.head) > target):
ptr += ptr.head
elif ptr + ptr.head == target:
idx -= 1
if idx == 0:
return ptr - self._sent
ptr -= 1
else:
ptr -= 1
return -1
cdef const TokenC* S_(self, int i) nogil:
return self.safe_get(self.S(i))
cdef const TokenC* B_(self, int i) nogil:
return self.safe_get(self.B(i))
cdef const TokenC* H_(self, int i) nogil:
return self.safe_get(self.B(i))
cdef const TokenC* L_(self, int i, int idx) nogil:
return self.safe_get(self.L(i, idx))
cdef const TokenC* R_(self, int i, int idx) nogil:
return self.safe_get(self.R(i, idx))
cdef const TokenC* safe_get(self, int i) nogil:
if i < 0 or i >= self.length:
return &self._empty_token
else:
return &self._sent[i]
cdef bint empty(self) nogil:
return self._s_i <= 0
@ -54,6 +120,12 @@ cdef class StateClass:
cdef bint has_head(self, int i) nogil:
return self.safe_get(i).head != 0
cdef int n_L(self, int i) nogil:
return _popcount(self.safe_get(i).l_kids)
cdef int n_R(self, int i) nogil:
return _popcount(self.safe_get(i).r_kids)
cdef bint stack_is_connected(self) nogil:
return False

View File

@ -51,10 +51,3 @@ cdef class TransitionSystem:
cdef Transition best_gold(self, const weight_t* scores, const State* state,
GoldParse gold) except *
#cdef class PyState:
# """Provide a Python class for testing purposes."""
# cdef Pool mem
# cdef TransitionSystem system
# cdef State* _state

View File

@ -3,6 +3,8 @@ from ._state cimport State
from ..structs cimport TokenC
from thinc.typedefs cimport weight_t
from .stateclass cimport StateClass
cdef weight_t MIN_SCORE = -90000
@ -55,6 +57,8 @@ cdef class TransitionSystem:
cdef Transition best_gold(self, const weight_t* scores, const State* s,
GoldParse gold) except *:
cdef StateClass stcls = StateClass(s.sent_len)
stcls.from_struct(s)
cdef Transition best
cdef weight_t score = MIN_SCORE
cdef int i
@ -65,39 +69,3 @@ cdef class TransitionSystem:
score = scores[i]
assert score > MIN_SCORE
return best
#cdef class PyState:
# """Provide a Python class for testing purposes."""
# def __init__(self, GoldParse gold):
# self.mem = Pool()
# self.system = EntityRecognition(labels)
# self._state = init_state(self.mem, tokens, gold.length)
#
# def transition(self, name):
# cdef const Transition* trans = self._transition_by_name(name)
# trans.do(trans, self._state)
#
# def is_valid(self, name):
# cdef const Transition* trans = self._transition_by_name(name)
# return _is_valid(trans.move, trans.label, self._state)
#
# def is_gold(self, name):
# cdef const Transition* trans = self._transition_by_name(name)
# return _get_const(trans, self._state, self._gold)
#
# property ent:
# def __get__(self):
# pass
#
# property n_ents:
# def __get__(self):
# pass
#
# property i:
# def __get__(self):
# pass
#
# property open_entity:
# def __get__(self):
# return entity_is_open(self._s)