mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-10 15:14:56 +03:00
Rewrite oracle to not use fast-forward. Seems to work?
This commit is contained in:
parent
c5574f48c7
commit
e887b2330e
|
@ -157,6 +157,45 @@ cdef cppclass StateC:
|
||||||
else:
|
else:
|
||||||
ids[i] = -1
|
ids[i] = -1
|
||||||
|
|
||||||
|
int can_push() nogil const:
|
||||||
|
if this.buffer_length == 0:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
int can_pop() nogil const:
|
||||||
|
if this.stack_depth() < 1:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
int can_arc() nogil const:
|
||||||
|
if this.at_break():
|
||||||
|
return 0
|
||||||
|
elif this.stack_depth() < 1:
|
||||||
|
return 0
|
||||||
|
elif this.buffer_length == 0:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
int can_break() nogil const:
|
||||||
|
if this.buffer_length == 0:
|
||||||
|
return False
|
||||||
|
elif this.B_(0).l_edge < 0:
|
||||||
|
return False
|
||||||
|
elif this._sent[this.B_(0).l_edge].sent_start < 0:
|
||||||
|
return False
|
||||||
|
elif this.stack_depth() < 1: # ?? I guess stops first action break?
|
||||||
|
return False
|
||||||
|
elif this.at_break():
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
int can_split() nogil const:
|
||||||
|
return 0
|
||||||
|
|
||||||
int S(int i) nogil const:
|
int S(int i) nogil const:
|
||||||
if i >= this._s_i:
|
if i >= this._s_i:
|
||||||
return -1
|
return -1
|
||||||
|
@ -265,7 +304,7 @@ cdef cppclass StateC:
|
||||||
return this._n_until_break == 0
|
return this._n_until_break == 0
|
||||||
|
|
||||||
bint is_final() nogil const:
|
bint is_final() nogil const:
|
||||||
return this.stack_depth() <= 0 and this.buffer_length == 0
|
return this.stack_depth() <= 1 and this.buffer_length == 0
|
||||||
|
|
||||||
bint has_head(int i) nogil const:
|
bint has_head(int i) nogil const:
|
||||||
return this.safe_get(i).head != 0
|
return this.safe_get(i).head != 0
|
||||||
|
@ -287,6 +326,12 @@ cdef cppclass StateC:
|
||||||
int stack_depth() nogil const:
|
int stack_depth() nogil const:
|
||||||
return this._s_i
|
return this._s_i
|
||||||
|
|
||||||
|
int segment_length() nogil const:
|
||||||
|
if this._n_until_break != -1:
|
||||||
|
return this._n_until_break
|
||||||
|
else:
|
||||||
|
return this.buffer_length
|
||||||
|
|
||||||
uint64_t hash() nogil const:
|
uint64_t hash() nogil const:
|
||||||
cdef TokenC[11] sig
|
cdef TokenC[11] sig
|
||||||
sig[0] = this.S_(2)[0]
|
sig[0] = this.S_(2)[0]
|
||||||
|
@ -460,69 +505,3 @@ cdef cppclass StateC:
|
||||||
this._n_until_break = src._n_until_break
|
this._n_until_break = src._n_until_break
|
||||||
this.offset = src.offset
|
this.offset = src.offset
|
||||||
this._empty_token = src._empty_token
|
this._empty_token = src._empty_token
|
||||||
|
|
||||||
void fast_forward() nogil:
|
|
||||||
# space token attachement policy:
|
|
||||||
# - attach space tokens always to the last preceding real token
|
|
||||||
# - except if it's the beginning of a sentence, then attach to the first following
|
|
||||||
# - boundary case: a document containing multiple space tokens but nothing else,
|
|
||||||
# then make the last space token the head of all others
|
|
||||||
|
|
||||||
while is_space_token(this.B_(0)) \
|
|
||||||
or this.eol() \
|
|
||||||
or this.stack_depth() == 0:
|
|
||||||
if this.eol():
|
|
||||||
# remove the last sentence's root from the stack
|
|
||||||
if this.stack_depth() == 1:
|
|
||||||
this.pop()
|
|
||||||
# parser got stuck: reduce stack or unshift
|
|
||||||
elif this.stack_depth() > 1:
|
|
||||||
if this.has_head(this.S(0)):
|
|
||||||
this.pop()
|
|
||||||
else:
|
|
||||||
this.unshift()
|
|
||||||
# stack is empty but there is another sentence on the buffer
|
|
||||||
elif this.buffer_length != 0:
|
|
||||||
this.push()
|
|
||||||
else: # stack empty and nothing else coming
|
|
||||||
break
|
|
||||||
|
|
||||||
elif is_space_token(this.B_(0)):
|
|
||||||
# the normal case: we're somewhere inside a sentence
|
|
||||||
if this.stack_depth() > 0:
|
|
||||||
# assert not is_space_token(this.S_(0))
|
|
||||||
# attach all coming space tokens to their last preceding
|
|
||||||
# real token (which should be on the top of the stack)
|
|
||||||
while is_space_token(this.B_(0)):
|
|
||||||
this.add_arc(this.S(0),this.B(0),0)
|
|
||||||
this.push()
|
|
||||||
this.pop()
|
|
||||||
# the rare case: we're at the beginning of a document:
|
|
||||||
# space tokens are attached to the first real token on the buffer
|
|
||||||
elif this.stack_depth() == 0:
|
|
||||||
# store all space tokens on the stack until a real token shows up
|
|
||||||
# or the last token on the buffer is reached
|
|
||||||
while is_space_token(this.B_(0)) and this.buffer_length > 1:
|
|
||||||
this.push()
|
|
||||||
# empty the stack by attaching all space tokens to the
|
|
||||||
# first token on the buffer
|
|
||||||
# boundary case: if all tokens are space tokens, the last one
|
|
||||||
# becomes the head of all others
|
|
||||||
while this.stack_depth() > 0:
|
|
||||||
this.add_arc(this.B(0),this.S(0),0)
|
|
||||||
this.pop()
|
|
||||||
# move the first token onto the stack
|
|
||||||
this.push()
|
|
||||||
|
|
||||||
elif this.stack_depth() == 0:
|
|
||||||
# for one token sentences (?)
|
|
||||||
if this.buffer_length == 1:
|
|
||||||
this.push()
|
|
||||||
this.pop()
|
|
||||||
# with an empty stack and a non-empty buffer
|
|
||||||
# only shift is valid anyway
|
|
||||||
elif this.buffer_length != 0:
|
|
||||||
this.push()
|
|
||||||
|
|
||||||
else: # can this even happen?
|
|
||||||
break
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
|
||||||
cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
|
cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
|
||||||
cdef weight_t cost = 0
|
cdef weight_t cost = 0
|
||||||
cdef int i, B_i
|
cdef int i, B_i
|
||||||
for i in range(stcls.c.buffer_length):
|
for i in range(stcls.c.segment_length()):
|
||||||
B_i = stcls.B(i)
|
B_i = stcls.B(i)
|
||||||
cost += gold.heads[B_i] == target
|
cost += gold.heads[B_i] == target
|
||||||
cost += gold.heads[target] == B_i
|
cost += gold.heads[target] == B_i
|
||||||
|
@ -74,8 +74,8 @@ cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nog
|
||||||
break
|
break
|
||||||
if BINARY_COSTS and cost >= 1:
|
if BINARY_COSTS and cost >= 1:
|
||||||
return cost
|
return cost
|
||||||
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
|
#if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
|
||||||
cost += 1
|
# cost += 1
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,15 +117,23 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
|
||||||
cdef class Shift:
|
cdef class Shift:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||||
sent_start = st._sent[st.B_(0).l_edge].sent_start
|
if not st.can_push():
|
||||||
return st.buffer_length >= 2 and not st.shifted[st.B(0)] and sent_start != 1
|
return False
|
||||||
|
elif st.stack_depth() == 0: # If the stack is empty, we must push
|
||||||
|
return True
|
||||||
|
elif st.shifted[st.B(0)]:
|
||||||
|
return False
|
||||||
|
elif st.at_break():
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(StateC* st, attr_t label) nogil:
|
cdef int transition(StateC* st, attr_t label) nogil:
|
||||||
if label != 0:
|
#if label != 0:
|
||||||
st.split(st.B(1), label)
|
# st.split(st.B(1), label)
|
||||||
|
st.shifted[st.B(0)] = 1
|
||||||
st.push()
|
st.push()
|
||||||
st.fast_forward()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
|
cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil:
|
||||||
|
@ -138,7 +146,7 @@ cdef class Shift:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
||||||
return 0
|
return 0
|
||||||
#if gold.fused_tokens[s.B(1)] == label:
|
#if gold.fused_tokens[s.B(1)] == label: TODO
|
||||||
# return 0
|
# return 0
|
||||||
#else:
|
#else:
|
||||||
# return 1
|
# return 1
|
||||||
|
@ -147,15 +155,21 @@ cdef class Shift:
|
||||||
cdef class Reduce:
|
cdef class Reduce:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||||
return st.stack_depth() >= 2
|
if st.stack_depth() >= 2:
|
||||||
|
return True
|
||||||
|
elif st.at_break() and st.stack_depth() == 1:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(StateC* st, attr_t label) nogil:
|
cdef int transition(StateC* st, attr_t label) nogil:
|
||||||
if st.has_head(st.S(0)):
|
if st.has_head(st.S(0)):
|
||||||
st.pop()
|
st.pop()
|
||||||
|
elif st.stack_depth() == 1 and st.at_break():
|
||||||
|
st.pop()
|
||||||
else:
|
else:
|
||||||
st.unshift()
|
st.unshift()
|
||||||
st.fast_forward()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
||||||
|
@ -165,15 +179,15 @@ cdef class Reduce:
|
||||||
cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
|
cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
|
||||||
cost = pop_cost(st, gold, st.S(0))
|
cost = pop_cost(st, gold, st.S(0))
|
||||||
if not st.has_head(st.S(0)):
|
if not st.has_head(st.S(0)):
|
||||||
# Decrement cost for the arcs e save
|
# Decrement cost for the arcs we save
|
||||||
for i in range(1, st.stack_depth()):
|
for i in range(1, st.stack_depth()):
|
||||||
S_i = st.S(i)
|
S_i = st.S(i)
|
||||||
if gold.heads[st.S(0)] == S_i:
|
if gold.heads[st.S(0)] == S_i:
|
||||||
cost -= 1
|
cost -= 1
|
||||||
if gold.heads[S_i] == st.S(0):
|
if gold.heads[S_i] == st.S(0):
|
||||||
cost -= 1
|
cost -= 1
|
||||||
if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0:
|
#if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0:
|
||||||
cost -= 1
|
# cost -= 1
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -184,18 +198,18 @@ cdef class Reduce:
|
||||||
cdef class LeftArc:
|
cdef class LeftArc:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||||
sent_start = st._sent[st.B_(0).l_edge].sent_start
|
return st.can_arc()
|
||||||
return sent_start != 1
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(StateC* st, attr_t label) nogil:
|
cdef int transition(StateC* st, attr_t label) nogil:
|
||||||
st.add_arc(st.B(0), st.S(0), label)
|
st.add_arc(st.B(0), st.S(0), label)
|
||||||
st.pop()
|
st.pop()
|
||||||
st.fast_forward()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
||||||
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
|
cdef weight_t move_cost = LeftArc.move_cost(s, gold)
|
||||||
|
cdef weight_t label_cost = LeftArc.label_cost(s, gold, label)
|
||||||
|
return move_cost + label_cost
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
|
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||||
|
@ -220,14 +234,17 @@ cdef class RightArc:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||||
# If there's (perhaps partial) parse pre-set, don't allow cycle.
|
# If there's (perhaps partial) parse pre-set, don't allow cycle.
|
||||||
sent_start = st._sent[st.B_(0).l_edge].sent_start
|
if not st.can_arc():
|
||||||
return sent_start != 1 and st.H(st.S(0)) != st.B(0)
|
return 0
|
||||||
|
elif st.H(st.S(0)) == st.B(0):
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(StateC* st, attr_t label) nogil:
|
cdef int transition(StateC* st, attr_t label) nogil:
|
||||||
st.add_arc(st.S(0), st.B(0), label)
|
st.add_arc(st.S(0), st.B(0), label)
|
||||||
st.push()
|
st.push()
|
||||||
st.fast_forward()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef inline weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
||||||
|
@ -253,21 +270,13 @@ cdef class Break:
|
||||||
cdef int i
|
cdef int i
|
||||||
if not USE_BREAK:
|
if not USE_BREAK:
|
||||||
return False
|
return False
|
||||||
elif st.at_break():
|
|
||||||
return False
|
|
||||||
elif st.stack_depth() < 1:
|
|
||||||
return False
|
|
||||||
elif st.B_(0).l_edge < 0:
|
|
||||||
return False
|
|
||||||
elif st._sent[st.B_(0).l_edge].sent_start < 0:
|
|
||||||
return False
|
|
||||||
else:
|
else:
|
||||||
return True
|
return st.can_break()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(StateC* st, attr_t label) nogil:
|
cdef int transition(StateC* st, attr_t label) nogil:
|
||||||
st.set_break(0)
|
st.set_break(0)
|
||||||
st.fast_forward()
|
st.pop()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
|
||||||
|
@ -317,7 +326,6 @@ cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
||||||
st._sent[i].dep = 0
|
st._sent[i].dep = 0
|
||||||
st._sent[i].l_kids = 0
|
st._sent[i].l_kids = 0
|
||||||
st._sent[i].r_kids = 0
|
st._sent[i].r_kids = 0
|
||||||
st.fast_forward()
|
|
||||||
return <void*>st
|
return <void*>st
|
||||||
|
|
||||||
|
|
||||||
|
@ -520,7 +528,6 @@ cdef class ArcEager(TransitionSystem):
|
||||||
st._sent[i].dep = 0
|
st._sent[i].dep = 0
|
||||||
st._sent[i].l_kids = 0
|
st._sent[i].l_kids = 0
|
||||||
st._sent[i].r_kids = 0
|
st._sent[i].r_kids = 0
|
||||||
st.fast_forward()
|
|
||||||
|
|
||||||
cdef int finalize_state(self, StateC* st) nogil:
|
cdef int finalize_state(self, StateC* st) nogil:
|
||||||
cdef int i
|
cdef int i
|
||||||
|
|
|
@ -137,6 +137,3 @@ cdef class StateClass:
|
||||||
|
|
||||||
cdef inline void clone(self, StateClass src) nogil:
|
cdef inline void clone(self, StateClass src) nogil:
|
||||||
self.c.clone(src.c)
|
self.c.clone(src.c)
|
||||||
|
|
||||||
cdef inline void fast_forward(self) nogil:
|
|
||||||
self.c.fast_forward()
|
|
||||||
|
|
|
@ -30,28 +30,32 @@ cdef class StateClass:
|
||||||
def get_S(self, int i):
|
def get_S(self, int i):
|
||||||
return self.c.S(i)
|
return self.c.S(i)
|
||||||
|
|
||||||
def push_stack(self, fast_forward=True):
|
def can_push(self):
|
||||||
|
return self.c.can_push()
|
||||||
|
|
||||||
|
def can_pop(self):
|
||||||
|
return self.c.can_pop()
|
||||||
|
|
||||||
|
def can_break(self):
|
||||||
|
return self.c.can_break()
|
||||||
|
|
||||||
|
def can_arc(self):
|
||||||
|
return self.c.can_arc()
|
||||||
|
|
||||||
|
def push_stack(self):
|
||||||
self.c.push()
|
self.c.push()
|
||||||
if fast_forward:
|
|
||||||
self.c.fast_forward()
|
|
||||||
|
|
||||||
def pop_stack(self, fast_forward=True):
|
def pop_stack(self):
|
||||||
self.c.pop()
|
self.c.pop()
|
||||||
if fast_forward:
|
|
||||||
self.c.fast_forward()
|
|
||||||
|
|
||||||
def unshift(self, fast_forward=True):
|
def unshift(self):
|
||||||
self.c.unshift()
|
self.c.unshift()
|
||||||
if fast_forward:
|
|
||||||
self.c.fast_forward()
|
|
||||||
|
|
||||||
def set_break(self, int i):
|
def set_break(self, int i):
|
||||||
self.c.set_break(i)
|
self.c.set_break(i)
|
||||||
|
|
||||||
def split_token(self, int i, int n, fast_forward=True):
|
def split_token(self, int i, int n):
|
||||||
self.c.split(i, n)
|
self.c.split(i, n)
|
||||||
if fast_forward:
|
|
||||||
self.c.fast_forward()
|
|
||||||
|
|
||||||
def get_doc(self, vocab):
|
def get_doc(self, vocab):
|
||||||
cdef Doc doc = Doc(vocab)
|
cdef Doc doc = Doc(vocab)
|
||||||
|
|
|
@ -41,7 +41,6 @@ def test_init_parser(parser):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# TODO: This is flakey, because it depends on what the parser first learns.
|
# TODO: This is flakey, because it depends on what the parser first learns.
|
||||||
@pytest.mark.xfail
|
|
||||||
def test_add_label(parser):
|
def test_add_label(parser):
|
||||||
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
|
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
|
||||||
doc = parser(doc)
|
doc = parser(doc)
|
||||||
|
|
|
@ -17,7 +17,10 @@ def vocab():
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def moves(vocab):
|
def moves(vocab):
|
||||||
aeager = ArcEager(vocab.strings, {})
|
aeager = ArcEager(vocab.strings)
|
||||||
|
aeager.add_action(0, '')
|
||||||
|
aeager.add_action(1, '')
|
||||||
|
aeager.add_action(4, 'ROOT')
|
||||||
aeager.add_action(2, 'nsubj')
|
aeager.add_action(2, 'nsubj')
|
||||||
aeager.add_action(3, 'dobj')
|
aeager.add_action(3, 'dobj')
|
||||||
aeager.add_action(2, 'aux')
|
aeager.add_action(2, 'aux')
|
||||||
|
|
|
@ -39,5 +39,5 @@ def test_split():
|
||||||
doc = get_doc('abcd')
|
doc = get_doc('abcd')
|
||||||
state = StateClass(doc, max_split=3)
|
state = StateClass(doc, max_split=3)
|
||||||
assert state.queue == [0, 1, 2, 3]
|
assert state.queue == [0, 1, 2, 3]
|
||||||
state.split_token(1, 2, fast_forward=False)
|
state.split_token(1, 2)
|
||||||
assert state.queue == [0, 1, 1*4+1, 2*4+1, 2, 3]
|
assert state.queue == [0, 1, 1*4+1, 2*4+1, 2, 3]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user