mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
* Fixes to unshift/fast-forward strategy. Getting 91.55 greedy on NW dev, gold preproc
This commit is contained in:
parent
afd77a529b
commit
15e177d7a1
|
@ -5,6 +5,13 @@ from thinc.typedefs cimport weight_t
|
|||
from .stateclass cimport StateClass
|
||||
|
||||
from .transition_system cimport TransitionSystem, Transition
|
||||
from ..gold cimport GoldParseC
|
||||
|
||||
|
||||
cdef class ArcEager(TransitionSystem):
|
||||
pass
|
||||
|
||||
|
||||
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil
|
||||
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil
|
||||
|
||||
|
|
|
@ -222,20 +222,20 @@ cdef class Break:
|
|||
return False
|
||||
elif st.at_break():
|
||||
return False
|
||||
elif st.B(0) == 0:
|
||||
return False
|
||||
elif st.stack_depth() < 1:
|
||||
return False
|
||||
elif (st.S(0) + 1) != st.B(0):
|
||||
# Must break at the token boundary
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
st.set_break(st.B(0))
|
||||
while st.stack_depth() >= 2 and st.has_head(st.S(0)):
|
||||
st.pop()
|
||||
if st.stack_depth() == 1:
|
||||
st.pop()
|
||||
else:
|
||||
st.unshift()
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
|
@ -243,32 +243,37 @@ cdef class Break:
|
|||
|
||||
@staticmethod
|
||||
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
# When we break, we can't reach any arcs between stack and buffer
|
||||
# So cost is number of deps between S0...Sn and B0...Nn
|
||||
cdef int cost = 0
|
||||
cdef int i, j, B_i, S_i
|
||||
for i in range(s.buffer_length()):
|
||||
B_i = s.B(i)
|
||||
for j in range(s.stack_depth()):
|
||||
S_i = s.S(j)
|
||||
cost += gold.heads[B_i] == S_i
|
||||
cost += gold.heads[S_i] == B_i
|
||||
return cost
|
||||
# Check for sentence boundary --- if it's here, we can't have any deps
|
||||
# between stack and buffer, so rest of action is irrelevant.
|
||||
s0_root = _get_root(s.S(0), gold)
|
||||
b0_root = _get_root(s.B(0), gold)
|
||||
if s0_root == -1 or b0_root == -1 or s0_root != b0_root:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return 0
|
||||
|
||||
cdef int _get_root(int word, const GoldParseC* gold) nogil:
|
||||
while gold.heads[word] != word and gold.labels[word] != -1 and word >= 0:
|
||||
word = gold.heads[word]
|
||||
if gold.labels[word] == -1:
|
||||
return -1
|
||||
else:
|
||||
return word
|
||||
|
||||
|
||||
cdef class ArcEager(TransitionSystem):
|
||||
@classmethod
|
||||
def get_labels(cls, gold_parses):
|
||||
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {},
|
||||
LEFT: {'ROOT': True}, BREAK: {'ROOT': True}}
|
||||
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {'root': True},
|
||||
LEFT: {'root': True}, BREAK: {'root': True}}
|
||||
for raw_text, sents in gold_parses:
|
||||
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
||||
for child, head, label in zip(ids, heads, labels):
|
||||
if label != 'ROOT':
|
||||
if label != 'root':
|
||||
if head < child:
|
||||
move_labels[RIGHT][label] = True
|
||||
elif head > child:
|
||||
|
@ -341,6 +346,9 @@ cdef class ArcEager(TransitionSystem):
|
|||
return t
|
||||
|
||||
cdef int initialize_state(self, StateClass st) except -1:
|
||||
# Ensure sent_end is set to 0 throughout
|
||||
for i in range(st.length):
|
||||
st._sent[i].sent_end = False
|
||||
st.fast_forward()
|
||||
|
||||
cdef int finalize_state(self, StateClass st) except -1:
|
||||
|
|
|
@ -322,7 +322,6 @@ cdef class Out:
|
|||
cdef int g_act = gold.ner[s.B(0)].move
|
||||
cdef int g_tag = gold.ner[s.B(0)].label
|
||||
|
||||
|
||||
if g_act == MISSING:
|
||||
return 0
|
||||
elif g_act == BEGIN:
|
||||
|
|
|
@ -36,6 +36,8 @@ from ..tokens cimport Tokens, TokenC
|
|||
from ..strings cimport StringStore
|
||||
|
||||
from .arc_eager cimport TransitionSystem, Transition
|
||||
from .arc_eager cimport push_cost, arc_cost
|
||||
|
||||
from .transition_system import OracleError
|
||||
|
||||
from ..gold cimport GoldParse
|
||||
|
@ -95,11 +97,12 @@ cdef class Parser:
|
|||
cdef Transition guess
|
||||
words = [w.orth_ for w in tokens]
|
||||
while not stcls.is_final():
|
||||
#print stcls.print_state(words)
|
||||
fill_context(context, stcls)
|
||||
scores = self.model.score(context)
|
||||
guess = self.moves.best_valid(scores, stcls)
|
||||
#print self.moves.move_name(guess.move, guess.label), stcls.print_state(words)
|
||||
guess.do(stcls, guess.label)
|
||||
assert stcls._s_i >= 0
|
||||
self.moves.finalize_state(stcls)
|
||||
tokens.set_parse(stcls._sent)
|
||||
|
||||
|
@ -128,12 +131,27 @@ cdef class Parser:
|
|||
cdef atom_t[CONTEXT_SIZE] context
|
||||
loss = 0
|
||||
words = [w.orth_ for w in tokens]
|
||||
history = []
|
||||
while not stcls.is_final():
|
||||
assert stcls._s_i >= 0
|
||||
fill_context(context, stcls)
|
||||
scores = self.model.score(context)
|
||||
guess = self.moves.best_valid(scores, stcls)
|
||||
best = self.moves.best_gold(scores, stcls, gold)
|
||||
try:
|
||||
best = self.moves.best_gold(scores, stcls, gold)
|
||||
except:
|
||||
history.append((self.moves.move_name(guess.move, guess.label), '!', stcls.print_state(words)))
|
||||
for i, word in enumerate(words):
|
||||
print gold.orig_annot[i]
|
||||
print '\n'.join('\t'.join(s) for s in history)
|
||||
print words[gold.c.heads[stcls.S(0)]]
|
||||
print words[gold.c.heads[stcls.B(0)]]
|
||||
print push_cost(stcls, &gold.c, stcls.B(0))
|
||||
print arc_cost(stcls, &gold.c, stcls.S(0), stcls.B(0))
|
||||
self.moves.set_valid(self.moves._is_valid, stcls)
|
||||
raise
|
||||
cost = guess.get_cost(stcls, &gold.c, guess.label)
|
||||
history.append((self.moves.move_name(guess.move, guess.label), str(cost), stcls.print_state(words)))
|
||||
self.model.update(context, guess.clas, best.clas, cost)
|
||||
guess.do(stcls, guess.label)
|
||||
loss += cost
|
||||
|
|
|
@ -143,7 +143,7 @@ cdef class StateClass:
|
|||
self._stack[self._s_i] = self.B(0)
|
||||
self._s_i += 1
|
||||
self._b_i += 1
|
||||
if self._b_i >= self._break:
|
||||
if self._b_i > self._break:
|
||||
self._break = -1
|
||||
|
||||
cdef void pop(self) nogil:
|
||||
|
@ -167,7 +167,7 @@ cdef class StateClass:
|
|||
self.pop()
|
||||
else:
|
||||
self.unshift()
|
||||
elif self.buffer_length() >= 2 and self.stack_depth() == 0:
|
||||
elif (self.length - self._b_i) >= 1 and self.stack_depth() == 0:
|
||||
self.push()
|
||||
else:
|
||||
break
|
||||
|
@ -208,10 +208,9 @@ cdef class StateClass:
|
|||
self._sent[i].ent_iob = ent_iob
|
||||
self._sent[i].ent_type = ent_type
|
||||
|
||||
cdef void set_break(self, int i) nogil:
|
||||
if 0 <= i < self.length:
|
||||
self._sent[i].sent_end = True
|
||||
self._break = i
|
||||
cdef void set_break(self, int _) nogil:
|
||||
self._sent[self.B(0)].sent_end = True
|
||||
self._break = self._b_i
|
||||
|
||||
cdef void clone(self, StateClass src) nogil:
|
||||
memcpy(self._sent, src._sent, self.length * sizeof(TokenC))
|
||||
|
@ -229,7 +228,7 @@ cdef class StateClass:
|
|||
third = words[self.S(2)] + '_%d' % self.S_(2).head
|
||||
n0 = words[self.B(0)]
|
||||
n1 = words[self.B(1)]
|
||||
return ' '.join((str(self.buffer_length()), str(self.stack_depth()), third, second, top, '|', n0, n1))
|
||||
return ' '.join((str(self.buffer_length()), str(self.B_(0).sent_end), str(self._b_i), str(self._break), str(self.length), str(self.stack_depth()), third, second, top, '|', n0, n1))
|
||||
|
||||
|
||||
# From https://en.wikipedia.org/wiki/Hamming_weight
|
||||
|
|
Loading…
Reference in New Issue
Block a user