Constrain subtok label to adjacent tokens

This commit is contained in:
Matthew Honnibal 2018-04-08 15:20:16 +02:00
parent 8f21953fc5
commit 6d0fe67b72

View File

@ -11,6 +11,8 @@ from thinc.extra.search cimport Beam
import json import json
from .nonproj import is_nonproj_tree from .nonproj import is_nonproj_tree
from ..typedefs cimport hash_t, attr_t
from ..strings cimport hash_string
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC from ._state cimport StateC
from . import nonproj from . import nonproj
@ -21,12 +23,12 @@ from ..errors import Errors
# Calculate cost as gold/not gold. We don't use scalar value anyway. # Calculate cost as gold/not gold. We don't use scalar value anyway.
cdef int BINARY_COSTS = 1 cdef int BINARY_COSTS = 1
cdef weight_t MIN_SCORE = -90000
cdef attr_t SUBTOK_LABEL = hash_string('subtok')
DEF NON_MONOTONIC = True DEF NON_MONOTONIC = True
DEF USE_BREAK = True DEF USE_BREAK = True
cdef weight_t MIN_SCORE = -90000
# Break transition from here # Break transition from here
# http://www.aclweb.org/anthology/P13-1074 # http://www.aclweb.org/anthology/P13-1074
cdef enum: cdef enum:
@ -180,7 +182,7 @@ cdef class Reduce:
cdef class LeftArc: cdef class LeftArc:
@staticmethod @staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef bint is_valid(const StateC* st, attr_t label) nogil:
if label == SUBTOK_LABEL and st.S(0).i != (st.B(0).i-1): if label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
return 0 return 0
sent_start = st._sent[st.B_(0).l_edge].sent_start sent_start = st._sent[st.B_(0).l_edge].sent_start
return sent_start != 1 return sent_start != 1
@ -218,7 +220,7 @@ cdef class RightArc:
@staticmethod @staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef bint is_valid(const StateC* st, attr_t label) nogil:
# If there's (perhaps partial) parse pre-set, don't allow cycle. # If there's (perhaps partial) parse pre-set, don't allow cycle.
if label == SUBTOK_LABEL and st.S(0).i != (st.B(0).i-1): if label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
return 0 return 0
sent_start = st._sent[st.B_(0).l_edge].sent_start sent_start = st._sent[st.B_(0).l_edge].sent_start
return sent_start != 1 and st.H(st.S(0)) != st.B(0) return sent_start != 1 and st.H(st.S(0)) != st.B(0)
@ -540,7 +542,10 @@ cdef class ArcEager(TransitionSystem):
is_valid[BREAK] = Break.is_valid(st, 0) is_valid[BREAK] = Break.is_valid(st, 0)
cdef int i cdef int i
for i in range(self.n_moves): for i in range(self.n_moves):
output[i] = is_valid[self.c[i].move] if self.c[i].label == SUBTOK_LABEL:
output[i] = self.c[i].is_valid(st, self.c[i].label)
else:
output[i] = is_valid[self.c[i].move]
cdef int set_costs(self, int* is_valid, weight_t* costs, cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass stcls, GoldParse gold) except -1: StateClass stcls, GoldParse gold) except -1: