Merge pull request #346 from wbwseeker/sentbnd_bug

introduce sentence boundaries for additional root tokens
This commit is contained in:
Matthew Honnibal 2016-04-25 20:31:27 +10:00
commit feb65fcaa1
3 changed files with 157 additions and 19 deletions

View File

@ -399,31 +399,34 @@ cdef class ArcEager(TransitionSystem):
cdef TokenC* orig_head cdef TokenC* orig_head
cdef int new_edge cdef int new_edge
cdef int child_i cdef int child_i
cdef TokenC* head_i cdef int head_i
for i in range(st.length): for i in range(st.length):
if st._sent[i].head == 0 and st._sent[i].dep == 0: if st._sent[i].head == 0 and st._sent[i].dep == 0:
st._sent[i].dep = self.root_label st._sent[i].dep = self.root_label
# If we're not using the Break transition, we segment via root-labelled # If we're not using the Break transition, we segment via root-labelled
# arcs between the root words. # arcs between the root words.
elif USE_ROOT_ARC_SEGMENT and st._sent[i].dep == self.root_label: elif USE_ROOT_ARC_SEGMENT and st._sent[i].dep == self.root_label:
orig_head_id = st._sent[i].head orig_head_id = i + st._sent[i].head
orig_head = &st._sent[orig_head_id] orig_head = &st._sent[orig_head_id]
if i < orig_head_id: # i is left dependent if i < orig_head_id: # i is left dependent
orig_head.l_kids -= 1 orig_head.l_kids -= 1
if i == orig_head.l_edge: # i is left-most child if i == orig_head.l_edge: # i is left-most child
# find the second left-most child and make it the new l_edge # find the second left-most child and make it the new l_edge
new_edge = orig_head_id new_edge = orig_head_id
child_i = i child_i = i+1
while child_i < orig_head_id: while child_i < orig_head_id:
if st._sent[child_i].head == orig_head_id: if child_i + st._sent[child_i].head == orig_head_id:
new_edge = child_i new_edge = child_i
break
child_i += 1 child_i += 1
# then walk up the path to root and update the l_edges of all ancestors # then walk up the path to root and update the l_edges of all ancestors
# the logic here works because the tree is guaranteed to be projective # the logic here works because the tree is guaranteed to be projective
head_i = &st._sent[orig_head.head] head_i = orig_head_id + orig_head.head
while head_i.l_edge == orig_head.l_edge: while st._sent[head_i].l_edge == orig_head.l_edge:
head_i.l_edge = new_edge st._sent[head_i].l_edge = new_edge
head_i = &st._sent[head_i.head] if st._sent[head_i].head == 0:
break
head_i += st._sent[head_i].head
orig_head.l_edge = new_edge orig_head.l_edge = new_edge
elif i > orig_head_id: # i is right dependent elif i > orig_head_id: # i is right dependent
@ -431,22 +434,27 @@ cdef class ArcEager(TransitionSystem):
if i == orig_head.r_edge: if i == orig_head.r_edge:
# find the second right-most child and make it the new r_edge # find the second right-most child and make it the new r_edge
new_edge = orig_head_id new_edge = orig_head_id
child_i = i child_i = i-1
while child_i > orig_head_id: while child_i > orig_head_id:
if st._sent[child_i].head == orig_head_id: if child_i + st._sent[child_i].head == orig_head_id:
new_edge = child_i new_edge = child_i
break
child_i -= 1 child_i -= 1
# then walk up the path to root and update the l_edges of all ancestors # then walk up the path to root and update the r_edges of all ancestors
# the logic here works because the tree is guaranteed to be projective # the logic here works because the tree is guaranteed to be projective
head_i = &st._sent[orig_head.head] head_i = orig_head_id + orig_head.head
while head_i.r_edge == orig_head.r_edge: while st._sent[head_i].r_edge == orig_head.r_edge:
head_i.r_edge = new_edge st._sent[head_i].r_edge = new_edge
head_i = &st._sent[head_i.head] if st._sent[head_i].head == 0:
break
head_i += st._sent[head_i].head
orig_head.r_edge = new_edge orig_head.r_edge = new_edge
# note that this can create non-projective trees if there are arcs # note that this may create non-projective trees if there are arcs
# between nodes on both sides of the new root node # between nodes on both sides of the new root node
st._sent[i].head = 0 st._sent[i].head = 0
st._sent[st._sent[i].l_edge].sent_start = True
cdef int set_valid(self, int* output, const StateC* st) nogil: cdef int set_valid(self, int* output, const StateC* st) nogil:
cdef bint[N_MOVES] is_valid cdef bint[N_MOVES] is_valid

View File

@ -19,3 +19,80 @@ def test_one_word_sentence(EN):
with EN.parser.step_through(doc) as _: with EN.parser.step_through(doc) as _:
pass pass
assert doc[0].dep != 0 assert doc[0].dep != 0
def apply_transition_sequence(model, doc, sequence):
with model.parser.step_through(doc) as stepwise:
for transition in sequence:
stepwise.transition(transition)
@pytest.mark.models
def test_arc_eager_finalize_state(EN):
# right branching
example = EN.tokenizer.tokens_from_list(u"a b c d e".split(' '))
apply_transition_sequence(EN, example, ['R-nsubj','D','R-nsubj','R-nsubj','D','R-ROOT'])
assert example[0].n_lefts == 0
assert example[0].n_rights == 2
assert example[0].left_edge.i == 0
assert example[0].right_edge.i == 3
assert example[0].head.i == 0
assert example[1].n_lefts == 0
assert example[1].n_rights == 0
assert example[1].left_edge.i == 1
assert example[1].right_edge.i == 1
assert example[1].head.i == 0
assert example[2].n_lefts == 0
assert example[2].n_rights == 1
assert example[2].left_edge.i == 2
assert example[2].right_edge.i == 3
assert example[2].head.i == 0
assert example[3].n_lefts == 0
assert example[3].n_rights == 0
assert example[3].left_edge.i == 3
assert example[3].right_edge.i == 3
assert example[3].head.i == 2
assert example[4].n_lefts == 0
assert example[4].n_rights == 0
assert example[4].left_edge.i == 4
assert example[4].right_edge.i == 4
assert example[4].head.i == 4
# left branching
example = EN.tokenizer.tokens_from_list(u"a b c d e".split(' '))
apply_transition_sequence(EN, example, ['S','L-nsubj','L-ROOT','S','L-nsubj','L-nsubj'])
assert example[0].n_lefts == 0
assert example[0].n_rights == 0
assert example[0].left_edge.i == 0
assert example[0].right_edge.i == 0
assert example[0].head.i == 0
assert example[1].n_lefts == 0
assert example[1].n_rights == 0
assert example[1].left_edge.i == 1
assert example[1].right_edge.i == 1
assert example[1].head.i == 2
assert example[2].n_lefts == 1
assert example[2].n_rights == 0
assert example[2].left_edge.i == 1
assert example[2].right_edge.i == 2
assert example[2].head.i == 4
assert example[3].n_lefts == 0
assert example[3].n_rights == 0
assert example[3].left_edge.i == 3
assert example[3].right_edge.i == 3
assert example[3].head.i == 4
assert example[4].n_lefts == 2
assert example[4].n_rights == 0
assert example[4].left_edge.i == 1
assert example[4].right_edge.i == 4
assert example[4].head.i == 4

View File

@ -1,7 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import pytest import pytest
from spacy.tokens import Doc
@pytest.mark.models @pytest.mark.models
@ -42,7 +42,7 @@ def test_single_question(EN):
@pytest.mark.models @pytest.mark.models
def test_sentence_breaks_no_space(EN): def test_sentence_breaks_no_space(EN):
doc = EN.tokenizer.tokens_from_list('This is a sentence . This is another one .'.split(' ')) doc = EN.tokenizer.tokens_from_list(u'This is a sentence . This is another one .'.split(' '))
EN.tagger(doc) EN.tagger(doc)
with EN.parser.step_through(doc) as stepwise: with EN.parser.step_through(doc) as stepwise:
# stack empty, automatic Shift (This) # stack empty, automatic Shift (This)
@ -83,7 +83,7 @@ def test_sentence_breaks_no_space(EN):
@pytest.mark.models @pytest.mark.models
def test_sentence_breaks_with_space(EN): def test_sentence_breaks_with_space(EN):
doc = EN.tokenizer.tokens_from_list('\t This is \n a sentence \n \n . \n \t \n This is another \t one .'.split(' ')) doc = EN.tokenizer.tokens_from_list(u'\t This is \n a sentence \n \n . \n \t \n This is another \t one .'.split(' '))
EN.tagger(doc) EN.tagger(doc)
with EN.parser.step_through(doc) as stepwise: with EN.parser.step_through(doc) as stepwise:
# stack empty, automatic Shift (This) # stack empty, automatic Shift (This)
@ -120,3 +120,56 @@ def test_sentence_breaks_with_space(EN):
for tok in doc: for tok in doc:
assert tok.dep != 0 or tok.is_space assert tok.dep != 0 or tok.is_space
assert [ tok.head.i for tok in doc ] == [1,2,2,2,5,2,5,5,2,8,8,8,13,13,16,14,13,13] assert [ tok.head.i for tok in doc ] == [1,2,2,2,5,2,5,5,2,8,8,8,13,13,16,14,13,13]
def apply_transition_sequence(model, doc, sequence):
with model.parser.step_through(doc) as stepwise:
for transition in sequence:
stepwise.transition(transition)
@pytest.mark.models
def test_sbd_for_root_label_dependents(EN):
"""
make sure that the parser properly introduces a sentence boundary without
the break transition by checking for dependents with the root label
"""
example = EN.tokenizer.tokens_from_list(u"I saw a firefly It glowed".split(' '))
EN.tagger(example)
apply_transition_sequence(EN, example, ['L-nsubj','S','L-det','R-dobj','D','S','L-nsubj','R-ROOT'])
assert example[1].head.i == 1
assert example[5].head.i == 5
sents = list(example.sents)
assert len(sents) == 2
assert sents[1][0].orth_ == u'It'
@pytest.mark.models
def test_sbd_serialization(EN):
"""
test that before and after serialization, the sentence boundaries are the same even
if the parser predicted two roots for the sentence that were made into two sentences
after parsing by arc_eager.finalize()
This is actually an interaction between the sentence boundary prediction and doc.from_array
The process is the following: if the parser doesn't predict a sentence boundary but attaches
a word with the ROOT label, the second root node is made root of its own sentence after parsing.
During serialization, sentence boundary information is lost and reintroduced when the code
is deserialized by introducing sentence starts at every left-edge of every root node.
BUG that is tested here: So far, the parser wasn't introducing a sentence start when
it introduced the second root node.
"""
example = EN.tokenizer.tokens_from_list(u"I bought a couch from IKEA. It was n't very comfortable .".split(' '))
EN.tagger(example)
apply_transition_sequence(EN, example, ['L-nsubj','S','L-det','R-dobj','D','R-prep','R-pobj','D','D','S','L-nsubj','R-ROOT','R-neg','D','S','L-advmod','R-acomp','D','R-punct'])
example_serialized = Doc(EN.vocab).from_bytes(example.to_bytes())
assert example.to_bytes() == example_serialized.to_bytes()
assert [s.text for s in example.sents] == [s.text for s in example_serialized.sents]