* Switch to predict label on shift. Big increase in accuracy.

This commit is contained in:
Matthew Honnibal 2014-11-12 23:50:12 +11:00
parent 8f84e8a78b
commit cf55b48ba6

View File

@ -34,23 +34,23 @@ cdef int set_accept_if_oracle(Move* moves, int n, State* s,
accept_o = False accept_o = False
if g_start == s.curr.start and g_end == s.i: if g_start == s.curr.start and g_end == s.i:
accept_r = True accept_r = True
r_label = g_labels[s.curr.start]
accept_s = False accept_s = False
elif g_start == s.curr.start and g_end > s.i: elif g_start == s.curr.start and g_end > s.i:
accept_s = True accept_s = True
s_label = s.curr.label
accept_r = False accept_r = False
elif g_starts[s.i] == s.i: elif g_starts[s.i] == s.i:
accept_r = True accept_r = True
r_label = 0
accept_s = False accept_s = False
else: else:
accept_r = True accept_r = True
accept_s = True accept_s = True
r_label = 0 s_label = s.curr.label
else: else:
accept_r = False accept_r = False
if g_starts[s.i] == s.i: if g_starts[s.i] == s.i:
accept_s = True accept_s = True
s_label = g_labels[s.i]
accept_o = False accept_o = False
else: else:
accept_o = True accept_o = True
@ -60,9 +60,9 @@ cdef int set_accept_if_oracle(Move* moves, int n, State* s,
for i in range(1, n): for i in range(1, n):
m = &moves[i] m = &moves[i]
if m.action == SHIFT: if m.action == SHIFT:
m.accept = accept_s m.accept = accept_s and m.label == s_label
elif m.action == REDUCE: elif m.action == REDUCE:
m.accept = accept_r and (r_label == 0 or m.label == r_label) m.accept = accept_r
elif m.action == OUT: elif m.action == OUT:
m.accept = accept_o m.accept = accept_o
n_accept += m.accept n_accept += m.accept
@ -77,7 +77,7 @@ cdef int set_accept_if_valid(Move* moves, int n, State* s) except 0:
moves[0].accept = False moves[0].accept = False
for i in range(1, n): for i in range(1, n):
if moves[i].action == SHIFT: if moves[i].action == SHIFT:
moves[i].accept = True moves[i].accept = moves[i].label == s.curr.label or not entity_is_open(s)
elif moves[i].action == REDUCE: elif moves[i].action == REDUCE:
moves[i].accept = open_ent moves[i].accept = open_ent
elif moves[i].action == OUT: elif moves[i].action == OUT:
@ -110,11 +110,16 @@ cdef int transition(State *s, Move* move) except -1:
s.i += 1 s.i += 1
elif move.action == SHIFT: elif move.action == SHIFT:
if not entity_is_open(s): if not entity_is_open(s):
begin_entity(s, 0) s.curr.start = s.i
s.curr.label = move.label
s.i += 1 s.i += 1
elif move.action == REDUCE: elif move.action == REDUCE:
s.curr.label = move.label s.curr.end = s.i
end_entity(s) s.ents[s.j] = s.curr
s.j += 1
s.curr.start = 0
s.curr.label = -1
s.curr.end = 0
else: else:
raise ValueError(move.action) raise ValueError(move.action)
@ -132,16 +137,16 @@ cdef int fill_moves(Move* moves, int n, list entity_types) except -1:
moves[i].action = MISSING moves[i].action = MISSING
moves[i].label = 0 moves[i].label = 0
i += 1 i += 1
moves[i].clas = i for entity_type in entity_types:
moves[i].action = SHIFT moves[i].action = SHIFT
moves[i].label = 0 moves[i].label = label_names.setdefault(entity_type, len(label_names))
i += 1 moves[i].clas = i
i += 1
moves[i].clas = i moves[i].clas = i
moves[i].action = OUT moves[i].action = OUT
moves[i].label = 0 moves[i].label = 0
i += 1 i += 1
for entity_type in entity_types: moves[i].action = REDUCE
moves[i].action = REDUCE moves[i].clas = i
moves[i].label = label_names.setdefault(entity_type, len(label_names)) moves[i].label = 0
moves[i].clas = i i += 1
i += 1