From 8d531c958b8ac6d58e527162c10bca8e5885d916 Mon Sep 17 00:00:00 2001 From: Wolfgang Seeker Date: Mon, 22 Feb 2016 14:40:40 +0100 Subject: [PATCH] replace tests for non-projectivity - add functions to find non-projective edges - add test file for non-projectivity functions --- spacy/gold.pyx | 71 ++++++++++++++----------------------- spacy/nonproj.py | 55 ++++++++++++++++++++++++++++ spacy/tagger.pyx | 5 +++ spacy/tests/test_nonproj.py | 42 ++++++++++++++++++++++ 4 files changed, 128 insertions(+), 45 deletions(-) create mode 100644 spacy/nonproj.py create mode 100644 spacy/tests/test_nonproj.py diff --git a/spacy/gold.pyx b/spacy/gold.pyx index d8b100744..dd29a42c7 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -14,6 +14,8 @@ try: except ImportError: import json +import nonproj + def tags_to_entities(tags): entities = [] @@ -236,34 +238,20 @@ cdef class GoldParse: self.heads[i] = self.gold_to_cand[annot_tuples[3][gold_i]] self.labels[i] = annot_tuples[4][gold_i] self.ner[i] = annot_tuples[5][gold_i] - - # If we have any non-projective arcs, i.e. crossing brackets, consider - # the heads for those words missing in the gold-standard. - # This way, we can train from these sentences - cdef int w1, w2, h1, h2 - if make_projective: - heads = list(self.heads) - for w1 in range(self.length): - if heads[w1] is not None: - h1 = heads[w1] - for w2 in range(w1+1, self.length): - if heads[w2] is not None: - h2 = heads[w2] - if _arcs_cross(w1, h1, w2, h2): - self.heads[w1] = None - self.labels[w1] = '' - self.heads[w2] = None - self.labels[w2] = '' - # Check there are no cycles in the dependencies, i.e. we are a tree - for w in range(self.length): - seen = set([w]) - head = w - while self.heads[head] != head and self.heads[head] != None: - head = self.heads[head] - if head in seen: - raise Exception("Cycle found: %s" % seen) - seen.add(head) + cycle = nonproj.contains_cycle(self.heads) + if cycle != None: + raise Exception("Cycle found: %s" % cycle) + + if make_projective: + # projectivity here means non-proj arcs are being disconnected + np_arcs = [] + for word in range(self.length): + if nonproj.is_non_projective_arc(word,self.heads): + np_arcs.append(word) + for np_arc in np_arcs: + self.heads[np_arc] = None + self.labels[np_arc] = '' self.brackets = {} for (gold_start, gold_end, label_str) in brackets: @@ -278,25 +266,18 @@ cdef class GoldParse: @property def is_projective(self): - heads = list(self.heads) - for w1 in range(self.length): - if heads[w1] is not None: - h1 = heads[w1] - for w2 in range(self.length): - if heads[w2] is not None and _arcs_cross(w1, h1, w2, heads[w2]): - return False - return True - - -cdef int _arcs_cross(int w1, int h1, int w2, int h2) except -1: - if w1 > h1: - w1, h1 = h1, w1 - if w2 > h2: - w2, h2 = h2, w2 - if w1 > w2: - w1, h1, w2, h2 = w2, h2, w1, h1 - return w1 < w2 < h1 < h2 or w1 < w2 == h2 < h1 + return not nonproj.is_non_projective_tree(self.heads) def is_punct_label(label): return label == 'P' or label.lower() == 'punct' + + + + + + + + + + diff --git a/spacy/nonproj.py b/spacy/nonproj.py new file mode 100644 index 000000000..58f9f3e9b --- /dev/null +++ b/spacy/nonproj.py @@ -0,0 +1,55 @@ + + +def ancestors(word, heads): + # returns all words going from the word up the path to the root + # the path to root cannot be longer than the number of words in the sentence + # this function ends after at most len(heads) steps + # because it would otherwise loop indefinitely on cycles + head = word + cnt = 0 + while heads[head] != head and cnt < len(heads): + head = heads[head] + cnt += 1 + yield head + if head == None: + break + + +def contains_cycle(heads): + # in an acyclic tree, the path from each word following + # the head relation upwards always ends at the root node + for word in range(len(heads)): + seen = set([word]) + for ancestor in ancestors(word,heads): + if ancestor in seen: + return seen + seen.add(ancestor) + return None + + +def is_non_projective_arc(word, heads): + # definition (e.g. Havelka 2007): an arc h -> d, h < d is non-projective + # if there is a word k, h < k < d such that h is not + # an ancestor of k. Same for h -> d, h > d + head = heads[word] + if head == word: # root arcs cannot be non-projective + return False + elif head == None: # unattached tokens cannot be non-projective + return False + + start, end = (head+1, word) if head < word else (word+1, head) + for k in range(start,end): + for ancestor in ancestors(k,heads): + if ancestor == None: # for unattached tokens/subtrees + break + elif ancestor == head: # normal case: k dominated by h + break + else: # head not in ancestors: d -> h is non-projective + return True + return False + + +def is_non_projective_tree(heads): + # a tree is non-projective if at least one arc is non-projective + return any( is_non_projective_arc(word,heads) for word in range(len(heads)) ) + diff --git a/spacy/tagger.pyx b/spacy/tagger.pyx index 26f8fd3e5..1c5baced7 100644 --- a/spacy/tagger.pyx +++ b/spacy/tagger.pyx @@ -211,6 +211,11 @@ cdef class Tagger: tokens.is_tagged = True tokens._py_tokens = [None] * tokens.length + def tags_from_list(self, Doc tokens, list strings): + assert(tokens.length == len(strings)) + for i in range(tokens.length): + self.vocab.morphology.assign_tag(&tokens.c[i], strings[i]) + def pipe(self, stream, batch_size=1000, n_threads=2): for doc in stream: self(doc) diff --git a/spacy/tests/test_nonproj.py b/spacy/tests/test_nonproj.py new file mode 100644 index 000000000..bd7f12bff --- /dev/null +++ b/spacy/tests/test_nonproj.py @@ -0,0 +1,42 @@ +from __future__ import unicode_literals +import pytest + +from spacy.nonproj import ancestors, contains_cycle, is_non_projective_arc, is_non_projective_tree + +def test_ancestors(): + tree = [1,2,2,4,5,2,2] + cyclic_tree = [1,2,2,4,5,3,2] + partial_tree = [1,2,2,4,5,None,2] + assert([ a for a in ancestors(3,tree) ] == [4,5,2]) + assert([ a for a in ancestors(3,cyclic_tree) ] == [4,5,3,4,5,3,4]) + assert([ a for a in ancestors(3,partial_tree) ] == [4,5,None]) + +def test_contains_cycle(): + tree = [1,2,2,4,5,2,2] + cyclic_tree = [1,2,2,4,5,3,2] + partial_tree = [1,2,2,4,5,None,2] + assert(contains_cycle(tree) == None) + assert(contains_cycle(cyclic_tree) == set([3,4,5])) + assert(contains_cycle(partial_tree) == None) + +def test_is_non_projective_arc(): + nonproj_tree = [1,2,2,4,5,2,7,4,2] + assert(is_non_projective_arc(0,nonproj_tree) == False) + assert(is_non_projective_arc(1,nonproj_tree) == False) + assert(is_non_projective_arc(2,nonproj_tree) == False) + assert(is_non_projective_arc(3,nonproj_tree) == False) + assert(is_non_projective_arc(4,nonproj_tree) == False) + assert(is_non_projective_arc(5,nonproj_tree) == False) + assert(is_non_projective_arc(6,nonproj_tree) == False) + assert(is_non_projective_arc(7,nonproj_tree) == True) + assert(is_non_projective_arc(8,nonproj_tree) == False) + partial_tree = [1,2,2,4,5,None,7,4,2] + assert(is_non_projective_arc(7,partial_tree) == False) + +def test_is_non_projective_tree(): + proj_tree = [1,2,2,4,5,2,7,5,2] + nonproj_tree = [1,2,2,4,5,2,7,4,2] + partial_tree = [1,2,2,4,5,None,7,4,2] + assert(is_non_projective_tree(proj_tree) == False) + assert(is_non_projective_tree(nonproj_tree) == True) + assert(is_non_projective_tree(partial_tree) == False)