mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
* Begin adding stateclass to ArcEager
This commit is contained in:
parent
ba10fd8af5
commit
2b9629ed62
|
@ -1,6 +1,9 @@
|
|||
# cython: profile=True
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import ctypes
|
||||
import os
|
||||
|
||||
from ._state cimport State
|
||||
from ._state cimport has_head, get_idx, get_s0, get_n0, get_left, get_right
|
||||
from ._state cimport is_final, at_eol, pop_stack, push_stack, add_dep
|
||||
|
@ -15,6 +18,12 @@ from .transition_system cimport move_cost_func_t, label_cost_func_t
|
|||
from ..gold cimport GoldParse
|
||||
from ..gold cimport GoldParseC
|
||||
|
||||
from libc.stdint cimport uint32_t
|
||||
from libc.string cimport memcpy
|
||||
|
||||
from cymem.cymem cimport Pool
|
||||
from ..stateclass cimport StateClass
|
||||
|
||||
|
||||
DEF NON_MONOTONIC = True
|
||||
DEF USE_BREAK = True
|
||||
|
@ -78,7 +87,6 @@ cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child)
|
|||
return 0
|
||||
|
||||
|
||||
|
||||
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
|
||||
if gold.labels[child] == -1:
|
||||
return True
|
||||
|
@ -110,6 +118,11 @@ cdef class Shift:
|
|||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
return not at_eol(s)
|
||||
|
||||
@staticmethod
|
||||
cdef bint _new_is_valid(StateClass st, int label) except -1:
|
||||
return not st.eol()
|
||||
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* state, int label) except -1:
|
||||
# Set the dep label, in case we need it after we reduce
|
||||
|
@ -133,6 +146,13 @@ cdef class Shift:
|
|||
|
||||
|
||||
cdef class Reduce:
|
||||
@staticmethod
|
||||
cdef bint _new_is_valid(StateClass st, int label) except -1:
|
||||
if NON_MONOTONIC:
|
||||
return st.stack_depth() >= 2 #and not missing_brackets(s)
|
||||
else:
|
||||
return st.stack_depth() >= 2 and st.has_head(st.S(0))
|
||||
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
if NON_MONOTONIC:
|
||||
|
@ -165,6 +185,13 @@ cdef class Reduce:
|
|||
|
||||
|
||||
cdef class LeftArc:
|
||||
@staticmethod
|
||||
cdef bint _new_is_valid(StateClass st, int label) except -1:
|
||||
if NON_MONOTONIC:
|
||||
return st.stack_depth() >= 1 #and not missing_brackets(s)
|
||||
else:
|
||||
return st.stack_depth() >= 1 and not st.has_head(st.S(0))
|
||||
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
if NON_MONOTONIC:
|
||||
|
@ -206,6 +233,10 @@ cdef class RightArc:
|
|||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
return s.stack_len >= 1 and not at_eol(s)
|
||||
|
||||
@staticmethod
|
||||
cdef bint _new_is_valid(StateClass st, int label) except -1:
|
||||
return st.stack_depth() >= 1 and not st.eol()
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* state, int label) except -1:
|
||||
add_dep(state, state.stack[0], state.i, label)
|
||||
|
@ -230,6 +261,32 @@ cdef class RightArc:
|
|||
|
||||
|
||||
cdef class Break:
|
||||
@staticmethod
|
||||
cdef bint _new_is_valid(StateClass st, int label) except -1:
|
||||
cdef int i
|
||||
if not USE_BREAK:
|
||||
return False
|
||||
elif st.eol():
|
||||
return False
|
||||
elif st.stack_depth() < 1:
|
||||
return False
|
||||
elif NON_MONOTONIC:
|
||||
return True
|
||||
else:
|
||||
# In the Break transition paper, they have this constraint that prevents
|
||||
# Break if stack is disconnected. But, if we're doing non-monotonic parsing,
|
||||
# we prefer to relax this constraint. This is helpful in parsing whole
|
||||
# documents, because then we don't get stuck with words on the stack.
|
||||
seen_headless = False
|
||||
for i in range(st.stack_depth()):
|
||||
if not st.has_head(st.S(i)):
|
||||
if seen_headless:
|
||||
return False
|
||||
else:
|
||||
seen_headless = True
|
||||
# TODO: Constituency constraints
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
cdef int i
|
||||
|
@ -584,14 +641,17 @@ cdef class ArcEager(TransitionSystem):
|
|||
output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label)
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||
cdef Pool mem = Pool()
|
||||
cdef StateClass stcls = StateClass.from_struct(mem, s)
|
||||
cdef bint[N_MOVES] is_valid
|
||||
is_valid[SHIFT] = Shift.is_valid(s, -1)
|
||||
is_valid[REDUCE] = Reduce.is_valid(s, -1)
|
||||
is_valid[LEFT] = LeftArc.is_valid(s, -1)
|
||||
is_valid[RIGHT] = RightArc.is_valid(s, -1)
|
||||
is_valid[BREAK] = Break.is_valid(s, -1)
|
||||
is_valid[CONSTITUENT] = Constituent.is_valid(s, -1)
|
||||
is_valid[ADJUST] = Adjust.is_valid(s, -1)
|
||||
#is_valid[SHIFT] = Shift.is_valid(s, -1)
|
||||
is_valid[SHIFT] = Shift._new_is_valid(stcls, -1)
|
||||
is_valid[REDUCE] = Reduce._new_is_valid(stcls, -1)
|
||||
is_valid[LEFT] = LeftArc._new_is_valid(stcls, -1)
|
||||
is_valid[RIGHT] = RightArc._new_is_valid(stcls, -1)
|
||||
is_valid[BREAK] = Break._new_is_valid(stcls, -1)
|
||||
is_valid[CONSTITUENT] = False # Constituent._new_is_valid(s, -1)
|
||||
is_valid[ADJUST] = False # Adjust._new_is_valid(s, -1)
|
||||
cdef Transition best
|
||||
cdef weight_t score = MIN_SCORE
|
||||
cdef int i
|
||||
|
|
Loading…
Reference in New Issue
Block a user