* Fixes to unshift/fast-forward strategy. Getting 91.55 greedy on NW dev, gold preproc

This commit is contained in:
Matthew Honnibal 2015-06-12 01:50:23 +02:00
parent afd77a529b
commit 15e177d7a1
5 changed files with 61 additions and 30 deletions

View File

@ -5,6 +5,13 @@ from thinc.typedefs cimport weight_t
from .stateclass cimport StateClass
from .transition_system cimport TransitionSystem, Transition
from ..gold cimport GoldParseC
cdef class ArcEager(TransitionSystem):
pass
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil

View File

@ -222,20 +222,20 @@ cdef class Break:
return False
elif st.at_break():
return False
elif st.B(0) == 0:
return False
elif st.stack_depth() < 1:
return False
elif (st.S(0) + 1) != st.B(0):
# Must break at the token boundary
return False
else:
return True
@staticmethod
cdef int transition(StateClass st, int label) nogil:
st.set_break(st.B(0))
while st.stack_depth() >= 2 and st.has_head(st.S(0)):
st.pop()
if st.stack_depth() == 1:
st.pop()
else:
st.unshift()
st.fast_forward()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
@ -243,32 +243,37 @@ cdef class Break:
@staticmethod
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
# When we break, we can't reach any arcs between stack and buffer
# So cost is number of deps between S0...Sn and B0...Nn
cdef int cost = 0
cdef int i, j, B_i, S_i
for i in range(s.buffer_length()):
B_i = s.B(i)
for j in range(s.stack_depth()):
S_i = s.S(j)
cost += gold.heads[B_i] == S_i
cost += gold.heads[S_i] == B_i
return cost
# Check for sentence boundary --- if it's here, we can't have any deps
# between stack and buffer, so rest of action is irrelevant.
s0_root = _get_root(s.S(0), gold)
b0_root = _get_root(s.B(0), gold)
if s0_root == -1 or b0_root == -1 or s0_root != b0_root:
return 0
else:
return 1
@staticmethod
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return 0
cdef int _get_root(int word, const GoldParseC* gold) nogil:
while gold.heads[word] != word and gold.labels[word] != -1 and word >= 0:
word = gold.heads[word]
if gold.labels[word] == -1:
return -1
else:
return word
cdef class ArcEager(TransitionSystem):
@classmethod
def get_labels(cls, gold_parses):
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {},
LEFT: {'ROOT': True}, BREAK: {'ROOT': True}}
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {'root': True},
LEFT: {'root': True}, BREAK: {'root': True}}
for raw_text, sents in gold_parses:
for (ids, words, tags, heads, labels, iob), ctnts in sents:
for child, head, label in zip(ids, heads, labels):
if label != 'ROOT':
if label != 'root':
if head < child:
move_labels[RIGHT][label] = True
elif head > child:
@ -341,6 +346,9 @@ cdef class ArcEager(TransitionSystem):
return t
cdef int initialize_state(self, StateClass st) except -1:
# Ensure sent_end is set to 0 throughout
for i in range(st.length):
st._sent[i].sent_end = False
st.fast_forward()
cdef int finalize_state(self, StateClass st) except -1:

View File

@ -322,7 +322,6 @@ cdef class Out:
cdef int g_act = gold.ner[s.B(0)].move
cdef int g_tag = gold.ner[s.B(0)].label
if g_act == MISSING:
return 0
elif g_act == BEGIN:

View File

@ -36,6 +36,8 @@ from ..tokens cimport Tokens, TokenC
from ..strings cimport StringStore
from .arc_eager cimport TransitionSystem, Transition
from .arc_eager cimport push_cost, arc_cost
from .transition_system import OracleError
from ..gold cimport GoldParse
@ -95,11 +97,12 @@ cdef class Parser:
cdef Transition guess
words = [w.orth_ for w in tokens]
while not stcls.is_final():
#print stcls.print_state(words)
fill_context(context, stcls)
scores = self.model.score(context)
guess = self.moves.best_valid(scores, stcls)
#print self.moves.move_name(guess.move, guess.label), stcls.print_state(words)
guess.do(stcls, guess.label)
assert stcls._s_i >= 0
self.moves.finalize_state(stcls)
tokens.set_parse(stcls._sent)
@ -128,12 +131,27 @@ cdef class Parser:
cdef atom_t[CONTEXT_SIZE] context
loss = 0
words = [w.orth_ for w in tokens]
history = []
while not stcls.is_final():
assert stcls._s_i >= 0
fill_context(context, stcls)
scores = self.model.score(context)
guess = self.moves.best_valid(scores, stcls)
best = self.moves.best_gold(scores, stcls, gold)
try:
best = self.moves.best_gold(scores, stcls, gold)
except:
history.append((self.moves.move_name(guess.move, guess.label), '!', stcls.print_state(words)))
for i, word in enumerate(words):
print gold.orig_annot[i]
print '\n'.join('\t'.join(s) for s in history)
print words[gold.c.heads[stcls.S(0)]]
print words[gold.c.heads[stcls.B(0)]]
print push_cost(stcls, &gold.c, stcls.B(0))
print arc_cost(stcls, &gold.c, stcls.S(0), stcls.B(0))
self.moves.set_valid(self.moves._is_valid, stcls)
raise
cost = guess.get_cost(stcls, &gold.c, guess.label)
history.append((self.moves.move_name(guess.move, guess.label), str(cost), stcls.print_state(words)))
self.model.update(context, guess.clas, best.clas, cost)
guess.do(stcls, guess.label)
loss += cost

View File

@ -143,7 +143,7 @@ cdef class StateClass:
self._stack[self._s_i] = self.B(0)
self._s_i += 1
self._b_i += 1
if self._b_i >= self._break:
if self._b_i > self._break:
self._break = -1
cdef void pop(self) nogil:
@ -167,7 +167,7 @@ cdef class StateClass:
self.pop()
else:
self.unshift()
elif self.buffer_length() >= 2 and self.stack_depth() == 0:
elif (self.length - self._b_i) >= 1 and self.stack_depth() == 0:
self.push()
else:
break
@ -208,10 +208,9 @@ cdef class StateClass:
self._sent[i].ent_iob = ent_iob
self._sent[i].ent_type = ent_type
cdef void set_break(self, int i) nogil:
if 0 <= i < self.length:
self._sent[i].sent_end = True
self._break = i
cdef void set_break(self, int _) nogil:
self._sent[self.B(0)].sent_end = True
self._break = self._b_i
cdef void clone(self, StateClass src) nogil:
memcpy(self._sent, src._sent, self.length * sizeof(TokenC))
@ -229,7 +228,7 @@ cdef class StateClass:
third = words[self.S(2)] + '_%d' % self.S_(2).head
n0 = words[self.B(0)]
n1 = words[self.B(1)]
return ' '.join((str(self.buffer_length()), str(self.stack_depth()), third, second, top, '|', n0, n1))
return ' '.join((str(self.buffer_length()), str(self.B_(0).sent_end), str(self._b_i), str(self._break), str(self.length), str(self.stack_depth()), third, second, top, '|', n0, n1))
# From https://en.wikipedia.org/wiki/Hamming_weight