from __future__ import unicode_literals from ._state cimport begin_entity from ._state cimport end_entity from ._state cimport entity_is_open from ._state cimport entity_is_sunk ACTION_NAMES = ['' for _ in range(N_ACTIONS)] ACTION_NAMES[MISSING] = '?' ACTION_NAMES[BEGIN] = 'B' ACTION_NAMES[IN] = 'I' ACTION_NAMES[LAST] = 'L' ACTION_NAMES[UNIT] = 'U' ACTION_NAMES[OUT] = 'O' cdef bint can_begin(State* s, int label): return not entity_is_open(s) cdef bint can_in(State* s, int label): return entity_is_open(s) and s.curr.label == label cdef bint can_last(State* s, int label): return entity_is_open(s) and s.curr.label == label cdef bint can_unit(State* s, int label): return not entity_is_open(s) cdef bint can_out(State* s, int label): return not entity_is_open(s) cdef bint is_oracle(ActionType act, int tag, ActionType g_act, int g_tag, ActionType next_act, bint is_sunk): if g_act == MISSING: return True if act == BEGIN: if g_act == BEGIN: # B, Gold B --> Label match return tag == g_tag else: # B, Gold I --> False (P) # B, Gold L --> False (P) # B, Gold O --> False (P) # B, Gold U --> False (P) return False elif act == IN: if g_act == BEGIN: # I, Gold B --> True (P of bad open entity sunk, R of this entity sunk) return True elif g_act == IN: # I, Gold I --> True (label forced by prev, if mismatch, P and R both sunk) return True elif g_act == LAST: # I, Gold L --> True iff this entity sunk and next tag == O return is_sunk and (next_act == OUT or next_act == MISSING) elif g_act == OUT: # I, Gold O --> True iff next tag == O return next_act == OUT or next_act == MISSING elif g_act == UNIT: # I, Gold U --> True iff next tag == O return next_act == OUT elif act == LAST: if g_act == BEGIN: # L, Gold B --> True return True elif g_act == IN: # L, Gold I --> True iff this entity sunk return is_sunk elif g_act == LAST: # L, Gold L --> True return True elif g_act == OUT: # L, Gold O --> True return True elif g_act == UNIT: # L, Gold U --> True return True elif act == OUT: if g_act == BEGIN: # O, Gold B --> False return False elif g_act == IN: # O, Gold I --> True return True elif g_act == LAST: # O, Gold L --> True return True elif g_act == OUT: # O, Gold O --> True return True elif g_act == UNIT: # O, Gold U --> False return False elif act == UNIT: if g_act == UNIT: # U, Gold U --> True iff tag match return tag == g_tag else: # U, Gold B --> False # U, Gold I --> False # U, Gold L --> False # U, Gold O --> False return False cdef int set_accept_if_valid(Move* moves, int n_classes, State* s) except 0: cdef int n_accept = 0 cdef Move* m moves[0].accept = False for i in range(1, n_classes): m = &moves[i] if m.action == BEGIN: m.accept = can_begin(s, m.label) elif m.action == IN: m.accept = can_in(s, m.label) elif m.action == LAST: m.accept = can_last(s, m.label) elif m.action == UNIT: m.accept = can_unit(s, m.label) elif m.action == OUT: m.accept = can_out(s, m.label) n_accept += m.accept assert n_accept != 0 return n_accept cdef int set_accept_if_oracle(Move* moves, Move* golds, int n_classes, State* s) except 0: cdef Move* g = &golds[s.i] cdef ActionType next_act = golds[s.i+1].action if s.i < s.length else OUT cdef bint is_sunk = entity_is_sunk(s, golds) cdef Move* m cdef int n_accept = 0 set_accept_if_valid(moves, n_classes, s) for i in range(1, n_classes): m = &moves[i] if not m.accept: continue m.accept = is_oracle(m.action, m.label, g.action, g.label, next_act, is_sunk) n_accept += m.accept assert n_accept != 0 return n_accept cdef Move* best_accepted(Move* moves, weight_t* scores, int n) except NULL: cdef int first_accept = -1 for first_accept in range(1, n): if moves[first_accept].accept: break else: raise StandardError assert first_accept != -1 cdef int best = first_accept cdef weight_t score = scores[first_accept-1] cdef int i for i in range(first_accept+1, n): if moves[i].accept and scores[i-1] > score: best = i score = scores[i-1] return &moves[best] cdef int transition(State *s, Move* move) except -1: if move.action == BEGIN: begin_entity(s, move.label) elif move.action == IN: pass elif move.action == LAST: end_entity(s) elif move.action == UNIT: begin_entity(s, move.label) end_entity(s) elif move.action == OUT: pass s.tags[s.i] = move.clas s.i += 1 def get_n_moves(n_tags): return n_tags + n_tags + n_tags + n_tags + 1 cdef int fill_moves(Move* moves, list tag_names) except -1: cdef Move* m label_names = {'-': 0} for i, tag_name in enumerate(tag_names): m = &moves[i] if '-' in tag_name: action_str, label = tag_name.split('-') elif tag_name == 'O': action_str = 'O' label = '-' elif tag_name == 'NULL' or tag_name == 'EOL': action_str = '?' label = '-' else: raise StandardError(tag_name) m.action = ACTION_NAMES.index(action_str) m.label = label_names.setdefault(label, len(label_names)) m.clas = i