mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
* Allow users to add_label, in order to extend the entity recogniser to new classes. Does not by itself add a class to the model
This commit is contained in:
parent
c8e0011ebc
commit
151aa0b0e2
|
@ -31,14 +31,12 @@ ctypedef int (*do_func_t)(StateClass state, int label) nogil
|
|||
cdef class TransitionSystem:
|
||||
cdef Pool mem
|
||||
cdef StringStore strings
|
||||
cdef const Transition* c
|
||||
cdef bint* _is_valid
|
||||
cdef Transition* c
|
||||
cdef readonly int n_moves
|
||||
cdef int _size
|
||||
cdef public int root_label
|
||||
cdef public freqs
|
||||
|
||||
cdef object _labels_by_action
|
||||
|
||||
cdef int initialize_state(self, StateClass state) except -1
|
||||
cdef int finalize_state(self, StateClass state) nogil
|
||||
|
||||
|
|
|
@ -16,20 +16,17 @@ class OracleError(Exception):
|
|||
|
||||
cdef class TransitionSystem:
|
||||
def __init__(self, StringStore string_table, dict labels_by_action, _freqs=None):
|
||||
self._labels_by_action = labels_by_action
|
||||
self.mem = Pool()
|
||||
self.n_moves = sum(len(labels) for labels in labels_by_action.values())
|
||||
self._is_valid = <bint*>self.mem.alloc(self.n_moves, sizeof(bint))
|
||||
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
|
||||
cdef int i = 0
|
||||
cdef int label_id
|
||||
self.strings = string_table
|
||||
self.n_moves = 0
|
||||
self._size = 100
|
||||
|
||||
self.c = <Transition*>self.mem.alloc(self._size, sizeof(Transition))
|
||||
|
||||
for action, label_strs in sorted(labels_by_action.items()):
|
||||
for label_str in sorted(label_strs):
|
||||
label_id = self.strings[unicode(label_str)] if label_str else 0
|
||||
moves[i] = self.init_transition(i, int(action), label_id)
|
||||
i += 1
|
||||
self.c = moves
|
||||
self.add_action(int(action), label_str)
|
||||
|
||||
self.root_label = self.strings['ROOT']
|
||||
self.freqs = {} if _freqs is None else _freqs
|
||||
for attr in (TAG, HEAD, DEP, ENT_TYPE, ENT_IOB):
|
||||
|
@ -41,8 +38,13 @@ cdef class TransitionSystem:
|
|||
self.freqs[HEAD][-i] = 1
|
||||
|
||||
def __reduce__(self):
|
||||
labels_by_action = {}
|
||||
cdef Transition t
|
||||
for trans in self.c[:self.n_moves]:
|
||||
label_str = self.strings[trans.label]
|
||||
labels_by_action.setdefault(trans.move, []).append(label_str)
|
||||
return (self.__class__,
|
||||
(self.strings, self._labels_by_action, self.freqs),
|
||||
(self.strings, labels_by_action, self.freqs),
|
||||
None, None)
|
||||
|
||||
cdef int initialize_state(self, StateClass state) except -1:
|
||||
|
@ -78,3 +80,14 @@ cdef class TransitionSystem:
|
|||
costs[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
|
||||
else:
|
||||
costs[i] = 9000
|
||||
|
||||
def add_action(self, int action, label):
|
||||
if self.n_moves >= self._size:
|
||||
self._size *= 2
|
||||
self.c = <Transition*>self.mem.realloc(self.c, self._size * sizeof(self.c[0]))
|
||||
|
||||
if not isinstance(label, int):
|
||||
label = self.strings[label]
|
||||
|
||||
self.c[self.n_moves] = self.init_transition(self.n_moves, action, label)
|
||||
self.n_moves += 1
|
||||
|
|
Loading…
Reference in New Issue
Block a user