mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 09:44:36 +03:00
* Begin adding stateclass to ArcEager
This commit is contained in:
parent
ba10fd8af5
commit
2b9629ed62
|
@ -1,6 +1,9 @@
|
||||||
# cython: profile=True
|
# cython: profile=True
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import ctypes
|
||||||
|
import os
|
||||||
|
|
||||||
from ._state cimport State
|
from ._state cimport State
|
||||||
from ._state cimport has_head, get_idx, get_s0, get_n0, get_left, get_right
|
from ._state cimport has_head, get_idx, get_s0, get_n0, get_left, get_right
|
||||||
from ._state cimport is_final, at_eol, pop_stack, push_stack, add_dep
|
from ._state cimport is_final, at_eol, pop_stack, push_stack, add_dep
|
||||||
|
@ -15,6 +18,12 @@ from .transition_system cimport move_cost_func_t, label_cost_func_t
|
||||||
from ..gold cimport GoldParse
|
from ..gold cimport GoldParse
|
||||||
from ..gold cimport GoldParseC
|
from ..gold cimport GoldParseC
|
||||||
|
|
||||||
|
from libc.stdint cimport uint32_t
|
||||||
|
from libc.string cimport memcpy
|
||||||
|
|
||||||
|
from cymem.cymem cimport Pool
|
||||||
|
from ..stateclass cimport StateClass
|
||||||
|
|
||||||
|
|
||||||
DEF NON_MONOTONIC = True
|
DEF NON_MONOTONIC = True
|
||||||
DEF USE_BREAK = True
|
DEF USE_BREAK = True
|
||||||
|
@ -78,7 +87,6 @@ cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
|
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
|
||||||
if gold.labels[child] == -1:
|
if gold.labels[child] == -1:
|
||||||
return True
|
return True
|
||||||
|
@ -110,6 +118,11 @@ cdef class Shift:
|
||||||
cdef bint is_valid(const State* s, int label) except -1:
|
cdef bint is_valid(const State* s, int label) except -1:
|
||||||
return not at_eol(s)
|
return not at_eol(s)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef bint _new_is_valid(StateClass st, int label) except -1:
|
||||||
|
return not st.eol()
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(State* state, int label) except -1:
|
cdef int transition(State* state, int label) except -1:
|
||||||
# Set the dep label, in case we need it after we reduce
|
# Set the dep label, in case we need it after we reduce
|
||||||
|
@ -133,6 +146,13 @@ cdef class Shift:
|
||||||
|
|
||||||
|
|
||||||
cdef class Reduce:
|
cdef class Reduce:
|
||||||
|
@staticmethod
|
||||||
|
cdef bint _new_is_valid(StateClass st, int label) except -1:
|
||||||
|
if NON_MONOTONIC:
|
||||||
|
return st.stack_depth() >= 2 #and not missing_brackets(s)
|
||||||
|
else:
|
||||||
|
return st.stack_depth() >= 2 and st.has_head(st.S(0))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(const State* s, int label) except -1:
|
cdef bint is_valid(const State* s, int label) except -1:
|
||||||
if NON_MONOTONIC:
|
if NON_MONOTONIC:
|
||||||
|
@ -165,6 +185,13 @@ cdef class Reduce:
|
||||||
|
|
||||||
|
|
||||||
cdef class LeftArc:
|
cdef class LeftArc:
|
||||||
|
@staticmethod
|
||||||
|
cdef bint _new_is_valid(StateClass st, int label) except -1:
|
||||||
|
if NON_MONOTONIC:
|
||||||
|
return st.stack_depth() >= 1 #and not missing_brackets(s)
|
||||||
|
else:
|
||||||
|
return st.stack_depth() >= 1 and not st.has_head(st.S(0))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(const State* s, int label) except -1:
|
cdef bint is_valid(const State* s, int label) except -1:
|
||||||
if NON_MONOTONIC:
|
if NON_MONOTONIC:
|
||||||
|
@ -206,6 +233,10 @@ cdef class RightArc:
|
||||||
cdef bint is_valid(const State* s, int label) except -1:
|
cdef bint is_valid(const State* s, int label) except -1:
|
||||||
return s.stack_len >= 1 and not at_eol(s)
|
return s.stack_len >= 1 and not at_eol(s)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef bint _new_is_valid(StateClass st, int label) except -1:
|
||||||
|
return st.stack_depth() >= 1 and not st.eol()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(State* state, int label) except -1:
|
cdef int transition(State* state, int label) except -1:
|
||||||
add_dep(state, state.stack[0], state.i, label)
|
add_dep(state, state.stack[0], state.i, label)
|
||||||
|
@ -230,6 +261,32 @@ cdef class RightArc:
|
||||||
|
|
||||||
|
|
||||||
cdef class Break:
|
cdef class Break:
|
||||||
|
@staticmethod
|
||||||
|
cdef bint _new_is_valid(StateClass st, int label) except -1:
|
||||||
|
cdef int i
|
||||||
|
if not USE_BREAK:
|
||||||
|
return False
|
||||||
|
elif st.eol():
|
||||||
|
return False
|
||||||
|
elif st.stack_depth() < 1:
|
||||||
|
return False
|
||||||
|
elif NON_MONOTONIC:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# In the Break transition paper, they have this constraint that prevents
|
||||||
|
# Break if stack is disconnected. But, if we're doing non-monotonic parsing,
|
||||||
|
# we prefer to relax this constraint. This is helpful in parsing whole
|
||||||
|
# documents, because then we don't get stuck with words on the stack.
|
||||||
|
seen_headless = False
|
||||||
|
for i in range(st.stack_depth()):
|
||||||
|
if not st.has_head(st.S(i)):
|
||||||
|
if seen_headless:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
seen_headless = True
|
||||||
|
# TODO: Constituency constraints
|
||||||
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(const State* s, int label) except -1:
|
cdef bint is_valid(const State* s, int label) except -1:
|
||||||
cdef int i
|
cdef int i
|
||||||
|
@ -584,14 +641,17 @@ cdef class ArcEager(TransitionSystem):
|
||||||
output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label)
|
output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label)
|
||||||
|
|
||||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||||
|
cdef Pool mem = Pool()
|
||||||
|
cdef StateClass stcls = StateClass.from_struct(mem, s)
|
||||||
cdef bint[N_MOVES] is_valid
|
cdef bint[N_MOVES] is_valid
|
||||||
is_valid[SHIFT] = Shift.is_valid(s, -1)
|
#is_valid[SHIFT] = Shift.is_valid(s, -1)
|
||||||
is_valid[REDUCE] = Reduce.is_valid(s, -1)
|
is_valid[SHIFT] = Shift._new_is_valid(stcls, -1)
|
||||||
is_valid[LEFT] = LeftArc.is_valid(s, -1)
|
is_valid[REDUCE] = Reduce._new_is_valid(stcls, -1)
|
||||||
is_valid[RIGHT] = RightArc.is_valid(s, -1)
|
is_valid[LEFT] = LeftArc._new_is_valid(stcls, -1)
|
||||||
is_valid[BREAK] = Break.is_valid(s, -1)
|
is_valid[RIGHT] = RightArc._new_is_valid(stcls, -1)
|
||||||
is_valid[CONSTITUENT] = Constituent.is_valid(s, -1)
|
is_valid[BREAK] = Break._new_is_valid(stcls, -1)
|
||||||
is_valid[ADJUST] = Adjust.is_valid(s, -1)
|
is_valid[CONSTITUENT] = False # Constituent._new_is_valid(s, -1)
|
||||||
|
is_valid[ADJUST] = False # Adjust._new_is_valid(s, -1)
|
||||||
cdef Transition best
|
cdef Transition best
|
||||||
cdef weight_t score = MIN_SCORE
|
cdef weight_t score = MIN_SCORE
|
||||||
cdef int i
|
cdef int i
|
||||||
|
|
Loading…
Reference in New Issue
Block a user