From cf55b48ba6fe461d06918a240f1e177ad1eebbd5 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 12 Nov 2014 23:50:12 +1100 Subject: [PATCH] * Switch to predict label on shift. Big increase in accuracy. --- spacy/ner/io_moves.pyx | 41 +++++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/spacy/ner/io_moves.pyx b/spacy/ner/io_moves.pyx index 94f7c15c8..dc268e4a5 100644 --- a/spacy/ner/io_moves.pyx +++ b/spacy/ner/io_moves.pyx @@ -34,23 +34,23 @@ cdef int set_accept_if_oracle(Move* moves, int n, State* s, accept_o = False if g_start == s.curr.start and g_end == s.i: accept_r = True - r_label = g_labels[s.curr.start] accept_s = False elif g_start == s.curr.start and g_end > s.i: accept_s = True + s_label = s.curr.label accept_r = False elif g_starts[s.i] == s.i: accept_r = True - r_label = 0 accept_s = False else: accept_r = True accept_s = True - r_label = 0 + s_label = s.curr.label else: accept_r = False if g_starts[s.i] == s.i: accept_s = True + s_label = g_labels[s.i] accept_o = False else: accept_o = True @@ -60,9 +60,9 @@ cdef int set_accept_if_oracle(Move* moves, int n, State* s, for i in range(1, n): m = &moves[i] if m.action == SHIFT: - m.accept = accept_s + m.accept = accept_s and m.label == s_label elif m.action == REDUCE: - m.accept = accept_r and (r_label == 0 or m.label == r_label) + m.accept = accept_r elif m.action == OUT: m.accept = accept_o n_accept += m.accept @@ -77,7 +77,7 @@ cdef int set_accept_if_valid(Move* moves, int n, State* s) except 0: moves[0].accept = False for i in range(1, n): if moves[i].action == SHIFT: - moves[i].accept = True + moves[i].accept = moves[i].label == s.curr.label or not entity_is_open(s) elif moves[i].action == REDUCE: moves[i].accept = open_ent elif moves[i].action == OUT: @@ -110,11 +110,16 @@ cdef int transition(State *s, Move* move) except -1: s.i += 1 elif move.action == SHIFT: if not entity_is_open(s): - begin_entity(s, 0) + s.curr.start = s.i + s.curr.label = move.label s.i += 1 elif move.action == REDUCE: - s.curr.label = move.label - end_entity(s) + s.curr.end = s.i + s.ents[s.j] = s.curr + s.j += 1 + s.curr.start = 0 + s.curr.label = -1 + s.curr.end = 0 else: raise ValueError(move.action) @@ -132,16 +137,16 @@ cdef int fill_moves(Move* moves, int n, list entity_types) except -1: moves[i].action = MISSING moves[i].label = 0 i += 1 - moves[i].clas = i - moves[i].action = SHIFT - moves[i].label = 0 - i += 1 + for entity_type in entity_types: + moves[i].action = SHIFT + moves[i].label = label_names.setdefault(entity_type, len(label_names)) + moves[i].clas = i + i += 1 moves[i].clas = i moves[i].action = OUT moves[i].label = 0 i += 1 - for entity_type in entity_types: - moves[i].action = REDUCE - moves[i].label = label_names.setdefault(entity_type, len(label_names)) - moves[i].clas = i - i += 1 + moves[i].action = REDUCE + moves[i].clas = i + moves[i].label = 0 + i += 1