mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-10 15:14:56 +03:00
Rewrite oracle to not use fast-forward. Seems to work?
This commit is contained in:
parent
c5574f48c7
commit
e887b2330e
|
@ -157,6 +157,45 @@ cdef cppclass StateC:
|
|||
else:
|
||||
ids[i] = -1
|
||||
|
||||
int can_push() nogil const:
|
||||
if this.buffer_length == 0:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
int can_pop() nogil const:
|
||||
if this.stack_depth() < 1:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
int can_arc() nogil const:
|
||||
if this.at_break():
|
||||
return 0
|
||||
elif this.stack_depth() < 1:
|
||||
return 0
|
||||
elif this.buffer_length == 0:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
int can_break() nogil const:
|
||||
if this.buffer_length == 0:
|
||||
return False
|
||||
elif this.B_(0).l_edge < 0:
|
||||
return False
|
||||
elif this._sent[this.B_(0).l_edge].sent_start < 0:
|
||||
return False
|
||||
elif this.stack_depth() < 1: # ?? I guess stops first action break?
|
||||
return False
|
||||
elif this.at_break():
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
int can_split() nogil const:
|
||||
return 0
|
||||
|
||||
int S(int i) nogil const:
|
||||
if i >= this._s_i:
|
||||
return -1
|
||||
|
@ -265,7 +304,7 @@ cdef cppclass StateC:
|
|||
return this._n_until_break == 0
|
||||
|
||||
bint is_final() nogil const:
|
||||
return this.stack_depth() <= 0 and this.buffer_length == 0
|
||||
return this.stack_depth() <= 1 and this.buffer_length == 0
|
||||
|
||||
bint has_head(int i) nogil const:
|
||||
return this.safe_get(i).head != 0
|
||||
|
@ -287,6 +326,12 @@ cdef cppclass StateC:
|
|||
int stack_depth() nogil const:
|
||||
return this._s_i
|
||||
|
||||
int segment_length() nogil const:
|
||||
if this._n_until_break != -1:
|
||||
return this._n_until_break
|
||||
else:
|
||||
return this.buffer_length
|
||||
|
||||
uint64_t hash() nogil const:
|
||||
cdef TokenC[11] sig
|
||||
sig[0] = this.S_(2)[0]
|
||||
|
@ -460,69 +505,3 @@ cdef cppclass StateC:
|
|||
this._n_until_break = src._n_until_break
|
||||
this.offset = src.offset
|
||||
this._empty_token = src._empty_token
|
||||
|
||||
void fast_forward() nogil:
|
||||
# space token attachement policy:
|
||||
# - attach space tokens always to the last preceding real token
|
||||
# - except if it's the beginning of a sentence, then attach to the first following
|
||||
# - boundary case: a document containing multiple space tokens but nothing else,
|
||||
# then make the last space token the head of all others
|
||||
|
||||
while is_space_token(this.B_(0)) \
|
||||
or this.eol() \
|
||||
or this.stack_depth() == 0:
|
||||
if this.eol():
|
||||
# remove the last sentence's root from the stack
|
||||
if this.stack_depth() == 1:
|
||||
this.pop()
|
||||
# parser got stuck: reduce stack or unshift
|
||||
elif this.stack_depth() > 1:
|
||||
if this.has_head(this.S(0)):
|
||||
this.pop()
|
||||
else:
|
||||
this.unshift()
|
||||
# stack is empty but there is another sentence on the buffer
|
||||
elif this.buffer_length != 0:
|
||||
this.push()
|
||||
else: # stack empty and nothing else coming
|
||||
break
|
||||
|
||||
elif is_space_token(this.B_(0)):
|
||||
# the normal case: we're somewhere inside a sentence
|
||||
if this.stack_depth() > 0:
|
||||
# assert not is_space_token(this.S_(0))
|
||||
# attach all coming space tokens to their last preceding
|
||||
# real token (which should be on the top of the stack)
|
||||
while is_space_token(this.B_(0)):
|
||||
this.add_arc(this.S(0),this.B(0),0)
|
||||
this.push()
|
||||
this.pop()
|
||||
# the rare case: we're at the beginning of a document:
|
||||
# space tokens are attached to the first real token on the buffer
|
||||
elif this.stack_depth() == 0:
|
||||
# store all space tokens on the stack until a real token shows up
|
||||
# or the last token on the buffer is reached
|
||||
while is_space_token(this.B_(0)) and this.buffer_length > 1:
|
||||
this.push()
|
||||
# empty the stack by attaching all space tokens to the
|
||||
# first token on the buffer
|
||||
# boundary case: if all tokens are space tokens, the last one
|
||||
# becomes the head of all others
|
||||
while this.stack_depth() > 0:
|
||||
this.add_arc(this.B(0),this.S(0),0)
|
||||
this.pop()
|
||||
# move the first token onto the stack
|
||||
this.push()
|
||||
|
||||
elif this.stack_depth() == 0:
|
||||
# for one token sentences (?)
|
||||
if this.buffer_length == 1:
|
||||
this.push()
|
||||
this.pop()
|
||||
# with an empty stack and a non-empty buffer
|
||||
# only shift is valid anyway
|
||||
elif this.buffer_length != 0:
|
||||
this.push()
|
||||
|
||||
else: # can this even happen?
|
||||
break
|
||||
|
|
|
@ -66,7 +66,7 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
|
|||
cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
|
||||
cdef weight_t cost = 0
|
||||
cdef int i, B_i
|
||||
for i in range(stcls.c.buffer_length):
|
||||
for i in range(stcls.c.segment_length()):
|
||||
B_i = stcls.B(i)
|
||||
cost += gold.heads[B_i] == target
|
||||
cost += gold.heads[target] == B_i
|
||||
|
@ -74,8 +74,8 @@ cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nog
|
|||
break
|
||||
if BINARY_COSTS and cost >= 1:
|
||||
return cost
|
||||
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
|
||||
cost += 1
|
||||
#if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
|
||||
# cost += 1
|
||||
return cost
|
||||
|
||||
|
||||
|
@ -117,15 +117,23 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
|
|||
cdef class Shift:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
sent_start = st._sent[st.B_(0).l_edge].sent_start
|
||||
return st.buffer_length >= 2 and not st.shifted[st.B(0)] and sent_start != 1
|
||||
if not st.can_push():
|
||||
return False
|
||||
elif st.stack_depth() == 0: # If the stack is empty, we must push
|
||||
return True
|
||||
elif st.shifted[st.B(0)]:
|
||||
return False
|
||||
elif st.at_break():
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
if label != 0:
|
||||
st.split(st.B(1), label)
|
||||
#if label != 0:
|
||||
# st.split(st.B(1), label)
|
||||
st.shifted[st.B(0)] = 1
|
||||
st.push()
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
|
||||
|
@ -138,7 +146,7 @@ cdef class Shift:
|
|||
@staticmethod
|
||||
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
||||
return 0
|
||||
#if gold.fused_tokens[s.B(1)] == label:
|
||||
#if gold.fused_tokens[s.B(1)] == label: TODO
|
||||
# return 0
|
||||
#else:
|
||||
# return 1
|
||||
|
@ -147,15 +155,21 @@ cdef class Shift:
|
|||
cdef class Reduce:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
return st.stack_depth() >= 2
|
||||
if st.stack_depth() >= 2:
|
||||
return True
|
||||
elif st.at_break() and st.stack_depth() == 1:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
if st.has_head(st.S(0)):
|
||||
st.pop()
|
||||
elif st.stack_depth() == 1 and st.at_break():
|
||||
st.pop()
|
||||
else:
|
||||
st.unshift()
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
||||
|
@ -165,15 +179,15 @@ cdef class Reduce:
|
|||
cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
|
||||
cost = pop_cost(st, gold, st.S(0))
|
||||
if not st.has_head(st.S(0)):
|
||||
# Decrement cost for the arcs e save
|
||||
# Decrement cost for the arcs we save
|
||||
for i in range(1, st.stack_depth()):
|
||||
S_i = st.S(i)
|
||||
if gold.heads[st.S(0)] == S_i:
|
||||
cost -= 1
|
||||
if gold.heads[S_i] == st.S(0):
|
||||
cost -= 1
|
||||
if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0:
|
||||
cost -= 1
|
||||
#if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0:
|
||||
# cost -= 1
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
|
@ -184,18 +198,18 @@ cdef class Reduce:
|
|||
cdef class LeftArc:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
sent_start = st._sent[st.B_(0).l_edge].sent_start
|
||||
return sent_start != 1
|
||||
return st.can_arc()
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.add_arc(st.B(0), st.S(0), label)
|
||||
st.pop()
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
||||
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
|
||||
cdef weight_t move_cost = LeftArc.move_cost(s, gold)
|
||||
cdef weight_t label_cost = LeftArc.label_cost(s, gold, label)
|
||||
return move_cost + label_cost
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
|
@ -220,14 +234,17 @@ cdef class RightArc:
|
|||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
# If there's (perhaps partial) parse pre-set, don't allow cycle.
|
||||
sent_start = st._sent[st.B_(0).l_edge].sent_start
|
||||
return sent_start != 1 and st.H(st.S(0)) != st.B(0)
|
||||
if not st.can_arc():
|
||||
return 0
|
||||
elif st.H(st.S(0)) == st.B(0):
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.add_arc(st.S(0), st.B(0), label)
|
||||
st.push()
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
||||
|
@ -253,21 +270,13 @@ cdef class Break:
|
|||
cdef int i
|
||||
if not USE_BREAK:
|
||||
return False
|
||||
elif st.at_break():
|
||||
return False
|
||||
elif st.stack_depth() < 1:
|
||||
return False
|
||||
elif st.B_(0).l_edge < 0:
|
||||
return False
|
||||
elif st._sent[st.B_(0).l_edge].sent_start < 0:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return st.can_break()
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.set_break(0)
|
||||
st.fast_forward()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
||||
|
@ -317,7 +326,6 @@ cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
|||
st._sent[i].dep = 0
|
||||
st._sent[i].l_kids = 0
|
||||
st._sent[i].r_kids = 0
|
||||
st.fast_forward()
|
||||
return <void*>st
|
||||
|
||||
|
||||
|
@ -520,7 +528,6 @@ cdef class ArcEager(TransitionSystem):
|
|||
st._sent[i].dep = 0
|
||||
st._sent[i].l_kids = 0
|
||||
st._sent[i].r_kids = 0
|
||||
st.fast_forward()
|
||||
|
||||
cdef int finalize_state(self, StateC* st) nogil:
|
||||
cdef int i
|
||||
|
|
|
@ -137,6 +137,3 @@ cdef class StateClass:
|
|||
|
||||
cdef inline void clone(self, StateClass src) nogil:
|
||||
self.c.clone(src.c)
|
||||
|
||||
cdef inline void fast_forward(self) nogil:
|
||||
self.c.fast_forward()
|
||||
|
|
|
@ -30,28 +30,32 @@ cdef class StateClass:
|
|||
def get_S(self, int i):
|
||||
return self.c.S(i)
|
||||
|
||||
def push_stack(self, fast_forward=True):
|
||||
def can_push(self):
|
||||
return self.c.can_push()
|
||||
|
||||
def can_pop(self):
|
||||
return self.c.can_pop()
|
||||
|
||||
def can_break(self):
|
||||
return self.c.can_break()
|
||||
|
||||
def can_arc(self):
|
||||
return self.c.can_arc()
|
||||
|
||||
def push_stack(self):
|
||||
self.c.push()
|
||||
if fast_forward:
|
||||
self.c.fast_forward()
|
||||
|
||||
def pop_stack(self, fast_forward=True):
|
||||
def pop_stack(self):
|
||||
self.c.pop()
|
||||
if fast_forward:
|
||||
self.c.fast_forward()
|
||||
|
||||
def unshift(self, fast_forward=True):
|
||||
def unshift(self):
|
||||
self.c.unshift()
|
||||
if fast_forward:
|
||||
self.c.fast_forward()
|
||||
|
||||
def set_break(self, int i):
|
||||
self.c.set_break(i)
|
||||
|
||||
def split_token(self, int i, int n, fast_forward=True):
|
||||
def split_token(self, int i, int n):
|
||||
self.c.split(i, n)
|
||||
if fast_forward:
|
||||
self.c.fast_forward()
|
||||
|
||||
def get_doc(self, vocab):
|
||||
cdef Doc doc = Doc(vocab)
|
||||
|
|
|
@ -41,7 +41,6 @@ def test_init_parser(parser):
|
|||
pass
|
||||
|
||||
# TODO: This is flakey, because it depends on what the parser first learns.
|
||||
@pytest.mark.xfail
|
||||
def test_add_label(parser):
|
||||
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
|
||||
doc = parser(doc)
|
||||
|
|
|
@ -17,7 +17,10 @@ def vocab():
|
|||
|
||||
@pytest.fixture
|
||||
def moves(vocab):
|
||||
aeager = ArcEager(vocab.strings, {})
|
||||
aeager = ArcEager(vocab.strings)
|
||||
aeager.add_action(0, '')
|
||||
aeager.add_action(1, '')
|
||||
aeager.add_action(4, 'ROOT')
|
||||
aeager.add_action(2, 'nsubj')
|
||||
aeager.add_action(3, 'dobj')
|
||||
aeager.add_action(2, 'aux')
|
||||
|
|
|
@ -39,5 +39,5 @@ def test_split():
|
|||
doc = get_doc('abcd')
|
||||
state = StateClass(doc, max_split=3)
|
||||
assert state.queue == [0, 1, 2, 3]
|
||||
state.split_token(1, 2, fast_forward=False)
|
||||
state.split_token(1, 2)
|
||||
assert state.queue == [0, 1, 1*4+1, 2*4+1, 2, 3]
|
||||
|
|
Loading…
Reference in New Issue
Block a user