mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
* Switch to predict label on shift. Big increase in accuracy.
This commit is contained in:
parent
8f84e8a78b
commit
cf55b48ba6
|
@ -34,23 +34,23 @@ cdef int set_accept_if_oracle(Move* moves, int n, State* s,
|
|||
accept_o = False
|
||||
if g_start == s.curr.start and g_end == s.i:
|
||||
accept_r = True
|
||||
r_label = g_labels[s.curr.start]
|
||||
accept_s = False
|
||||
elif g_start == s.curr.start and g_end > s.i:
|
||||
accept_s = True
|
||||
s_label = s.curr.label
|
||||
accept_r = False
|
||||
elif g_starts[s.i] == s.i:
|
||||
accept_r = True
|
||||
r_label = 0
|
||||
accept_s = False
|
||||
else:
|
||||
accept_r = True
|
||||
accept_s = True
|
||||
r_label = 0
|
||||
s_label = s.curr.label
|
||||
else:
|
||||
accept_r = False
|
||||
if g_starts[s.i] == s.i:
|
||||
accept_s = True
|
||||
s_label = g_labels[s.i]
|
||||
accept_o = False
|
||||
else:
|
||||
accept_o = True
|
||||
|
@ -60,9 +60,9 @@ cdef int set_accept_if_oracle(Move* moves, int n, State* s,
|
|||
for i in range(1, n):
|
||||
m = &moves[i]
|
||||
if m.action == SHIFT:
|
||||
m.accept = accept_s
|
||||
m.accept = accept_s and m.label == s_label
|
||||
elif m.action == REDUCE:
|
||||
m.accept = accept_r and (r_label == 0 or m.label == r_label)
|
||||
m.accept = accept_r
|
||||
elif m.action == OUT:
|
||||
m.accept = accept_o
|
||||
n_accept += m.accept
|
||||
|
@ -77,7 +77,7 @@ cdef int set_accept_if_valid(Move* moves, int n, State* s) except 0:
|
|||
moves[0].accept = False
|
||||
for i in range(1, n):
|
||||
if moves[i].action == SHIFT:
|
||||
moves[i].accept = True
|
||||
moves[i].accept = moves[i].label == s.curr.label or not entity_is_open(s)
|
||||
elif moves[i].action == REDUCE:
|
||||
moves[i].accept = open_ent
|
||||
elif moves[i].action == OUT:
|
||||
|
@ -110,11 +110,16 @@ cdef int transition(State *s, Move* move) except -1:
|
|||
s.i += 1
|
||||
elif move.action == SHIFT:
|
||||
if not entity_is_open(s):
|
||||
begin_entity(s, 0)
|
||||
s.curr.start = s.i
|
||||
s.curr.label = move.label
|
||||
s.i += 1
|
||||
elif move.action == REDUCE:
|
||||
s.curr.label = move.label
|
||||
end_entity(s)
|
||||
s.curr.end = s.i
|
||||
s.ents[s.j] = s.curr
|
||||
s.j += 1
|
||||
s.curr.start = 0
|
||||
s.curr.label = -1
|
||||
s.curr.end = 0
|
||||
else:
|
||||
raise ValueError(move.action)
|
||||
|
||||
|
@ -132,16 +137,16 @@ cdef int fill_moves(Move* moves, int n, list entity_types) except -1:
|
|||
moves[i].action = MISSING
|
||||
moves[i].label = 0
|
||||
i += 1
|
||||
moves[i].clas = i
|
||||
moves[i].action = SHIFT
|
||||
moves[i].label = 0
|
||||
i += 1
|
||||
for entity_type in entity_types:
|
||||
moves[i].action = SHIFT
|
||||
moves[i].label = label_names.setdefault(entity_type, len(label_names))
|
||||
moves[i].clas = i
|
||||
i += 1
|
||||
moves[i].clas = i
|
||||
moves[i].action = OUT
|
||||
moves[i].label = 0
|
||||
i += 1
|
||||
for entity_type in entity_types:
|
||||
moves[i].action = REDUCE
|
||||
moves[i].label = label_names.setdefault(entity_type, len(label_names))
|
||||
moves[i].clas = i
|
||||
i += 1
|
||||
moves[i].action = REDUCE
|
||||
moves[i].clas = i
|
||||
moves[i].label = 0
|
||||
i += 1
|
||||
|
|
Loading…
Reference in New Issue
Block a user