From a7e6c5ac8f20587814784f7f44e3c92d44822dcd Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 18 Oct 2015 17:17:27 +1100 Subject: [PATCH] * Fix Issue #122: Incorrect calculation of children after Doc.merge() --- spacy/tokens/doc.pyx | 25 +++++++++++++++++-------- tests/tokens/test_tokens_api.py | 27 +++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 50b19d4c1..55a83913b 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -447,9 +447,8 @@ cdef class Doc: cdef Span span = self[start:end] # Get LexemeC for newly merged token - new_orth = ''.join([t.string for t in span]) - if span[-1].whitespace_: - new_orth = new_orth[:-1] + new_orth = ''.join([t.text_with_ws for t in span]) + new_orth = new_orth[:-len(span[-1].whitespace_)] cdef const LexemeC* lex = self.vocab.get(self.mem, new_orth) # House the new merged token where it starts cdef TokenC* token = &self.data[start] @@ -508,16 +507,26 @@ cdef int set_children_from_heads(TokenC* tokens, int length) except -1: cdef TokenC* head cdef TokenC* child cdef int i + # Set number of left/right children to 0. We'll increment it in the loops. + for i in range(length): + tokens[i].l_kids = 0 + tokens[i].r_kids = 0 + tokens[i].l_edge = i + tokens[i].r_edge = i # Set left edges for i in range(length): child = &tokens[i] head = &tokens[i + child.head] - if child < head and child.l_edge < head.l_edge: - head.l_edge = child.l_edge + if child < head: + if child.l_edge < head.l_edge: + head.l_edge = child.l_edge + head.l_kids += 1 + # Set right edges --- same as above, but iterate in reverse for i in range(length-1, -1, -1): child = &tokens[i] head = &tokens[i + child.head] - if child > head and child.r_edge > head.r_edge: - head.r_edge = child.r_edge - + if child > head: + if child.r_edge > head.r_edge: + head.r_edge = child.r_edge + head.r_kids += 1 diff --git a/tests/tokens/test_tokens_api.py b/tests/tokens/test_tokens_api.py index 675f00235..b40513b02 100644 --- a/tests/tokens/test_tokens_api.py +++ b/tests/tokens/test_tokens_api.py @@ -109,3 +109,30 @@ def test_set_ents(EN): assert ent.label_ == 'PRODUCT' assert ent.start == 2 assert ent.end == 4 + + +def test_merge(EN): + doc = EN('WKRO played songs by the beach boys all night') + + assert len(doc) == 9 + # merge 'The Beach Boys' + doc.merge(doc[4].idx, doc[6].idx + len(doc[6]), 'NAMED', 'LEMMA', 'TYPE') + assert len(doc) == 7 + + assert doc[4].text == 'the beach boys' + assert doc[4].text_with_ws == 'the beach boys ' + assert doc[4].tag_ == 'NAMED' + + +@pytest.mark.models +def test_merge_children(EN): + """Test that attachments work correctly after merging.""" + doc = EN('WKRO played songs by the beach boys all night') + # merge 'The Beach Boys' + doc.merge(doc[4].idx, doc[6].idx + len(doc[6]), 'NAMED', 'LEMMA', 'TYPE') + + for word in doc: + if word.i < word.head.i: + assert word in list(word.head.lefts) + elif word.i > word.head.i: + assert word in list(word.head.rights)