mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-28 02:04:07 +03:00
* Fixes to unshift/fast-forward strategy. Getting 91.55 greedy on NW dev, gold preproc
This commit is contained in:
parent
afd77a529b
commit
15e177d7a1
|
@ -5,6 +5,13 @@ from thinc.typedefs cimport weight_t
|
||||||
from .stateclass cimport StateClass
|
from .stateclass cimport StateClass
|
||||||
|
|
||||||
from .transition_system cimport TransitionSystem, Transition
|
from .transition_system cimport TransitionSystem, Transition
|
||||||
|
from ..gold cimport GoldParseC
|
||||||
|
|
||||||
|
|
||||||
cdef class ArcEager(TransitionSystem):
|
cdef class ArcEager(TransitionSystem):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil
|
||||||
|
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil
|
||||||
|
|
||||||
|
|
|
@ -222,20 +222,20 @@ cdef class Break:
|
||||||
return False
|
return False
|
||||||
elif st.at_break():
|
elif st.at_break():
|
||||||
return False
|
return False
|
||||||
|
elif st.B(0) == 0:
|
||||||
|
return False
|
||||||
elif st.stack_depth() < 1:
|
elif st.stack_depth() < 1:
|
||||||
return False
|
return False
|
||||||
|
elif (st.S(0) + 1) != st.B(0):
|
||||||
|
# Must break at the token boundary
|
||||||
|
return False
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(StateClass st, int label) nogil:
|
cdef int transition(StateClass st, int label) nogil:
|
||||||
st.set_break(st.B(0))
|
st.set_break(st.B(0))
|
||||||
while st.stack_depth() >= 2 and st.has_head(st.S(0)):
|
st.fast_forward()
|
||||||
st.pop()
|
|
||||||
if st.stack_depth() == 1:
|
|
||||||
st.pop()
|
|
||||||
else:
|
|
||||||
st.unshift()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||||
|
@ -243,32 +243,37 @@ cdef class Break:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||||
# When we break, we can't reach any arcs between stack and buffer
|
# Check for sentence boundary --- if it's here, we can't have any deps
|
||||||
# So cost is number of deps between S0...Sn and B0...Nn
|
# between stack and buffer, so rest of action is irrelevant.
|
||||||
cdef int cost = 0
|
s0_root = _get_root(s.S(0), gold)
|
||||||
cdef int i, j, B_i, S_i
|
b0_root = _get_root(s.B(0), gold)
|
||||||
for i in range(s.buffer_length()):
|
if s0_root == -1 or b0_root == -1 or s0_root != b0_root:
|
||||||
B_i = s.B(i)
|
return 0
|
||||||
for j in range(s.stack_depth()):
|
else:
|
||||||
S_i = s.S(j)
|
return 1
|
||||||
cost += gold.heads[B_i] == S_i
|
|
||||||
cost += gold.heads[S_i] == B_i
|
|
||||||
return cost
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
cdef int _get_root(int word, const GoldParseC* gold) nogil:
|
||||||
|
while gold.heads[word] != word and gold.labels[word] != -1 and word >= 0:
|
||||||
|
word = gold.heads[word]
|
||||||
|
if gold.labels[word] == -1:
|
||||||
|
return -1
|
||||||
|
else:
|
||||||
|
return word
|
||||||
|
|
||||||
|
|
||||||
cdef class ArcEager(TransitionSystem):
|
cdef class ArcEager(TransitionSystem):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_labels(cls, gold_parses):
|
def get_labels(cls, gold_parses):
|
||||||
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {},
|
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {'root': True},
|
||||||
LEFT: {'ROOT': True}, BREAK: {'ROOT': True}}
|
LEFT: {'root': True}, BREAK: {'root': True}}
|
||||||
for raw_text, sents in gold_parses:
|
for raw_text, sents in gold_parses:
|
||||||
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
||||||
for child, head, label in zip(ids, heads, labels):
|
for child, head, label in zip(ids, heads, labels):
|
||||||
if label != 'ROOT':
|
if label != 'root':
|
||||||
if head < child:
|
if head < child:
|
||||||
move_labels[RIGHT][label] = True
|
move_labels[RIGHT][label] = True
|
||||||
elif head > child:
|
elif head > child:
|
||||||
|
@ -341,6 +346,9 @@ cdef class ArcEager(TransitionSystem):
|
||||||
return t
|
return t
|
||||||
|
|
||||||
cdef int initialize_state(self, StateClass st) except -1:
|
cdef int initialize_state(self, StateClass st) except -1:
|
||||||
|
# Ensure sent_end is set to 0 throughout
|
||||||
|
for i in range(st.length):
|
||||||
|
st._sent[i].sent_end = False
|
||||||
st.fast_forward()
|
st.fast_forward()
|
||||||
|
|
||||||
cdef int finalize_state(self, StateClass st) except -1:
|
cdef int finalize_state(self, StateClass st) except -1:
|
||||||
|
|
|
@ -322,7 +322,6 @@ cdef class Out:
|
||||||
cdef int g_act = gold.ner[s.B(0)].move
|
cdef int g_act = gold.ner[s.B(0)].move
|
||||||
cdef int g_tag = gold.ner[s.B(0)].label
|
cdef int g_tag = gold.ner[s.B(0)].label
|
||||||
|
|
||||||
|
|
||||||
if g_act == MISSING:
|
if g_act == MISSING:
|
||||||
return 0
|
return 0
|
||||||
elif g_act == BEGIN:
|
elif g_act == BEGIN:
|
||||||
|
|
|
@ -36,6 +36,8 @@ from ..tokens cimport Tokens, TokenC
|
||||||
from ..strings cimport StringStore
|
from ..strings cimport StringStore
|
||||||
|
|
||||||
from .arc_eager cimport TransitionSystem, Transition
|
from .arc_eager cimport TransitionSystem, Transition
|
||||||
|
from .arc_eager cimport push_cost, arc_cost
|
||||||
|
|
||||||
from .transition_system import OracleError
|
from .transition_system import OracleError
|
||||||
|
|
||||||
from ..gold cimport GoldParse
|
from ..gold cimport GoldParse
|
||||||
|
@ -95,11 +97,12 @@ cdef class Parser:
|
||||||
cdef Transition guess
|
cdef Transition guess
|
||||||
words = [w.orth_ for w in tokens]
|
words = [w.orth_ for w in tokens]
|
||||||
while not stcls.is_final():
|
while not stcls.is_final():
|
||||||
#print stcls.print_state(words)
|
|
||||||
fill_context(context, stcls)
|
fill_context(context, stcls)
|
||||||
scores = self.model.score(context)
|
scores = self.model.score(context)
|
||||||
guess = self.moves.best_valid(scores, stcls)
|
guess = self.moves.best_valid(scores, stcls)
|
||||||
|
#print self.moves.move_name(guess.move, guess.label), stcls.print_state(words)
|
||||||
guess.do(stcls, guess.label)
|
guess.do(stcls, guess.label)
|
||||||
|
assert stcls._s_i >= 0
|
||||||
self.moves.finalize_state(stcls)
|
self.moves.finalize_state(stcls)
|
||||||
tokens.set_parse(stcls._sent)
|
tokens.set_parse(stcls._sent)
|
||||||
|
|
||||||
|
@ -128,12 +131,27 @@ cdef class Parser:
|
||||||
cdef atom_t[CONTEXT_SIZE] context
|
cdef atom_t[CONTEXT_SIZE] context
|
||||||
loss = 0
|
loss = 0
|
||||||
words = [w.orth_ for w in tokens]
|
words = [w.orth_ for w in tokens]
|
||||||
|
history = []
|
||||||
while not stcls.is_final():
|
while not stcls.is_final():
|
||||||
|
assert stcls._s_i >= 0
|
||||||
fill_context(context, stcls)
|
fill_context(context, stcls)
|
||||||
scores = self.model.score(context)
|
scores = self.model.score(context)
|
||||||
guess = self.moves.best_valid(scores, stcls)
|
guess = self.moves.best_valid(scores, stcls)
|
||||||
best = self.moves.best_gold(scores, stcls, gold)
|
try:
|
||||||
|
best = self.moves.best_gold(scores, stcls, gold)
|
||||||
|
except:
|
||||||
|
history.append((self.moves.move_name(guess.move, guess.label), '!', stcls.print_state(words)))
|
||||||
|
for i, word in enumerate(words):
|
||||||
|
print gold.orig_annot[i]
|
||||||
|
print '\n'.join('\t'.join(s) for s in history)
|
||||||
|
print words[gold.c.heads[stcls.S(0)]]
|
||||||
|
print words[gold.c.heads[stcls.B(0)]]
|
||||||
|
print push_cost(stcls, &gold.c, stcls.B(0))
|
||||||
|
print arc_cost(stcls, &gold.c, stcls.S(0), stcls.B(0))
|
||||||
|
self.moves.set_valid(self.moves._is_valid, stcls)
|
||||||
|
raise
|
||||||
cost = guess.get_cost(stcls, &gold.c, guess.label)
|
cost = guess.get_cost(stcls, &gold.c, guess.label)
|
||||||
|
history.append((self.moves.move_name(guess.move, guess.label), str(cost), stcls.print_state(words)))
|
||||||
self.model.update(context, guess.clas, best.clas, cost)
|
self.model.update(context, guess.clas, best.clas, cost)
|
||||||
guess.do(stcls, guess.label)
|
guess.do(stcls, guess.label)
|
||||||
loss += cost
|
loss += cost
|
||||||
|
|
|
@ -143,7 +143,7 @@ cdef class StateClass:
|
||||||
self._stack[self._s_i] = self.B(0)
|
self._stack[self._s_i] = self.B(0)
|
||||||
self._s_i += 1
|
self._s_i += 1
|
||||||
self._b_i += 1
|
self._b_i += 1
|
||||||
if self._b_i >= self._break:
|
if self._b_i > self._break:
|
||||||
self._break = -1
|
self._break = -1
|
||||||
|
|
||||||
cdef void pop(self) nogil:
|
cdef void pop(self) nogil:
|
||||||
|
@ -167,7 +167,7 @@ cdef class StateClass:
|
||||||
self.pop()
|
self.pop()
|
||||||
else:
|
else:
|
||||||
self.unshift()
|
self.unshift()
|
||||||
elif self.buffer_length() >= 2 and self.stack_depth() == 0:
|
elif (self.length - self._b_i) >= 1 and self.stack_depth() == 0:
|
||||||
self.push()
|
self.push()
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
@ -208,10 +208,9 @@ cdef class StateClass:
|
||||||
self._sent[i].ent_iob = ent_iob
|
self._sent[i].ent_iob = ent_iob
|
||||||
self._sent[i].ent_type = ent_type
|
self._sent[i].ent_type = ent_type
|
||||||
|
|
||||||
cdef void set_break(self, int i) nogil:
|
cdef void set_break(self, int _) nogil:
|
||||||
if 0 <= i < self.length:
|
self._sent[self.B(0)].sent_end = True
|
||||||
self._sent[i].sent_end = True
|
self._break = self._b_i
|
||||||
self._break = i
|
|
||||||
|
|
||||||
cdef void clone(self, StateClass src) nogil:
|
cdef void clone(self, StateClass src) nogil:
|
||||||
memcpy(self._sent, src._sent, self.length * sizeof(TokenC))
|
memcpy(self._sent, src._sent, self.length * sizeof(TokenC))
|
||||||
|
@ -229,7 +228,7 @@ cdef class StateClass:
|
||||||
third = words[self.S(2)] + '_%d' % self.S_(2).head
|
third = words[self.S(2)] + '_%d' % self.S_(2).head
|
||||||
n0 = words[self.B(0)]
|
n0 = words[self.B(0)]
|
||||||
n1 = words[self.B(1)]
|
n1 = words[self.B(1)]
|
||||||
return ' '.join((str(self.buffer_length()), str(self.stack_depth()), third, second, top, '|', n0, n1))
|
return ' '.join((str(self.buffer_length()), str(self.B_(0).sent_end), str(self._b_i), str(self._break), str(self.length), str(self.stack_depth()), third, second, top, '|', n0, n1))
|
||||||
|
|
||||||
|
|
||||||
# From https://en.wikipedia.org/wiki/Hamming_weight
|
# From https://en.wikipedia.org/wiki/Hamming_weight
|
||||||
|
|
Loading…
Reference in New Issue
Block a user