Use ordered dict to specify transitions

This commit is contained in:
Matthew Honnibal 2017-05-27 15:52:20 -05:00
parent 3eea5383a1
commit 7ebd26b8aa

View File

@ -5,7 +5,7 @@ from __future__ import unicode_literals
from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from thinc.typedefs cimport weight_t from thinc.typedefs cimport weight_t
from collections import defaultdict from collections import defaultdict, OrderedDict
from ..structs cimport TokenC from ..structs cimport TokenC
from .stateclass cimport StateClass from .stateclass cimport StateClass
@ -26,7 +26,7 @@ cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
cdef class TransitionSystem: cdef class TransitionSystem:
def __init__(self, StringStore string_table, dict labels_by_action): def __init__(self, StringStore string_table, labels_by_action):
self.mem = Pool() self.mem = Pool()
self.strings = string_table self.strings = string_table
self.n_moves = 0 self.n_moves = 0
@ -34,14 +34,14 @@ cdef class TransitionSystem:
self.c = <Transition*>self.mem.alloc(self._size, sizeof(Transition)) self.c = <Transition*>self.mem.alloc(self._size, sizeof(Transition))
for action, label_strs in sorted(labels_by_action.items()): for action, label_strs in labels_by_action.items():
for label_str in label_strs: for label_str in label_strs:
self.add_action(int(action), label_str) self.add_action(int(action), label_str)
self.root_label = self.strings['ROOT'] self.root_label = self.strings['ROOT']
self.init_beam_state = _init_state self.init_beam_state = _init_state
def __reduce__(self): def __reduce__(self):
labels_by_action = {} labels_by_action = OrderedDict()
cdef Transition t cdef Transition t
for trans in self.c[:self.n_moves]: for trans in self.c[:self.n_moves]:
label_str = self.strings[trans.label] label_str = self.strings[trans.label]
@ -77,6 +77,11 @@ cdef class TransitionSystem:
history.append(i) history.append(i)
action.do(state.c, action.label) action.do(state.c, action.label)
break break
else:
print(gold.words)
print(gold.ner)
print(history)
raise ValueError("Could not find gold move")
return history return history
cdef int initialize_state(self, StateC* state) nogil: cdef int initialize_state(self, StateC* state) nogil: