* Begin adding stateclass to ArcEager

This commit is contained in:
Matthew Honnibal 2015-06-09 01:41:09 +02:00
parent ba10fd8af5
commit 2b9629ed62

View File

@ -1,6 +1,9 @@
# cython: profile=True
from __future__ import unicode_literals
import ctypes
import os
from ._state cimport State
from ._state cimport has_head, get_idx, get_s0, get_n0, get_left, get_right
from ._state cimport is_final, at_eol, pop_stack, push_stack, add_dep
@ -15,6 +18,12 @@ from .transition_system cimport move_cost_func_t, label_cost_func_t
from ..gold cimport GoldParse
from ..gold cimport GoldParseC
from libc.stdint cimport uint32_t
from libc.string cimport memcpy
from cymem.cymem cimport Pool
from ..stateclass cimport StateClass
DEF NON_MONOTONIC = True
DEF USE_BREAK = True
@ -78,7 +87,6 @@ cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child)
return 0
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
if gold.labels[child] == -1:
return True
@ -110,6 +118,11 @@ cdef class Shift:
cdef bint is_valid(const State* s, int label) except -1:
return not at_eol(s)
@staticmethod
cdef bint _new_is_valid(StateClass st, int label) except -1:
return not st.eol()
@staticmethod
cdef int transition(State* state, int label) except -1:
# Set the dep label, in case we need it after we reduce
@ -133,6 +146,13 @@ cdef class Shift:
cdef class Reduce:
@staticmethod
cdef bint _new_is_valid(StateClass st, int label) except -1:
if NON_MONOTONIC:
return st.stack_depth() >= 2 #and not missing_brackets(s)
else:
return st.stack_depth() >= 2 and st.has_head(st.S(0))
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
if NON_MONOTONIC:
@ -165,6 +185,13 @@ cdef class Reduce:
cdef class LeftArc:
@staticmethod
cdef bint _new_is_valid(StateClass st, int label) except -1:
if NON_MONOTONIC:
return st.stack_depth() >= 1 #and not missing_brackets(s)
else:
return st.stack_depth() >= 1 and not st.has_head(st.S(0))
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
if NON_MONOTONIC:
@ -206,6 +233,10 @@ cdef class RightArc:
cdef bint is_valid(const State* s, int label) except -1:
return s.stack_len >= 1 and not at_eol(s)
@staticmethod
cdef bint _new_is_valid(StateClass st, int label) except -1:
return st.stack_depth() >= 1 and not st.eol()
@staticmethod
cdef int transition(State* state, int label) except -1:
add_dep(state, state.stack[0], state.i, label)
@ -230,6 +261,32 @@ cdef class RightArc:
cdef class Break:
@staticmethod
cdef bint _new_is_valid(StateClass st, int label) except -1:
cdef int i
if not USE_BREAK:
return False
elif st.eol():
return False
elif st.stack_depth() < 1:
return False
elif NON_MONOTONIC:
return True
else:
# In the Break transition paper, they have this constraint that prevents
# Break if stack is disconnected. But, if we're doing non-monotonic parsing,
# we prefer to relax this constraint. This is helpful in parsing whole
# documents, because then we don't get stuck with words on the stack.
seen_headless = False
for i in range(st.stack_depth()):
if not st.has_head(st.S(i)):
if seen_headless:
return False
else:
seen_headless = True
# TODO: Constituency constraints
return True
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
cdef int i
@ -584,14 +641,17 @@ cdef class ArcEager(TransitionSystem):
output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label)
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
cdef Pool mem = Pool()
cdef StateClass stcls = StateClass.from_struct(mem, s)
cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = Shift.is_valid(s, -1)
is_valid[REDUCE] = Reduce.is_valid(s, -1)
is_valid[LEFT] = LeftArc.is_valid(s, -1)
is_valid[RIGHT] = RightArc.is_valid(s, -1)
is_valid[BREAK] = Break.is_valid(s, -1)
is_valid[CONSTITUENT] = Constituent.is_valid(s, -1)
is_valid[ADJUST] = Adjust.is_valid(s, -1)
#is_valid[SHIFT] = Shift.is_valid(s, -1)
is_valid[SHIFT] = Shift._new_is_valid(stcls, -1)
is_valid[REDUCE] = Reduce._new_is_valid(stcls, -1)
is_valid[LEFT] = LeftArc._new_is_valid(stcls, -1)
is_valid[RIGHT] = RightArc._new_is_valid(stcls, -1)
is_valid[BREAK] = Break._new_is_valid(stcls, -1)
is_valid[CONSTITUENT] = False # Constituent._new_is_valid(s, -1)
is_valid[ADJUST] = False # Adjust._new_is_valid(s, -1)
cdef Transition best
cdef weight_t score = MIN_SCORE
cdef int i