mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-25 03:13:41 +03:00
Use ordered dict to specify actions
This commit is contained in:
parent
655ca58c16
commit
99316fa631
|
@ -9,6 +9,7 @@ import ctypes
|
||||||
from libc.stdint cimport uint32_t
|
from libc.stdint cimport uint32_t
|
||||||
from libc.string cimport memcpy
|
from libc.string cimport memcpy
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
from .stateclass cimport StateClass
|
from .stateclass cimport StateClass
|
||||||
from ._state cimport StateC, is_space_token
|
from ._state cimport StateC, is_space_token
|
||||||
|
@ -312,12 +313,13 @@ cdef class ArcEager(TransitionSystem):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_actions(cls, **kwargs):
|
def get_actions(cls, **kwargs):
|
||||||
actions = kwargs.get('actions',
|
actions = kwargs.get('actions',
|
||||||
{
|
OrderedDict((
|
||||||
SHIFT: [''],
|
(SHIFT, ['']),
|
||||||
REDUCE: [''],
|
(REDUCE, ['']),
|
||||||
RIGHT: [],
|
(RIGHT, []),
|
||||||
LEFT: [],
|
(LEFT, []),
|
||||||
BREAK: ['ROOT']})
|
(BREAK, ['ROOT'])
|
||||||
|
)))
|
||||||
seen_actions = set()
|
seen_actions = set()
|
||||||
for label in kwargs.get('left_labels', []):
|
for label in kwargs.get('left_labels', []):
|
||||||
if label.upper() != 'ROOT':
|
if label.upper() != 'ROOT':
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from thinc.typedefs cimport weight_t
|
from thinc.typedefs cimport weight_t
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
from .stateclass cimport StateClass
|
from .stateclass cimport StateClass
|
||||||
from ._state cimport StateC
|
from ._state cimport StateC
|
||||||
|
@ -51,17 +52,29 @@ cdef bint _entity_is_sunk(StateClass st, Transition* golds) nogil:
|
||||||
|
|
||||||
|
|
||||||
cdef class BiluoPushDown(TransitionSystem):
|
cdef class BiluoPushDown(TransitionSystem):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
TransitionSystem.__init__(self, *args, **kwargs)
|
||||||
|
|
||||||
|
def __reduce__(self):
|
||||||
|
labels_by_action = OrderedDict()
|
||||||
|
cdef Transition t
|
||||||
|
for trans in self.c[:self.n_moves]:
|
||||||
|
label_str = self.strings[trans.label]
|
||||||
|
labels_by_action.setdefault(trans.move, []).append(label_str)
|
||||||
|
return (BiluoPushDown, (self.strings, labels_by_action),
|
||||||
|
None, None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_actions(cls, **kwargs):
|
def get_actions(cls, **kwargs):
|
||||||
actions = kwargs.get('actions',
|
actions = kwargs.get('actions',
|
||||||
{
|
OrderedDict((
|
||||||
MISSING: [''],
|
(MISSING, ['']),
|
||||||
BEGIN: [],
|
(BEGIN, []),
|
||||||
IN: [],
|
(IN, []),
|
||||||
LAST: [],
|
(LAST, []),
|
||||||
UNIT: [],
|
(UNIT, []),
|
||||||
OUT: ['']
|
(OUT, [''])
|
||||||
})
|
)))
|
||||||
seen_entities = set()
|
seen_entities = set()
|
||||||
for entity_type in kwargs.get('entity_types', []):
|
for entity_type in kwargs.get('entity_types', []):
|
||||||
if entity_type in seen_entities:
|
if entity_type in seen_entities:
|
||||||
|
@ -90,7 +103,7 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
def move_name(self, int move, int label):
|
def move_name(self, int move, int label):
|
||||||
if move == OUT:
|
if move == OUT:
|
||||||
return 'O'
|
return 'O'
|
||||||
elif move == 'MISSING':
|
elif move == MISSING:
|
||||||
return 'M'
|
return 'M'
|
||||||
else:
|
else:
|
||||||
return MOVE_NAMES[move] + '-' + self.strings[label]
|
return MOVE_NAMES[move] + '-' + self.strings[label]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user