WIP on adding split-token actions to parser

This patch starts getting the StateC object ready to split tokens. The
split function is implemented by pushing indices into the buffer that
indicate an out-of-length token.

Still todo:

* Update the oracles
* Update GoldParseC
* Interpret the parse once it's complete
* Add retokenizer.split() method
This commit is contained in:
Matthew Honnibal 2018-03-31 20:05:27 +02:00
parent 3e3af01681
commit e5ad35787c
6 changed files with 103 additions and 98 deletions

View File

@ -13,7 +13,6 @@ from ..symbols cimport punct
from ..attrs cimport IS_SPACE
from ..typedefs cimport attr_t
cdef void _split(StateC* this, int i, int n) nogil
cdef inline bint is_space_token(const TokenC* token) nogil:
return Lexeme.c_check_flag(token.lex, IS_SPACE)
@ -44,12 +43,14 @@ cdef cppclass StateC:
Entity* _ents
TokenC _empty_token
RingBufferC _hist
int buffer_length
int max_split
int length
int offset
int _s_i
int _b_i
int _e_i
int _break
int _n_until_break
__init__(const TokenC* sent, int length) nogil:
cdef int PADDING = 5
@ -78,7 +79,9 @@ cdef cppclass StateC:
this._stack += PADDING
this.shifted += PADDING
this.length = length
this._break = -1
this.buffer_length = length
this.max_split = 0
this._n_until_break = -1
this._s_i = 0
this._b_i = 0
this._e_i = 0
@ -160,7 +163,9 @@ cdef cppclass StateC:
return this._stack[this._s_i - (i+1)]
int B(int i) nogil const:
if (i + this._b_i) >= this.length:
if i >= this.buffer_length:
return -1
if this._n_until_break != -1 and i >= this._n_until_break:
return -1
return this._buffer[this._b_i + i]
@ -254,13 +259,13 @@ cdef cppclass StateC:
return this._s_i <= 0
bint eol() nogil const:
return this.buffer_length() == 0
return this.buffer_length == 0 or this.at_break()
bint at_break() nogil const:
return this._break != -1
return this._n_until_break == 0
bint is_final() nogil const:
return this.stack_depth() <= 0 and this._b_i >= this.length
return this.stack_depth() <= 0 and this.buffer_length == 0
bint has_head(int i) nogil const:
return this.safe_get(i).head != 0
@ -282,12 +287,6 @@ cdef cppclass StateC:
int stack_depth() nogil const:
return this._s_i
int buffer_length() nogil const:
if this._break != -1:
return this._break - this._b_i
else:
return this.length - this._b_i
uint64_t hash() nogil const:
cdef TokenC[11] sig
sig[0] = this.S_(2)[0]
@ -311,46 +310,62 @@ cdef cppclass StateC:
return ring_get(&this._hist, i)
void push() nogil:
if this.B(0) != -1:
this._stack[this._s_i] = this.B(0)
if this.buffer_length != 0:
this._stack[this._s_i] = this._buffer[this._b_i]
if this._n_until_break != -1:
this._n_until_break -= 1
this._s_i += 1
this._b_i += 1
this.buffer_length -= 1
if this.B_(0).sent_start == 1:
this.set_break(this.B(0))
if this._b_i > this._break:
this._break = -1
this.set_break(0)
void split(int i, int n) nogil:
'''Split token i of the buffer into N pieces.'''
# Let's say we've got a length 10 sentence.
# state.split(5, 2)
# Before: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# After: [0, 1, 2, 3, 4, 5.0, 5.1, 5.2, 6, 7, 8, 9, 10]
# Sentence grows to length 12.
# Words 6-10 move to positions 8-12
# Words 0-5 stay where they are.
cdef int PADDING = 5
cdef int j
# Unwind the padding, so we can work with the original pointer.
this._sent -= PADDING
this._sent = <TokenC*>realloc(this._sent,
((this.length+n+1) + (PADDING * 2)) * sizeof(TokenC))
for j in range(this.length+PADDING*2, this.length+n+1+PADDING*2):
this._sent[j] = this._empty_token
# Put the start padding back in
this._sent += PADDING
# In our example, we want to move words 6-10 to 8-12. So we must move
# a block of 4 words.
cdef int n_moved = this.length - (i+1)
cdef int move_from = i+1
cdef int move_to = i+n+1
memmove(&this._sent[move_to], &this._sent[move_from],
n_moved*sizeof(TokenC))
# Now copy the token that has been split into its neighbours.
for j in range(i+1, i+n+1):
this._sent[j] = this._sent[i]
# Finally, adjust length.
this.length += n
# Let's say we've got a length 10 sentence. 4 is start of buffer.
# We do: state.split(1, 2)
#
# Old buffer: 4,5,6,7,8,9
# New buffer: 4,5,13,22,6,7,8,9
if (this._b_i+5*2) < n:
with gil:
raise NotImplementedError
# Let's say we're at token index 4. this._b_i will be 4, so that we
# point forward into the buffer. To insert, we don't need to reallocate
# -- we have space at the start; we can just shift the tokens between
# where we are at the buffer and where the split starts backwards to
# make room.
#
# For b_i=4, i=1, n=2 we want to have:
# Old buffer: [_, _, _, _, 4, 5, 6, 7, 8, 9] and b_i=4
# New buffer: [_, _, 4, 5, 13, 22, 6, 7, 8, 9] and b_i=2
# b_i will always move back by n in total, as that's
# the size of the gap we're creating.
# The number of values we have to copy will be i+1
# Another way to see it:
# For b_i=4, i=1, n=2
# buffer[2:4] = buffer[4:6]
# buffer[4:6] = new_tokens
# For b_i=7, i=1, n=1
# buffer[6:8] = buffer[7:9]
# buffer[8:9] = new_tokens
# For b_i=3, i=1, n=3
# buffer[0:2] = buffer[3:5]
# buffer[2:5] = new_tokens
# For b_i=5, i=3, n=1
# buffer[4:8] = buffer[5:9]
# buffer[8:9] = new_tokens
cdef int target = this.B(i)
this._b_i -= n
memmove(&this._buffer[this._b_i],
&this._buffer[this._b_i+n], (i+1)*sizeof(this._buffer[0]))
cdef int subtoken, new_token
for subtoken in range(n):
new_token = (subtoken+1) * this.length + target
this._buffer[this._b_i+(i+1)+subtoken] = new_token
this.buffer_length += n
if this._n_until_break != -1:
this._n_until_break += n
void pop() nogil:
if this._s_i >= 1:
@ -361,6 +376,9 @@ cdef cppclass StateC:
this._buffer[this._b_i] = this.S(0)
this._s_i -= 1
this.shifted[this.B(0)] = True
this.buffer_length += 1
if this._n_until_break != -1:
this._n_until_break += 1
void add_arc(int head, int child, attr_t label) nogil:
if this.has_head(child):
@ -424,12 +442,13 @@ cdef cppclass StateC:
this._sent[i].ent_type = ent_type
void set_break(int i) nogil:
if 0 <= i < this.length:
this._sent[i].sent_start = 1
this._break = this._b_i
if 0 <= i < this.buffer_length:
this._sent[this.B_(i).l_edge].sent_start = 1
this._n_until_break = i
void clone(const StateC* src) nogil:
this.length = src.length
this.buffer_length = src.buffer_length
memcpy(this._sent, src._sent, this.length * sizeof(TokenC))
memcpy(this._stack, src._stack, this.length * sizeof(int))
memcpy(this._buffer, src._buffer, this.length * sizeof(int))
@ -438,7 +457,7 @@ cdef cppclass StateC:
this._b_i = src._b_i
this._s_i = src._s_i
this._e_i = src._e_i
this._break = src._break
this._n_until_break = src._n_until_break
this.offset = src.offset
this._empty_token = src._empty_token
@ -450,9 +469,9 @@ cdef cppclass StateC:
# then make the last space token the head of all others
while is_space_token(this.B_(0)) \
or this.buffer_length() == 0 \
or this.eol() \
or this.stack_depth() == 0:
if this.buffer_length() == 0:
if this.eol():
# remove the last sentence's root from the stack
if this.stack_depth() == 1:
this.pop()
@ -463,7 +482,7 @@ cdef cppclass StateC:
else:
this.unshift()
# stack is empty but there is another sentence on the buffer
elif (this.length - this._b_i) >= 1:
elif this.buffer_length != 0:
this.push()
else: # stack empty and nothing else coming
break
@ -483,7 +502,7 @@ cdef cppclass StateC:
elif this.stack_depth() == 0:
# store all space tokens on the stack until a real token shows up
# or the last token on the buffer is reached
while is_space_token(this.B_(0)) and this.buffer_length() > 1:
while is_space_token(this.B_(0)) and this.buffer_length > 1:
this.push()
# empty the stack by attaching all space tokens to the
# first token on the buffer
@ -497,12 +516,12 @@ cdef cppclass StateC:
elif this.stack_depth() == 0:
# for one token sentences (?)
if this.buffer_length() == 1:
if this.buffer_length == 1:
this.push()
this.pop()
# with an empty stack and a non-empty buffer
# only shift is valid anyway
elif (this.length - this._b_i) >= 1:
elif this.buffer_length != 0:
this.push()
else: # can this even happen?

View File

@ -66,7 +66,7 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
cdef weight_t cost = 0
cdef int i, B_i
for i in range(stcls.buffer_length()):
for i in range(stcls.c.buffer_length):
B_i = stcls.B(i)
cost += gold.heads[B_i] == target
cost += gold.heads[target] == B_i
@ -118,7 +118,7 @@ cdef class Shift:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
sent_start = st._sent[st.B_(0).l_edge].sent_start
return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and sent_start != 1
return st.buffer_length >= 2 and not st.shifted[st.B(0)] and sent_start != 1
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
@ -137,10 +137,11 @@ cdef class Shift:
@staticmethod
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil:
if gold.fused_tokens[s.B(1)] == label:
return 0
else:
return 1
#if gold.fused_tokens[s.B(1)] == label:
# return 0
#else:
# return 1
cdef class Reduce:
@ -265,7 +266,7 @@ cdef class Break:
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.set_break(st.B_(0).l_edge)
st.set_break(0)
st.fast_forward()
@staticmethod
@ -278,7 +279,7 @@ cdef class Break:
cdef int i, j, S_i, B_i
for i in range(s.stack_depth()):
S_i = s.S(i)
for j in range(s.buffer_length()):
for j in range(s.c.buffer_length):
B_i = s.B(j)
cost += gold.heads[S_i] == B_i
cost += gold.heads[B_i] == S_i

View File

@ -10,6 +10,7 @@ from ..vocab cimport EMPTY_LEXEME
from ._state cimport StateC
@cython.final
cdef class StateClass:
cdef Pool mem
cdef StateC* c
@ -105,7 +106,7 @@ cdef class StateClass:
return self.c.stack_depth()
cdef inline int buffer_length(self) nogil:
return self.c.buffer_length()
return self.c.buffer_length
cdef inline void push(self) nogil:
self.c.push()

View File

@ -8,13 +8,14 @@ from ..tokens.doc cimport Doc
cdef class StateClass:
def __init__(self, Doc doc=None, int offset=0):
def __init__(self, Doc doc=None, int offset=0, int max_split=0):
cdef Pool mem = Pool()
self.mem = mem
self._borrowed = 0
if doc is not None:
self.c = new StateC(doc.c, doc.length)
self.c.offset = offset
self.c.max_split = max_split
def __dealloc__(self):
if self._borrowed != 1:
@ -39,6 +40,14 @@ cdef class StateClass:
if fast_forward:
self.c.fast_forward()
def unshift(self, fast_forward=True):
self.c.unshift()
if fast_forward:
self.c.fast_forward()
def set_break(self, int i):
self.c.set_break(i)
def split_token(self, int i, int n, fast_forward=True):
self.c.split(i, n)
if fast_forward:
@ -57,7 +66,7 @@ cdef class StateClass:
@property
def queue(self):
return {self.B(i) for i in range(self.c.buffer_length())}
return [self.B(i) for i in range(self.c.buffer_length)]
@property
def token_vector_lenth(self):

View File

@ -32,37 +32,12 @@ def test_pop():
assert state.get_S(0) == 0
def toy_split():
def _realloc(data, new_size):
additions = new_size - len(data)
return data + ['']*additions
length = 10
sent = list(range(length))
sent = [None]*pad + sent + [None]*pad # pad
ptr = pad
i = 5
n = 2
ptr -= pad
i += pad
sent = _realloc(sent, length+n+(pad*2))
n_moved = (length + (pad*2)) - i+1
def test_split():
'''state.split_token should take the ith word of the buffer, and split it
into n+1 pieces. n is 0-indexed, i.e. split(i, 0) is a noop, and split(i, 1)
creates 1 new token.'''
doc = get_doc('abcd')
state = StateClass(doc)
assert len(state) == len(doc)
state.split_token(1, 2)
assert len(state) == len(doc)+2
stdoc = state.get_doc(doc.vocab)
assert stdoc[0].text == 'a'
assert stdoc[1].text == 'b'
assert stdoc[2].text == 'b'
assert stdoc[3].text == 'b'
assert stdoc[4].text == 'c'
assert stdoc[5].text == 'd'
state = StateClass(doc, max_split=3)
assert state.queue == [0, 1, 2, 3]
state.split_token(1, 2, fast_forward=False)
assert state.queue == [0, 1, 1*4+1, 2*4+1, 2, 3]