* Bug fixes to NER

This commit is contained in:
Matthew Honnibal 2014-11-10 17:39:23 +11:00
parent d7b2843643
commit af9ed18cf1
5 changed files with 39 additions and 16 deletions

View File

@ -9,6 +9,7 @@ cdef struct Entity:
cdef struct State: cdef struct State:
Entity curr
Entity* ents Entity* ents
int* tags int* tags
int i int i

View File

@ -2,13 +2,16 @@ from .moves cimport BEGIN, UNIT
cdef int begin_entity(State* s, label) except -1: cdef int begin_entity(State* s, label) except -1:
s.j += 1 s.curr.start = s.i
s.ents[s.j].start = s.i s.curr.label = label
s.ents[s.j].label = label
cdef int end_entity(State* s) except -1: cdef int end_entity(State* s) except -1:
s.ents[s.j].end = s.i + 1 s.curr.end = s.i + 1
s.curr[s.j] = s.curr
s.curr.start = 0
s.curr.label = -1
s.curr.end = 0
cdef State* init_state(Pool mem, int sent_length) except NULL: cdef State* init_state(Pool mem, int sent_length) except NULL:
@ -17,24 +20,24 @@ cdef State* init_state(Pool mem, int sent_length) except NULL:
s.ents = <Entity*>mem.alloc(sent_length, sizeof(Entity)) s.ents = <Entity*>mem.alloc(sent_length, sizeof(Entity))
for i in range(sent_length): for i in range(sent_length):
s.ents[i].label = -1 s.ents[i].label = -1
s.curr.label = -1
s.tags = <int*>mem.alloc(sent_length, sizeof(int)) s.tags = <int*>mem.alloc(sent_length, sizeof(int))
s.length = sent_length s.length = sent_length
return s return s
cdef bint entity_is_open(State *s) except -1: cdef bint entity_is_open(State *s) except -1:
return s.j >= 0 and s.ents[s.j].label != -1 return s.curr.label != -1
cdef bint entity_is_sunk(State *s, Move* golds) except -1: cdef bint entity_is_sunk(State *s, Move* golds) except -1:
if not entity_is_open(s): if not entity_is_open(s):
return False return False
cdef Entity* ent = &s.ents[s.j] cdef Move* gold = &golds[s.curr.start]
cdef Move* gold = &golds[ent.start]
if gold.action != BEGIN and gold.action != UNIT: if gold.action != BEGIN and gold.action != UNIT:
return True return True
elif gold.label != ent.label: elif gold.label != s.curr.label:
return True return True
else: else:
return False return False

View File

@ -1,8 +1,11 @@
from __future__ import unicode_literals
from ._state cimport begin_entity from ._state cimport begin_entity
from ._state cimport end_entity from ._state cimport end_entity
from ._state cimport entity_is_open from ._state cimport entity_is_open
from ._state cimport entity_is_sunk from ._state cimport entity_is_sunk
ACTION_NAMES = ['' for _ in range(N_ACTIONS)] ACTION_NAMES = ['' for _ in range(N_ACTIONS)]
ACTION_NAMES[<int>BEGIN] = 'B' ACTION_NAMES[<int>BEGIN] = 'B'
ACTION_NAMES[<int>IN] = 'I' ACTION_NAMES[<int>IN] = 'I'
@ -16,11 +19,11 @@ cdef bint can_begin(State* s, int label):
cdef bint can_in(State* s, int label): cdef bint can_in(State* s, int label):
return entity_is_open(s) and s.ents[s.j].tag == label return entity_is_open(s) and s.ents[s.j].label == label
cdef bint can_last(State* s, int label): cdef bint can_last(State* s, int label):
return entity_is_open(s) and s.ents[s.j].tag == label return entity_is_open(s) and s.ents[s.j].label == label
cdef bint can_unit(State* s, int label): cdef bint can_unit(State* s, int label):
@ -119,6 +122,7 @@ cdef int set_accept_if_valid(Move* moves, int n_classes, State* s) except 0:
elif m.action == OUT: elif m.action == OUT:
m.accept = can_out(s, m.label) m.accept = can_out(s, m.label)
n_accept += m.accept n_accept += m.accept
assert n_accept != 0
return n_accept return n_accept
@ -133,6 +137,7 @@ cdef int set_accept_if_oracle(Move* moves, Move* golds, int n_classes, State* s)
m.accept = is_oracle(<ActionType>m.action, m.label, <ActionType>g.action, m.accept = is_oracle(<ActionType>m.action, m.label, <ActionType>g.action,
g.label, next_act, is_sunk) g.label, next_act, is_sunk)
n_accept += m.accept n_accept += m.accept
assert n_accept != 0
return n_accept return n_accept
@ -182,6 +187,7 @@ cdef int fill_moves(Move* moves, int n_tags) except -1:
for label in range(n_tags): for label in range(n_tags):
moves[i].action = IN moves[i].action = IN
moves[i].label = label moves[i].label = label
i += 1
for label in range(n_tags): for label in range(n_tags):
moves[i].action = LAST moves[i].action = LAST
moves[i].label = label moves[i].label = label
@ -190,4 +196,5 @@ cdef int fill_moves(Move* moves, int n_tags) except -1:
moves[i].action = UNIT moves[i].action = UNIT
moves[i].label = label moves[i].label = label
i += 1 i += 1
moves[i].label == OUT moves[i].action = OUT
moves[i].label = 0

View File

@ -12,3 +12,5 @@ cdef class PyState:
cdef Move* _moves cdef Move* _moves
cdef State* _s cdef State* _s
cdef Move* _get_move(self, unicode move_name) except NULL

View File

@ -1,7 +1,10 @@
from __future__ import unicode_literals
from ._state cimport init_state from ._state cimport init_state
from ._state cimport entity_is_open from ._state cimport entity_is_open
from .moves cimport fill_moves from .moves cimport fill_moves
from .moves cimport transition from .moves cimport transition
from .moves cimport set_accept_if_valid
from .moves import get_n_moves from .moves import get_n_moves
from .moves import ACTION_NAMES from .moves import ACTION_NAMES
@ -19,16 +22,23 @@ cdef class PyState:
for i in range(self.n_classes): for i in range(self.n_classes):
m = &self._moves[i] m = &self._moves[i]
action_name = ACTION_NAMES[m.action] action_name = ACTION_NAMES[m.action]
if action_name == 'O':
self.moves_by_name['O'] = i
else:
tag_name = tag_names[m.label] tag_name = tag_names[m.label]
self.moves_by_name['%s-%s' % (action_name, tag_name)] = i self.moves_by_name['%s-%s' % (action_name, tag_name)] = i
cdef Move* _get_move(self, unicode move_name) except NULL:
return &self._moves[self.moves_by_name[move_name]]
def transition(self, unicode move_name): def transition(self, unicode move_name):
cdef int m_i = self.moves_by_name[move_name] cdef Move* m = self._get_move(move_name)
cdef Move* m = &self._moves[m_i]
transition(self._s, m) transition(self._s, m)
def is_valid(self, unicode move_name): def is_valid(self, unicode move_name):
pass cdef Move* m = self._get_move(move_name)
set_accept_if_valid(self._moves, self.n_classes, self._s)
return m.accept
def is_gold(self, unicode move_name): def is_gold(self, unicode move_name):
pass pass