Rewrite oracle to not use fast-forward. Seems to work?

This commit is contained in:
Matthew Honnibal 2018-04-01 10:43:11 +02:00
parent c5574f48c7
commit e887b2330e
7 changed files with 107 additions and 118 deletions

View File

@ -157,6 +157,45 @@ cdef cppclass StateC:
else: else:
ids[i] = -1 ids[i] = -1
int can_push() nogil const:
if this.buffer_length == 0:
return 0
else:
return 1
int can_pop() nogil const:
if this.stack_depth() < 1:
return 0
else:
return 1
int can_arc() nogil const:
if this.at_break():
return 0
elif this.stack_depth() < 1:
return 0
elif this.buffer_length == 0:
return 0
else:
return 1
int can_break() nogil const:
if this.buffer_length == 0:
return False
elif this.B_(0).l_edge < 0:
return False
elif this._sent[this.B_(0).l_edge].sent_start < 0:
return False
elif this.stack_depth() < 1: # ?? I guess stops first action break?
return False
elif this.at_break():
return False
else:
return True
int can_split() nogil const:
return 0
int S(int i) nogil const: int S(int i) nogil const:
if i >= this._s_i: if i >= this._s_i:
return -1 return -1
@ -265,7 +304,7 @@ cdef cppclass StateC:
return this._n_until_break == 0 return this._n_until_break == 0
bint is_final() nogil const: bint is_final() nogil const:
return this.stack_depth() <= 0 and this.buffer_length == 0 return this.stack_depth() <= 1 and this.buffer_length == 0
bint has_head(int i) nogil const: bint has_head(int i) nogil const:
return this.safe_get(i).head != 0 return this.safe_get(i).head != 0
@ -287,6 +326,12 @@ cdef cppclass StateC:
int stack_depth() nogil const: int stack_depth() nogil const:
return this._s_i return this._s_i
int segment_length() nogil const:
if this._n_until_break != -1:
return this._n_until_break
else:
return this.buffer_length
uint64_t hash() nogil const: uint64_t hash() nogil const:
cdef TokenC[11] sig cdef TokenC[11] sig
sig[0] = this.S_(2)[0] sig[0] = this.S_(2)[0]
@ -460,69 +505,3 @@ cdef cppclass StateC:
this._n_until_break = src._n_until_break this._n_until_break = src._n_until_break
this.offset = src.offset this.offset = src.offset
this._empty_token = src._empty_token this._empty_token = src._empty_token
void fast_forward() nogil:
# space token attachement policy:
# - attach space tokens always to the last preceding real token
# - except if it's the beginning of a sentence, then attach to the first following
# - boundary case: a document containing multiple space tokens but nothing else,
# then make the last space token the head of all others
while is_space_token(this.B_(0)) \
or this.eol() \
or this.stack_depth() == 0:
if this.eol():
# remove the last sentence's root from the stack
if this.stack_depth() == 1:
this.pop()
# parser got stuck: reduce stack or unshift
elif this.stack_depth() > 1:
if this.has_head(this.S(0)):
this.pop()
else:
this.unshift()
# stack is empty but there is another sentence on the buffer
elif this.buffer_length != 0:
this.push()
else: # stack empty and nothing else coming
break
elif is_space_token(this.B_(0)):
# the normal case: we're somewhere inside a sentence
if this.stack_depth() > 0:
# assert not is_space_token(this.S_(0))
# attach all coming space tokens to their last preceding
# real token (which should be on the top of the stack)
while is_space_token(this.B_(0)):
this.add_arc(this.S(0),this.B(0),0)
this.push()
this.pop()
# the rare case: we're at the beginning of a document:
# space tokens are attached to the first real token on the buffer
elif this.stack_depth() == 0:
# store all space tokens on the stack until a real token shows up
# or the last token on the buffer is reached
while is_space_token(this.B_(0)) and this.buffer_length > 1:
this.push()
# empty the stack by attaching all space tokens to the
# first token on the buffer
# boundary case: if all tokens are space tokens, the last one
# becomes the head of all others
while this.stack_depth() > 0:
this.add_arc(this.B(0),this.S(0),0)
this.pop()
# move the first token onto the stack
this.push()
elif this.stack_depth() == 0:
# for one token sentences (?)
if this.buffer_length == 1:
this.push()
this.pop()
# with an empty stack and a non-empty buffer
# only shift is valid anyway
elif this.buffer_length != 0:
this.push()
else: # can this even happen?
break

View File

@ -66,7 +66,7 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
cdef weight_t cost = 0 cdef weight_t cost = 0
cdef int i, B_i cdef int i, B_i
for i in range(stcls.c.buffer_length): for i in range(stcls.c.segment_length()):
B_i = stcls.B(i) B_i = stcls.B(i)
cost += gold.heads[B_i] == target cost += gold.heads[B_i] == target
cost += gold.heads[target] == B_i cost += gold.heads[target] == B_i
@ -74,8 +74,8 @@ cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nog
break break
if BINARY_COSTS and cost >= 1: if BINARY_COSTS and cost >= 1:
return cost return cost
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0: #if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
cost += 1 # cost += 1
return cost return cost
@ -117,15 +117,23 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
cdef class Shift: cdef class Shift:
@staticmethod @staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef bint is_valid(const StateC* st, attr_t label) nogil:
sent_start = st._sent[st.B_(0).l_edge].sent_start if not st.can_push():
return st.buffer_length >= 2 and not st.shifted[st.B(0)] and sent_start != 1 return False
elif st.stack_depth() == 0: # If the stack is empty, we must push
return True
elif st.shifted[st.B(0)]:
return False
elif st.at_break():
return False
else:
return True
@staticmethod @staticmethod
cdef int transition(StateC* st, attr_t label) nogil: cdef int transition(StateC* st, attr_t label) nogil:
if label != 0: #if label != 0:
st.split(st.B(1), label) # st.split(st.B(1), label)
st.shifted[st.B(0)] = 1
st.push() st.push()
st.fast_forward()
@staticmethod @staticmethod
cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil: cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
@ -138,7 +146,7 @@ cdef class Shift:
@staticmethod @staticmethod
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
return 0 return 0
#if gold.fused_tokens[s.B(1)] == label: #if gold.fused_tokens[s.B(1)] == label: TODO
# return 0 # return 0
#else: #else:
# return 1 # return 1
@ -147,15 +155,21 @@ cdef class Shift:
cdef class Reduce: cdef class Reduce:
@staticmethod @staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef bint is_valid(const StateC* st, attr_t label) nogil:
return st.stack_depth() >= 2 if st.stack_depth() >= 2:
return True
elif st.at_break() and st.stack_depth() == 1:
return True
else:
return False
@staticmethod @staticmethod
cdef int transition(StateC* st, attr_t label) nogil: cdef int transition(StateC* st, attr_t label) nogil:
if st.has_head(st.S(0)): if st.has_head(st.S(0)):
st.pop() st.pop()
elif st.stack_depth() == 1 and st.at_break():
st.pop()
else: else:
st.unshift() st.unshift()
st.fast_forward()
@staticmethod @staticmethod
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
@ -165,15 +179,15 @@ cdef class Reduce:
cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil: cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
cost = pop_cost(st, gold, st.S(0)) cost = pop_cost(st, gold, st.S(0))
if not st.has_head(st.S(0)): if not st.has_head(st.S(0)):
# Decrement cost for the arcs e save # Decrement cost for the arcs we save
for i in range(1, st.stack_depth()): for i in range(1, st.stack_depth()):
S_i = st.S(i) S_i = st.S(i)
if gold.heads[st.S(0)] == S_i: if gold.heads[st.S(0)] == S_i:
cost -= 1 cost -= 1
if gold.heads[S_i] == st.S(0): if gold.heads[S_i] == st.S(0):
cost -= 1 cost -= 1
if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0: #if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0:
cost -= 1 # cost -= 1
return cost return cost
@staticmethod @staticmethod
@ -184,18 +198,18 @@ cdef class Reduce:
cdef class LeftArc: cdef class LeftArc:
@staticmethod @staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef bint is_valid(const StateC* st, attr_t label) nogil:
sent_start = st._sent[st.B_(0).l_edge].sent_start return st.can_arc()
return sent_start != 1
@staticmethod @staticmethod
cdef int transition(StateC* st, attr_t label) nogil: cdef int transition(StateC* st, attr_t label) nogil:
st.add_arc(st.B(0), st.S(0), label) st.add_arc(st.B(0), st.S(0), label)
st.pop() st.pop()
st.fast_forward()
@staticmethod @staticmethod
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label) cdef weight_t move_cost = LeftArc.move_cost(s, gold)
cdef weight_t label_cost = LeftArc.label_cost(s, gold, label)
return move_cost + label_cost
@staticmethod @staticmethod
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil: cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
@ -220,14 +234,17 @@ cdef class RightArc:
@staticmethod @staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef bint is_valid(const StateC* st, attr_t label) nogil:
# If there's (perhaps partial) parse pre-set, don't allow cycle. # If there's (perhaps partial) parse pre-set, don't allow cycle.
sent_start = st._sent[st.B_(0).l_edge].sent_start if not st.can_arc():
return sent_start != 1 and st.H(st.S(0)) != st.B(0) return 0
elif st.H(st.S(0)) == st.B(0):
return 0
else:
return 1
@staticmethod @staticmethod
cdef int transition(StateC* st, attr_t label) nogil: cdef int transition(StateC* st, attr_t label) nogil:
st.add_arc(st.S(0), st.B(0), label) st.add_arc(st.S(0), st.B(0), label)
st.push() st.push()
st.fast_forward()
@staticmethod @staticmethod
cdef inline weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: cdef inline weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
@ -253,21 +270,13 @@ cdef class Break:
cdef int i cdef int i
if not USE_BREAK: if not USE_BREAK:
return False return False
elif st.at_break():
return False
elif st.stack_depth() < 1:
return False
elif st.B_(0).l_edge < 0:
return False
elif st._sent[st.B_(0).l_edge].sent_start < 0:
return False
else: else:
return True return st.can_break()
@staticmethod @staticmethod
cdef int transition(StateC* st, attr_t label) nogil: cdef int transition(StateC* st, attr_t label) nogil:
st.set_break(0) st.set_break(0)
st.fast_forward() st.pop()
@staticmethod @staticmethod
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
@ -317,7 +326,6 @@ cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
st._sent[i].dep = 0 st._sent[i].dep = 0
st._sent[i].l_kids = 0 st._sent[i].l_kids = 0
st._sent[i].r_kids = 0 st._sent[i].r_kids = 0
st.fast_forward()
return <void*>st return <void*>st
@ -520,7 +528,6 @@ cdef class ArcEager(TransitionSystem):
st._sent[i].dep = 0 st._sent[i].dep = 0
st._sent[i].l_kids = 0 st._sent[i].l_kids = 0
st._sent[i].r_kids = 0 st._sent[i].r_kids = 0
st.fast_forward()
cdef int finalize_state(self, StateC* st) nogil: cdef int finalize_state(self, StateC* st) nogil:
cdef int i cdef int i

View File

@ -137,6 +137,3 @@ cdef class StateClass:
cdef inline void clone(self, StateClass src) nogil: cdef inline void clone(self, StateClass src) nogil:
self.c.clone(src.c) self.c.clone(src.c)
cdef inline void fast_forward(self) nogil:
self.c.fast_forward()

View File

@ -30,28 +30,32 @@ cdef class StateClass:
def get_S(self, int i): def get_S(self, int i):
return self.c.S(i) return self.c.S(i)
def push_stack(self, fast_forward=True): def can_push(self):
return self.c.can_push()
def can_pop(self):
return self.c.can_pop()
def can_break(self):
return self.c.can_break()
def can_arc(self):
return self.c.can_arc()
def push_stack(self):
self.c.push() self.c.push()
if fast_forward:
self.c.fast_forward()
def pop_stack(self, fast_forward=True): def pop_stack(self):
self.c.pop() self.c.pop()
if fast_forward:
self.c.fast_forward()
def unshift(self, fast_forward=True): def unshift(self):
self.c.unshift() self.c.unshift()
if fast_forward:
self.c.fast_forward()
def set_break(self, int i): def set_break(self, int i):
self.c.set_break(i) self.c.set_break(i)
def split_token(self, int i, int n, fast_forward=True): def split_token(self, int i, int n):
self.c.split(i, n) self.c.split(i, n)
if fast_forward:
self.c.fast_forward()
def get_doc(self, vocab): def get_doc(self, vocab):
cdef Doc doc = Doc(vocab) cdef Doc doc = Doc(vocab)

View File

@ -41,7 +41,6 @@ def test_init_parser(parser):
pass pass
# TODO: This is flakey, because it depends on what the parser first learns. # TODO: This is flakey, because it depends on what the parser first learns.
@pytest.mark.xfail
def test_add_label(parser): def test_add_label(parser):
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd']) doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
doc = parser(doc) doc = parser(doc)

View File

@ -17,7 +17,10 @@ def vocab():
@pytest.fixture @pytest.fixture
def moves(vocab): def moves(vocab):
aeager = ArcEager(vocab.strings, {}) aeager = ArcEager(vocab.strings)
aeager.add_action(0, '')
aeager.add_action(1, '')
aeager.add_action(4, 'ROOT')
aeager.add_action(2, 'nsubj') aeager.add_action(2, 'nsubj')
aeager.add_action(3, 'dobj') aeager.add_action(3, 'dobj')
aeager.add_action(2, 'aux') aeager.add_action(2, 'aux')

View File

@ -39,5 +39,5 @@ def test_split():
doc = get_doc('abcd') doc = get_doc('abcd')
state = StateClass(doc, max_split=3) state = StateClass(doc, max_split=3)
assert state.queue == [0, 1, 2, 3] assert state.queue == [0, 1, 2, 3]
state.split_token(1, 2, fast_forward=False) state.split_token(1, 2)
assert state.queue == [0, 1, 1*4+1, 2*4+1, 2, 3] assert state.queue == [0, 1, 1*4+1, 2*4+1, 2, 3]