mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 09:44:36 +03:00
replace tests for non-projectivity
- add functions to find non-projective edges - add test file for non-projectivity functions
This commit is contained in:
parent
eae35e9b27
commit
8d531c958b
|
@ -14,6 +14,8 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
import nonproj
|
||||||
|
|
||||||
|
|
||||||
def tags_to_entities(tags):
|
def tags_to_entities(tags):
|
||||||
entities = []
|
entities = []
|
||||||
|
@ -236,34 +238,20 @@ cdef class GoldParse:
|
||||||
self.heads[i] = self.gold_to_cand[annot_tuples[3][gold_i]]
|
self.heads[i] = self.gold_to_cand[annot_tuples[3][gold_i]]
|
||||||
self.labels[i] = annot_tuples[4][gold_i]
|
self.labels[i] = annot_tuples[4][gold_i]
|
||||||
self.ner[i] = annot_tuples[5][gold_i]
|
self.ner[i] = annot_tuples[5][gold_i]
|
||||||
|
|
||||||
# If we have any non-projective arcs, i.e. crossing brackets, consider
|
|
||||||
# the heads for those words missing in the gold-standard.
|
|
||||||
# This way, we can train from these sentences
|
|
||||||
cdef int w1, w2, h1, h2
|
|
||||||
if make_projective:
|
|
||||||
heads = list(self.heads)
|
|
||||||
for w1 in range(self.length):
|
|
||||||
if heads[w1] is not None:
|
|
||||||
h1 = heads[w1]
|
|
||||||
for w2 in range(w1+1, self.length):
|
|
||||||
if heads[w2] is not None:
|
|
||||||
h2 = heads[w2]
|
|
||||||
if _arcs_cross(w1, h1, w2, h2):
|
|
||||||
self.heads[w1] = None
|
|
||||||
self.labels[w1] = ''
|
|
||||||
self.heads[w2] = None
|
|
||||||
self.labels[w2] = ''
|
|
||||||
|
|
||||||
# Check there are no cycles in the dependencies, i.e. we are a tree
|
cycle = nonproj.contains_cycle(self.heads)
|
||||||
for w in range(self.length):
|
if cycle != None:
|
||||||
seen = set([w])
|
raise Exception("Cycle found: %s" % cycle)
|
||||||
head = w
|
|
||||||
while self.heads[head] != head and self.heads[head] != None:
|
if make_projective:
|
||||||
head = self.heads[head]
|
# projectivity here means non-proj arcs are being disconnected
|
||||||
if head in seen:
|
np_arcs = []
|
||||||
raise Exception("Cycle found: %s" % seen)
|
for word in range(self.length):
|
||||||
seen.add(head)
|
if nonproj.is_non_projective_arc(word,self.heads):
|
||||||
|
np_arcs.append(word)
|
||||||
|
for np_arc in np_arcs:
|
||||||
|
self.heads[np_arc] = None
|
||||||
|
self.labels[np_arc] = ''
|
||||||
|
|
||||||
self.brackets = {}
|
self.brackets = {}
|
||||||
for (gold_start, gold_end, label_str) in brackets:
|
for (gold_start, gold_end, label_str) in brackets:
|
||||||
|
@ -278,25 +266,18 @@ cdef class GoldParse:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_projective(self):
|
def is_projective(self):
|
||||||
heads = list(self.heads)
|
return not nonproj.is_non_projective_tree(self.heads)
|
||||||
for w1 in range(self.length):
|
|
||||||
if heads[w1] is not None:
|
|
||||||
h1 = heads[w1]
|
|
||||||
for w2 in range(self.length):
|
|
||||||
if heads[w2] is not None and _arcs_cross(w1, h1, w2, heads[w2]):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
cdef int _arcs_cross(int w1, int h1, int w2, int h2) except -1:
|
|
||||||
if w1 > h1:
|
|
||||||
w1, h1 = h1, w1
|
|
||||||
if w2 > h2:
|
|
||||||
w2, h2 = h2, w2
|
|
||||||
if w1 > w2:
|
|
||||||
w1, h1, w2, h2 = w2, h2, w1, h1
|
|
||||||
return w1 < w2 < h1 < h2 or w1 < w2 == h2 < h1
|
|
||||||
|
|
||||||
|
|
||||||
def is_punct_label(label):
|
def is_punct_label(label):
|
||||||
return label == 'P' or label.lower() == 'punct'
|
return label == 'P' or label.lower() == 'punct'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
55
spacy/nonproj.py
Normal file
55
spacy/nonproj.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
|
||||||
|
|
||||||
|
def ancestors(word, heads):
|
||||||
|
# returns all words going from the word up the path to the root
|
||||||
|
# the path to root cannot be longer than the number of words in the sentence
|
||||||
|
# this function ends after at most len(heads) steps
|
||||||
|
# because it would otherwise loop indefinitely on cycles
|
||||||
|
head = word
|
||||||
|
cnt = 0
|
||||||
|
while heads[head] != head and cnt < len(heads):
|
||||||
|
head = heads[head]
|
||||||
|
cnt += 1
|
||||||
|
yield head
|
||||||
|
if head == None:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def contains_cycle(heads):
|
||||||
|
# in an acyclic tree, the path from each word following
|
||||||
|
# the head relation upwards always ends at the root node
|
||||||
|
for word in range(len(heads)):
|
||||||
|
seen = set([word])
|
||||||
|
for ancestor in ancestors(word,heads):
|
||||||
|
if ancestor in seen:
|
||||||
|
return seen
|
||||||
|
seen.add(ancestor)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_non_projective_arc(word, heads):
|
||||||
|
# definition (e.g. Havelka 2007): an arc h -> d, h < d is non-projective
|
||||||
|
# if there is a word k, h < k < d such that h is not
|
||||||
|
# an ancestor of k. Same for h -> d, h > d
|
||||||
|
head = heads[word]
|
||||||
|
if head == word: # root arcs cannot be non-projective
|
||||||
|
return False
|
||||||
|
elif head == None: # unattached tokens cannot be non-projective
|
||||||
|
return False
|
||||||
|
|
||||||
|
start, end = (head+1, word) if head < word else (word+1, head)
|
||||||
|
for k in range(start,end):
|
||||||
|
for ancestor in ancestors(k,heads):
|
||||||
|
if ancestor == None: # for unattached tokens/subtrees
|
||||||
|
break
|
||||||
|
elif ancestor == head: # normal case: k dominated by h
|
||||||
|
break
|
||||||
|
else: # head not in ancestors: d -> h is non-projective
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_non_projective_tree(heads):
|
||||||
|
# a tree is non-projective if at least one arc is non-projective
|
||||||
|
return any( is_non_projective_arc(word,heads) for word in range(len(heads)) )
|
||||||
|
|
|
@ -211,6 +211,11 @@ cdef class Tagger:
|
||||||
tokens.is_tagged = True
|
tokens.is_tagged = True
|
||||||
tokens._py_tokens = [None] * tokens.length
|
tokens._py_tokens = [None] * tokens.length
|
||||||
|
|
||||||
|
def tags_from_list(self, Doc tokens, list strings):
|
||||||
|
assert(tokens.length == len(strings))
|
||||||
|
for i in range(tokens.length):
|
||||||
|
self.vocab.morphology.assign_tag(&tokens.c[i], strings[i])
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=1000, n_threads=2):
|
def pipe(self, stream, batch_size=1000, n_threads=2):
|
||||||
for doc in stream:
|
for doc in stream:
|
||||||
self(doc)
|
self(doc)
|
||||||
|
|
42
spacy/tests/test_nonproj.py
Normal file
42
spacy/tests/test_nonproj.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from spacy.nonproj import ancestors, contains_cycle, is_non_projective_arc, is_non_projective_tree
|
||||||
|
|
||||||
|
def test_ancestors():
|
||||||
|
tree = [1,2,2,4,5,2,2]
|
||||||
|
cyclic_tree = [1,2,2,4,5,3,2]
|
||||||
|
partial_tree = [1,2,2,4,5,None,2]
|
||||||
|
assert([ a for a in ancestors(3,tree) ] == [4,5,2])
|
||||||
|
assert([ a for a in ancestors(3,cyclic_tree) ] == [4,5,3,4,5,3,4])
|
||||||
|
assert([ a for a in ancestors(3,partial_tree) ] == [4,5,None])
|
||||||
|
|
||||||
|
def test_contains_cycle():
|
||||||
|
tree = [1,2,2,4,5,2,2]
|
||||||
|
cyclic_tree = [1,2,2,4,5,3,2]
|
||||||
|
partial_tree = [1,2,2,4,5,None,2]
|
||||||
|
assert(contains_cycle(tree) == None)
|
||||||
|
assert(contains_cycle(cyclic_tree) == set([3,4,5]))
|
||||||
|
assert(contains_cycle(partial_tree) == None)
|
||||||
|
|
||||||
|
def test_is_non_projective_arc():
|
||||||
|
nonproj_tree = [1,2,2,4,5,2,7,4,2]
|
||||||
|
assert(is_non_projective_arc(0,nonproj_tree) == False)
|
||||||
|
assert(is_non_projective_arc(1,nonproj_tree) == False)
|
||||||
|
assert(is_non_projective_arc(2,nonproj_tree) == False)
|
||||||
|
assert(is_non_projective_arc(3,nonproj_tree) == False)
|
||||||
|
assert(is_non_projective_arc(4,nonproj_tree) == False)
|
||||||
|
assert(is_non_projective_arc(5,nonproj_tree) == False)
|
||||||
|
assert(is_non_projective_arc(6,nonproj_tree) == False)
|
||||||
|
assert(is_non_projective_arc(7,nonproj_tree) == True)
|
||||||
|
assert(is_non_projective_arc(8,nonproj_tree) == False)
|
||||||
|
partial_tree = [1,2,2,4,5,None,7,4,2]
|
||||||
|
assert(is_non_projective_arc(7,partial_tree) == False)
|
||||||
|
|
||||||
|
def test_is_non_projective_tree():
|
||||||
|
proj_tree = [1,2,2,4,5,2,7,5,2]
|
||||||
|
nonproj_tree = [1,2,2,4,5,2,7,4,2]
|
||||||
|
partial_tree = [1,2,2,4,5,None,7,4,2]
|
||||||
|
assert(is_non_projective_tree(proj_tree) == False)
|
||||||
|
assert(is_non_projective_tree(nonproj_tree) == True)
|
||||||
|
assert(is_non_projective_tree(partial_tree) == False)
|
Loading…
Reference in New Issue
Block a user