Use ordered dict to specify actions

This commit is contained in:
Matthew Honnibal 2017-05-27 15:50:21 -05:00
parent 655ca58c16
commit 99316fa631
2 changed files with 30 additions and 15 deletions

View File

@ -9,6 +9,7 @@ import ctypes
from libc.stdint cimport uint32_t from libc.stdint cimport uint32_t
from libc.string cimport memcpy from libc.string cimport memcpy
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from collections import OrderedDict
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC, is_space_token from ._state cimport StateC, is_space_token
@ -312,12 +313,13 @@ cdef class ArcEager(TransitionSystem):
@classmethod @classmethod
def get_actions(cls, **kwargs): def get_actions(cls, **kwargs):
actions = kwargs.get('actions', actions = kwargs.get('actions',
{ OrderedDict((
SHIFT: [''], (SHIFT, ['']),
REDUCE: [''], (REDUCE, ['']),
RIGHT: [], (RIGHT, []),
LEFT: [], (LEFT, []),
BREAK: ['ROOT']}) (BREAK, ['ROOT'])
)))
seen_actions = set() seen_actions = set()
for label in kwargs.get('left_labels', []): for label in kwargs.get('left_labels', []):
if label.upper() != 'ROOT': if label.upper() != 'ROOT':

View File

@ -2,6 +2,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from thinc.typedefs cimport weight_t from thinc.typedefs cimport weight_t
from collections import OrderedDict
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC from ._state cimport StateC
@ -51,17 +52,29 @@ cdef bint _entity_is_sunk(StateClass st, Transition* golds) nogil:
cdef class BiluoPushDown(TransitionSystem): cdef class BiluoPushDown(TransitionSystem):
def __init__(self, *args, **kwargs):
TransitionSystem.__init__(self, *args, **kwargs)
def __reduce__(self):
labels_by_action = OrderedDict()
cdef Transition t
for trans in self.c[:self.n_moves]:
label_str = self.strings[trans.label]
labels_by_action.setdefault(trans.move, []).append(label_str)
return (BiluoPushDown, (self.strings, labels_by_action),
None, None)
@classmethod @classmethod
def get_actions(cls, **kwargs): def get_actions(cls, **kwargs):
actions = kwargs.get('actions', actions = kwargs.get('actions',
{ OrderedDict((
MISSING: [''], (MISSING, ['']),
BEGIN: [], (BEGIN, []),
IN: [], (IN, []),
LAST: [], (LAST, []),
UNIT: [], (UNIT, []),
OUT: [''] (OUT, [''])
}) )))
seen_entities = set() seen_entities = set()
for entity_type in kwargs.get('entity_types', []): for entity_type in kwargs.get('entity_types', []):
if entity_type in seen_entities: if entity_type in seen_entities:
@ -90,7 +103,7 @@ cdef class BiluoPushDown(TransitionSystem):
def move_name(self, int move, int label): def move_name(self, int move, int label):
if move == OUT: if move == OUT:
return 'O' return 'O'
elif move == 'MISSING': elif move == MISSING:
return 'M' return 'M'
else: else:
return MOVE_NAMES[move] + '-' + self.strings[label] return MOVE_NAMES[move] + '-' + self.strings[label]