* Prepare to switch to using state class, instead of state struct

This commit is contained in:
Matthew Honnibal 2015-06-09 21:20:14 +02:00
parent 2b9629ed62
commit 0895d454fb
10 changed files with 245 additions and 141 deletions

View File

@ -4,6 +4,7 @@ from ._state cimport State
cdef int fill_context(atom_t* context, State* state) except -1 cdef int fill_context(atom_t* context, State* state) except -1
cdef int _new_fill_context(atom_t* context, State* state) except -1
# Context elements # Context elements
# Ensure each token's attributes are listed: w, p, c, c6, c4. The order # Ensure each token's attributes are listed: w, p, c, c6, c4. The order

View File

@ -20,6 +20,11 @@ from ._state cimport has_head, get_left, get_right
from ._state cimport count_left_kids, count_right_kids from ._state cimport count_left_kids, count_right_kids
from .stateclass cimport StateClass
from cymem.cymem cimport Pool
cdef inline void fill_token(atom_t* context, const TokenC* token) nogil: cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
if token is NULL: if token is NULL:
context[0] = 0 context[0] = 0
@ -60,6 +65,53 @@ cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
context[10] = token.ent_iob context[10] = token.ent_iob
context[11] = token.ent_type context[11] = token.ent_type
cdef int _new_fill_context(atom_t* ctxt, State* state) except -1:
# Take care to fill every element of context!
# We could memset, but this makes it very easy to have broken features that
# make almost no impact on accuracy. If instead they're unset, the impact
# tends to be dramatic, so we get an obvious regression to fix...
cdef StateClass st = StateClass(state.sent_len)
st.from_struct(state)
fill_token(&ctxt[S2w], st.S_(2))
fill_token(&ctxt[S1w], st.S_(1))
fill_token(&ctxt[S1rw], st.R_(st.S(1), 1))
fill_token(&ctxt[S0lw], st.L_(st.S(0), 1))
fill_token(&ctxt[S0l2w], st.L_(st.S(0), 2))
fill_token(&ctxt[S0w], st.S_(0))
fill_token(&ctxt[S0r2w], st.R_(st.S(0), 2))
fill_token(&ctxt[S0rw], st.R_(st.S(0), 1))
fill_token(&ctxt[N0lw], st.L_(st.B(0), 1))
fill_token(&ctxt[N0l2w], st.L_(st.B(0), 2))
fill_token(&ctxt[N0w], st.B_(0))
fill_token(&ctxt[N1w], st.B_(1))
fill_token(&ctxt[N2w], st.B_(2))
fill_token(&ctxt[P1w], st.safe_get(st.B(0)-1))
fill_token(&ctxt[P2w], st.safe_get(st.B(0)-2))
# TODO
fill_token(&ctxt[E0w], get_e0(state))
fill_token(&ctxt[E1w], get_e1(state))
if st.stack_depth() >= 1 and not st.eol():
ctxt[dist] = min(st.S(0) - st.B(0), 5) # TODO: This is backwards!!
else:
ctxt[dist] = 0
ctxt[N0lv] = min(st.n_L(st.B(0)), 5)
ctxt[S0lv] = min(st.n_L(st.S(0)), 5)
ctxt[S0rv] = min(st.n_R(st.S(0)), 5)
ctxt[S1lv] = min(st.n_L(st.S(1)), 5)
ctxt[S1rv] = min(st.n_R(st.S(1)), 5)
ctxt[S0_has_head] = 0
ctxt[S1_has_head] = 0
ctxt[S2_has_head] = 0
if st.stack_depth() >= 1:
ctxt[S0_has_head] = st.has_head(st.S(0)) + 1
if st.stack_depth() >= 2:
ctxt[S1_has_head] = st.has_head(st.S(1)) + 1
if st.stack_depth() >= 3:
ctxt[S2_has_head] = st.has_head(st.S(2)) + 1
cdef int fill_context(atom_t* context, State* state) except -1: cdef int fill_context(atom_t* context, State* state) except -1:
# Take care to fill every element of context! # Take care to fill every element of context!

View File

@ -115,29 +115,33 @@ cdef bint has_head(const TokenC* t) nogil:
cdef const TokenC* get_left(const State* s, const TokenC* head, const int idx) nogil: cdef const TokenC* get_left(const State* s, const TokenC* head, const int idx) nogil:
return _new_get_left(s, head, idx) return _new_get_left(s, head, idx)
#cdef uint32_t kids = head.l_kids
#if kids == 0:
# return NULL
#cdef int offset = _nth_significant_bit(kids, idx)
#cdef const TokenC* child = head - offset
#if child >= s.sent:
# return child
##else:
# return NULL
"""
cdef uint32_t kids = head.l_kids
if kids == 0:
return NULL
cdef int offset = _nth_significant_bit(kids, idx)
cdef const TokenC* child = head - offset
if child >= s.sent:
return child
else:
return NULL
"""
cdef const TokenC* get_right(const State* s, const TokenC* head, const int idx) nogil: cdef const TokenC* get_right(const State* s, const TokenC* head, const int idx) nogil:
return _new_get_right(s, head, idx) return _new_get_right(s, head, idx)
#cdef uint32_t kids = head.r_kids
#if kids == 0:
# return NULL
#cdef int offset = _nth_significant_bit(kids, idx)
#cdef const TokenC* child = head + offset
#if child < (s.sent + s.sent_len):
# return child
#else:
# return NULL
"""
cdef uint32_t kids = head.r_kids
if kids == 0:
return NULL
cdef int offset = _nth_significant_bit(kids, idx)
cdef const TokenC* child = head + offset
if child < (s.sent + s.sent_len):
return child
else:
return NULL
"""
cdef int count_left_kids(const TokenC* head) nogil: cdef int count_left_kids(const TokenC* head) nogil:
return _popcount(head.l_kids) return _popcount(head.l_kids)

View File

@ -6,6 +6,5 @@ from thinc.typedefs cimport weight_t
from ._state cimport State from ._state cimport State
from .transition_system cimport TransitionSystem, Transition from .transition_system cimport TransitionSystem, Transition
cdef class ArcEager(TransitionSystem): cdef class ArcEager(TransitionSystem):
pass pass

View File

@ -22,7 +22,7 @@ from libc.stdint cimport uint32_t
from libc.string cimport memcpy from libc.string cimport memcpy
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from ..stateclass cimport StateClass from .stateclass cimport StateClass
DEF NON_MONOTONIC = True DEF NON_MONOTONIC = True
@ -59,32 +59,63 @@ MOVE_NAMES[ADJUST] = 'A'
# Helper functions for the arc-eager oracle # Helper functions for the arc-eager oracle
cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -1: cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -1:
# When we push a word, we can't make arcs to or from the stack. So, we lose cdef StateClass stcls = StateClass(st.sent_len)
# any of those arcs. stcls.from_struct(st)
cdef int cost = 0 cdef int cost = 0
cost += head_in_stack(st, target, gold.heads) cdef int i, S_i
cost += children_in_stack(st, target, gold.heads) for i in range(stcls.stack_depth()):
# If we can Break, we shouldn't push S_i = stcls.S(i)
if gold.heads[target] == S_i:
cost += 1
if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
cost += 1
cost += Break.is_valid(st, -1) and Break.move_cost(st, gold) == 0 cost += Break.is_valid(st, -1) and Break.move_cost(st, gold) == 0
return cost return cost
# When we push a word, we can't make arcs to or from the stack. So, we lose
# any of those arcs.
#cost += head_in_stack(st, target, gold.heads)
#cost += children_in_stack(st, target, gold.heads)
# If we can Break, we shouldn't push
#cost += Break.is_valid(st, -1) and Break.move_cost(st, gold) == 0
#return cost
cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1: cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1:
cdef StateClass stcls = StateClass(st.sent_len)
stcls.from_struct(st)
cdef int cost = 0 cdef int cost = 0
cost += children_in_buffer(st, target, gold.heads) cdef int i, B_i
cost += head_in_buffer(st, target, gold.heads) for i in range(stcls.buffer_length()):
B_i = stcls.B(i)
cost += gold.heads[B_i] == target
cost += gold.heads[target] == B_i
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
break
return cost return cost
#cost += children_in_buffer(st, target, gold.heads)
#cost += head_in_buffer(st, target, gold.heads)
#return cost
cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child) except -1: cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child) except -1:
cdef StateClass stcls = StateClass(st.sent_len)
stcls.from_struct(st)
if arc_is_gold(gold, head, child): if arc_is_gold(gold, head, child):
return 0 return 0
elif (child + st.sent[child].head) == gold.heads[child]: elif stcls.H(child) == gold.heads[child]:
return 1 return 1
elif gold.heads[child] >= st.i: elif gold.heads[child] >= stcls.B(0):
return 1 return 1
else: else:
return 0 return 0
#if arc_is_gold(gold, head, child):
# return 0
#elif (child + st.sent[child].head) == gold.heads[child]:
# return 1
#elif gold.heads[child] >= st.i:
# return 1
#else:
# return 0
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1: cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
@ -122,7 +153,6 @@ cdef class Shift:
cdef bint _new_is_valid(StateClass st, int label) except -1: cdef bint _new_is_valid(StateClass st, int label) except -1:
return not st.eol() return not st.eol()
@staticmethod @staticmethod
cdef int transition(State* state, int label) except -1: cdef int transition(State* state, int label) except -1:
# Set the dep label, in case we need it after we reduce # Set the dep label, in case we need it after we reduce
@ -596,14 +626,17 @@ cdef class ArcEager(TransitionSystem):
state.sent[i].dep = root_label state.sent[i].dep = root_label
cdef int set_valid(self, bint* output, const State* state) except -1: cdef int set_valid(self, bint* output, const State* state) except -1:
raise Exception
cdef StateClass stcls = StateClass(state.sent_len)
stcls.from_struct(state)
cdef bint[N_MOVES] is_valid cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = Shift.is_valid(state, -1) is_valid[SHIFT] = Shift._new_is_valid(stcls, -1)
is_valid[REDUCE] = Reduce.is_valid(state, -1) is_valid[REDUCE] = Reduce._new_is_valid(stcls, -1)
is_valid[LEFT] = LeftArc.is_valid(state, -1) is_valid[LEFT] = LeftArc._new_is_valid(stcls, -1)
is_valid[RIGHT] = RightArc.is_valid(state, -1) is_valid[RIGHT] = RightArc._new_is_valid(stcls, -1)
is_valid[BREAK] = Break.is_valid(state, -1) is_valid[BREAK] = Break._new_is_valid(stcls, -1)
is_valid[CONSTITUENT] = Constituent.is_valid(state, -1) is_valid[CONSTITUENT] = False # Constituent.is_valid(state, -1)
is_valid[ADJUST] = Adjust.is_valid(state, -1) is_valid[ADJUST] = False # Adjust.is_valid(state, -1)
cdef int i cdef int i
for i in range(self.n_moves): for i in range(self.n_moves):
output[i] = is_valid[self.c[i].move] output[i] = is_valid[self.c[i].move]
@ -641,10 +674,10 @@ cdef class ArcEager(TransitionSystem):
output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label) output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label)
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
cdef Pool mem = Pool() assert s is not NULL
cdef StateClass stcls = StateClass.from_struct(mem, s) cdef StateClass stcls = StateClass(s.sent_len)
stcls.from_struct(s)
cdef bint[N_MOVES] is_valid cdef bint[N_MOVES] is_valid
#is_valid[SHIFT] = Shift.is_valid(s, -1)
is_valid[SHIFT] = Shift._new_is_valid(stcls, -1) is_valid[SHIFT] = Shift._new_is_valid(stcls, -1)
is_valid[REDUCE] = Reduce._new_is_valid(stcls, -1) is_valid[REDUCE] = Reduce._new_is_valid(stcls, -1)
is_valid[LEFT] = LeftArc._new_is_valid(stcls, -1) is_valid[LEFT] = LeftArc._new_is_valid(stcls, -1)

View File

@ -1,4 +1,5 @@
# cython: profile=True # cython: profile=True
# cython: experimental_cpp_class_def=True
""" """
MALT-style dependency parser MALT-style dependency parser
""" """
@ -38,7 +39,9 @@ from ._state cimport State, new_state, copy_state, is_final, push_stack, get_lef
from ..gold cimport GoldParse from ..gold cimport GoldParse
from . import _parse_features from . import _parse_features
from ._parse_features cimport fill_context, CONTEXT_SIZE from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport _new_fill_context as fill_context
#from ._parse_features cimport fill_context
DEBUG = False DEBUG = False

View File

@ -2,14 +2,11 @@ from libc.string cimport memcpy, memset
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from structs cimport TokenC from ..structs cimport TokenC
from .syntax._state cimport State from ._state cimport State
from .vocab cimport EMPTY_LEXEME from ..vocab cimport EMPTY_LEXEME
cdef TokenC EMPTY_TOKEN
cdef class StateClass: cdef class StateClass:
@ -17,45 +14,13 @@ cdef class StateClass:
cdef int* _stack cdef int* _stack
cdef int* _buffer cdef int* _buffer
cdef TokenC* _sent cdef TokenC* _sent
cdef TokenC _empty_token
cdef int length cdef int length
cdef int _s_i cdef int _s_i
cdef int _b_i cdef int _b_i
@staticmethod cdef int from_struct(self, const State* state) except -1
cdef inline StateClass init(const TokenC* sent, int length):
cdef StateClass self = StateClass(length)
memcpy(self._sent, sent, sizeof(TokenC*) * length)
return self
@staticmethod
cdef inline StateClass from_struct(Pool mem, const State* state):
cdef StateClass self = StateClass.init(state.sent, state.sent_len)
memcpy(self._stack, state.stack - state.stack_len, sizeof(int) * state.stack_len)
self._s_i = state.stack_len - 1
self._b_i = state.i
return self
cdef inline const TokenC* S_(self, int i) nogil:
return self.safe_get(self.S(i))
cdef inline const TokenC* B_(self, int i) nogil:
return self.safe_get(self.B(i))
cdef inline const TokenC* H_(self, int i) nogil:
return self.safe_get(self.B(i))
cdef inline const TokenC* L_(self, int i, int idx) nogil:
return self.safe_get(self.L(i, idx))
cdef inline const TokenC* R_(self, int i, int idx) nogil:
return self.safe_get(self.R(i, idx))
cdef inline const TokenC* safe_get(self, int i) nogil:
if 0 >= i >= self.length:
return &EMPTY_TOKEN
else:
return self._sent
cdef int S(self, int i) nogil cdef int S(self, int i) nogil
cdef int B(self, int i) nogil cdef int B(self, int i) nogil
@ -64,6 +29,16 @@ cdef class StateClass:
cdef int L(self, int i, int idx) nogil cdef int L(self, int i, int idx) nogil
cdef int R(self, int i, int idx) nogil cdef int R(self, int i, int idx) nogil
cdef const TokenC* S_(self, int i) nogil
cdef const TokenC* B_(self, int i) nogil
cdef const TokenC* H_(self, int i) nogil
cdef const TokenC* L_(self, int i, int idx) nogil
cdef const TokenC* R_(self, int i, int idx) nogil
cdef const TokenC* safe_get(self, int i) nogil
cdef bint empty(self) nogil cdef bint empty(self) nogil
cdef bint eol(self) nogil cdef bint eol(self) nogil
@ -72,6 +47,10 @@ cdef class StateClass:
cdef bint has_head(self, int i) nogil cdef bint has_head(self, int i) nogil
cdef int n_L(self, int i) nogil
cdef int n_R(self, int i) nogil
cdef bint stack_is_connected(self) nogil cdef bint stack_is_connected(self) nogil
cdef int stack_depth(self) nogil cdef int stack_depth(self) nogil

View File

@ -1,24 +1,33 @@
from libc.string cimport memcpy, memset from libc.string cimport memcpy, memset
from libc.stdint cimport uint32_t from libc.stdint cimport uint32_t
from .vocab cimport EMPTY_LEXEME from ..vocab cimport EMPTY_LEXEME
memset(&EMPTY_TOKEN, 0, sizeof(TokenC))
EMPTY_TOKEN.lex = &EMPTY_LEXEME
cdef class StateClass: cdef class StateClass:
def __cinit__(self, int length): def __init__(self, int length):
self.mem = Pool() cdef Pool mem = Pool()
self._stack = <int*>self.mem.alloc(sizeof(int), length) self._buffer = <int*>mem.alloc(length, sizeof(int))
self._buffer = <int*>self.mem.alloc(sizeof(int), length) self._stack = <int*>mem.alloc(length, sizeof(int))
self._sent = <TokenC*>self.mem.alloc(sizeof(TokenC*), length) self._sent = <TokenC*>mem.alloc(length, sizeof(TokenC))
self.length = 0 self.mem = mem
for i in range(self.length): self.length = length
self._s_i = 0
self._b_i = 0
cdef int i
for i in range(length):
self._buffer[i] = i self._buffer[i] = i
self._empty_token.lex = &EMPTY_LEXEME
cdef int from_struct(self, const State* state) except -1:
self._s_i = state.stack_len
self._b_i = state.i
memcpy(self._sent, state.sent, sizeof(TokenC) * self.length)
cdef int i
for i in range(state.stack_len):
self._stack[self._s_i - (i+1)] = state.stack[-i]
cdef int S(self, int i) nogil: cdef int S(self, int i) nogil:
if self._s_i - (i+1) < 0: if i >= self._s_i:
return -1 return -1
return self._stack[self._s_i - (i+1)] return self._stack[self._s_i - (i+1)]
@ -33,14 +42,71 @@ cdef class StateClass:
return self._sent[i].head + i return self._sent[i].head + i
cdef int L(self, int i, int idx) nogil: cdef int L(self, int i, int idx) nogil:
if 0 <= _popcount(self.safe_get(i).l_kids) <= idx: if idx < 1:
return -1 return -1
return _nth_significant_bit(self.safe_get(i).l_kids, idx) if i < 0 or i >= self.length:
return -1
cdef const TokenC* target = &self._sent[i]
cdef const TokenC* ptr = self._sent
while ptr < target:
# If this head is still to the right of us, we can skip to it
# No token that's between this token and this head could be our
# child.
if (ptr.head >= 1) and (ptr + ptr.head) < target:
ptr += ptr.head
elif ptr + ptr.head == target:
idx -= 1
if idx == 0:
return ptr - self._sent
ptr += 1
else:
ptr += 1
return -1
cdef int R(self, int i, int idx) nogil: cdef int R(self, int i, int idx) nogil:
if 0 <= _popcount(self.safe_get(i).r_kids) <= idx: if idx < 1:
return -1 return -1
return _nth_significant_bit(self.safe_get(i).r_kids, idx) if i < 0 or i >= self.length:
return -1
cdef const TokenC* ptr = self._sent + (self.length - 1)
cdef const TokenC* target = &self._sent[i]
while ptr > target:
# If this head is still to the right of us, we can skip to it
# No token that's between this token and this head could be our
# child.
if (ptr.head < 0) and ((ptr + ptr.head) > target):
ptr += ptr.head
elif ptr + ptr.head == target:
idx -= 1
if idx == 0:
return ptr - self._sent
ptr -= 1
else:
ptr -= 1
return -1
cdef const TokenC* S_(self, int i) nogil:
return self.safe_get(self.S(i))
cdef const TokenC* B_(self, int i) nogil:
return self.safe_get(self.B(i))
cdef const TokenC* H_(self, int i) nogil:
return self.safe_get(self.B(i))
cdef const TokenC* L_(self, int i, int idx) nogil:
return self.safe_get(self.L(i, idx))
cdef const TokenC* R_(self, int i, int idx) nogil:
return self.safe_get(self.R(i, idx))
cdef const TokenC* safe_get(self, int i) nogil:
if i < 0 or i >= self.length:
return &self._empty_token
else:
return &self._sent[i]
cdef bint empty(self) nogil: cdef bint empty(self) nogil:
return self._s_i <= 0 return self._s_i <= 0
@ -54,6 +120,12 @@ cdef class StateClass:
cdef bint has_head(self, int i) nogil: cdef bint has_head(self, int i) nogil:
return self.safe_get(i).head != 0 return self.safe_get(i).head != 0
cdef int n_L(self, int i) nogil:
return _popcount(self.safe_get(i).l_kids)
cdef int n_R(self, int i) nogil:
return _popcount(self.safe_get(i).r_kids)
cdef bint stack_is_connected(self) nogil: cdef bint stack_is_connected(self) nogil:
return False return False

View File

@ -51,10 +51,3 @@ cdef class TransitionSystem:
cdef Transition best_gold(self, const weight_t* scores, const State* state, cdef Transition best_gold(self, const weight_t* scores, const State* state,
GoldParse gold) except * GoldParse gold) except *
#cdef class PyState:
# """Provide a Python class for testing purposes."""
# cdef Pool mem
# cdef TransitionSystem system
# cdef State* _state

View File

@ -3,6 +3,8 @@ from ._state cimport State
from ..structs cimport TokenC from ..structs cimport TokenC
from thinc.typedefs cimport weight_t from thinc.typedefs cimport weight_t
from .stateclass cimport StateClass
cdef weight_t MIN_SCORE = -90000 cdef weight_t MIN_SCORE = -90000
@ -55,6 +57,8 @@ cdef class TransitionSystem:
cdef Transition best_gold(self, const weight_t* scores, const State* s, cdef Transition best_gold(self, const weight_t* scores, const State* s,
GoldParse gold) except *: GoldParse gold) except *:
cdef StateClass stcls = StateClass(s.sent_len)
stcls.from_struct(s)
cdef Transition best cdef Transition best
cdef weight_t score = MIN_SCORE cdef weight_t score = MIN_SCORE
cdef int i cdef int i
@ -65,39 +69,3 @@ cdef class TransitionSystem:
score = scores[i] score = scores[i]
assert score > MIN_SCORE assert score > MIN_SCORE
return best return best
#cdef class PyState:
# """Provide a Python class for testing purposes."""
# def __init__(self, GoldParse gold):
# self.mem = Pool()
# self.system = EntityRecognition(labels)
# self._state = init_state(self.mem, tokens, gold.length)
#
# def transition(self, name):
# cdef const Transition* trans = self._transition_by_name(name)
# trans.do(trans, self._state)
#
# def is_valid(self, name):
# cdef const Transition* trans = self._transition_by_name(name)
# return _is_valid(trans.move, trans.label, self._state)
#
# def is_gold(self, name):
# cdef const Transition* trans = self._transition_by_name(name)
# return _get_const(trans, self._state, self._gold)
#
# property ent:
# def __get__(self):
# pass
#
# property n_ents:
# def __get__(self):
# pass
#
# property i:
# def __get__(self):
# pass
#
# property open_entity:
# def __get__(self):
# return entity_is_open(self._s)