WIP start excising GoldParse

This commit is contained in:
Matthew Honnibal 2020-06-14 17:11:41 +02:00
parent a1c5b694be
commit 7d65615625
2 changed files with 10 additions and 12 deletions

View File

@ -2,8 +2,6 @@ from cymem.cymem cimport Pool
from ..typedefs cimport attr_t, weight_t from ..typedefs cimport attr_t, weight_t
from ..structs cimport TokenC from ..structs cimport TokenC
from ..gold cimport GoldParse
from ..gold cimport GoldParseC
from ..strings cimport StringStore from ..strings cimport StringStore
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC from ._state cimport StateC
@ -17,14 +15,14 @@ cdef struct Transition:
weight_t score weight_t score
bint (*is_valid)(const StateC* state, attr_t label) nogil bint (*is_valid)(const StateC* state, attr_t label) nogil
weight_t (*get_cost)(StateClass state, const GoldParseC* gold, attr_t label) nogil weight_t (*get_cost)(StateClass state, const void* gold, attr_t label) nogil
int (*do)(StateC* state, attr_t label) nogil int (*do)(StateC* state, attr_t label) nogil
ctypedef weight_t (*get_cost_func_t)(StateClass state, const GoldParseC* gold, ctypedef weight_t (*get_cost_func_t)(StateClass state, const void* gold,
attr_tlabel) nogil attr_tlabel) nogil
ctypedef weight_t (*move_cost_func_t)(StateClass state, const GoldParseC* gold) nogil ctypedef weight_t (*move_cost_func_t)(StateClass state, const void* gold) nogil
ctypedef weight_t (*label_cost_func_t)(StateClass state, const GoldParseC* ctypedef weight_t (*label_cost_func_t)(StateClass state, const void*
gold, attr_t label) nogil gold, attr_t label) nogil
ctypedef int (*do_func_t)(StateC* state, attr_t label) nogil ctypedef int (*do_func_t)(StateC* state, attr_t label) nogil
@ -55,4 +53,4 @@ cdef class TransitionSystem:
cdef int set_valid(self, int* output, const StateC* st) nogil cdef int set_valid(self, int* output, const StateC* st) nogil
cdef int set_costs(self, int* is_valid, weight_t* costs, cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass state, GoldParse gold) except -1 StateClass state, NewExample example) except -1

View File

@ -87,7 +87,7 @@ cdef class TransitionSystem:
beams.append(beam) beams.append(beam)
return beams return beams
def get_oracle_sequence(self, doc, GoldParse gold): def get_oracle_sequence(self, NewExample example):
cdef Pool mem = Pool() cdef Pool mem = Pool()
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.n_moves > 0 assert self.n_moves > 0
@ -98,7 +98,7 @@ cdef class TransitionSystem:
self.initialize_state(state.c) self.initialize_state(state.c)
history = [] history = []
while not state.is_final(): while not state.is_final():
self.set_costs(is_valid, costs, state, gold) self.set_costs(is_valid, costs, state, example)
for i in range(self.n_moves): for i in range(self.n_moves):
if is_valid[i] and costs[i] <= 0: if is_valid[i] and costs[i] <= 0:
action = self.c[i] action = self.c[i]
@ -124,10 +124,10 @@ cdef class TransitionSystem:
def finalize_doc(self, doc): def finalize_doc(self, doc):
pass pass
def preprocess_gold(self, GoldParse gold): def preprocess_gold(self, gold):
raise NotImplementedError raise NotImplementedError
def is_gold_parse(self, StateClass state, GoldParse gold): def is_gold_parse(self, StateClass state, example):
raise NotImplementedError raise NotImplementedError
cdef Transition lookup_transition(self, object name) except *: cdef Transition lookup_transition(self, object name) except *:
@ -148,7 +148,7 @@ cdef class TransitionSystem:
is_valid[i] = self.c[i].is_valid(st, self.c[i].label) is_valid[i] = self.c[i].is_valid(st, self.c[i].label)
cdef int set_costs(self, int* is_valid, weight_t* costs, cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass stcls, GoldParse gold) except -1: StateClass stcls, NewExample example) except -1:
cdef int i cdef int i
self.set_valid(is_valid, stcls.c) self.set_valid(is_valid, stcls.c)
cdef int n_gold = 0 cdef int n_gold = 0