diff --git a/spacy/syntax/nonproj.pyx b/spacy/syntax/nonproj.pyx index fb04ecb2d..9339efb39 100644 --- a/spacy/syntax/nonproj.pyx +++ b/spacy/syntax/nonproj.pyx @@ -118,15 +118,18 @@ class PseudoProjectivity: # reattach arcs with decorated labels (following HEAD scheme) # for each decorated arc X||Y, search top-down, left-to-right, # breadth-first until hitting a Y then make this the new head - parse = tokens.to_array([HEAD, DEP]) - labels = [ tokens.vocab.strings[int(p[1])] for p in parse ] + #parse = tokens.to_array([HEAD, DEP]) for token in tokens: if cls.is_decorated(token.dep_): newlabel,headlabel = cls.decompose(token.dep_) newhead = cls._find_new_head(token,headlabel) - parse[token.i,1] = tokens.vocab.strings[newlabel] - parse[token.i,0] = newhead.i - token.i - tokens.from_array([HEAD, DEP],parse) + token.head = newhead + token.dep_ = newlabel + + # tokens.attach(token,newhead,newlabel) + #parse[token.i,1] = tokens.vocab.strings[newlabel] + #parse[token.i,0] = newhead.i - token.i + #tokens.from_array([HEAD, DEP],parse) @classmethod @@ -168,7 +171,7 @@ class PseudoProjectivity: @classmethod def _find_new_head(cls, token, headlabel): - # search through the tree starting from root + # search through the tree starting from the head of the given token # returns the id of the first descendant with the given label # if there is none, return the current head (no change) queue = [token.head] @@ -176,8 +179,8 @@ class PseudoProjectivity: next_queue = [] for qtoken in queue: for child in qtoken.children: - if child == token: - continue + if child.is_space: continue + if child == token: continue if child.dep_ == headlabel: return child next_queue.append(child) diff --git a/spacy/tests/tokens/test_token_api.py b/spacy/tests/tokens/test_token_api.py index 6deaadfbf..fba8a4d67 100644 --- a/spacy/tests/tokens/test_token_api.py +++ b/spacy/tests/tokens/test_token_api.py @@ -62,3 +62,67 @@ def test_vectors(EN): assert sum(apples.vector) != sum(oranges.vector) assert apples.vector_norm != oranges.vector_norm +@pytest.mark.models +def test_ancestors(EN): + # the structure of this sentence depends on the English annotation scheme + tokens = EN(u'Yesterday I saw a dog that barked loudly.') + ancestors = [ t.orth_ for t in tokens[6].ancestors ] + assert ancestors == ['dog','saw'] + ancestors = [ t.orth_ for t in tokens[1].ancestors ] + assert ancestors == ['saw'] + ancestors = [ t.orth_ for t in tokens[2].ancestors ] + assert ancestors == [] + + assert tokens[2].is_ancestor_of(tokens[7]) + assert not tokens[6].is_ancestor_of(tokens[2]) + + +@pytest.mark.models +def test_head_setter(EN): + # the structure of this sentence depends on the English annotation scheme + yesterday, i, saw, a, dog, that, barked, loudly, dot = EN(u'Yesterday I saw a dog that barked loudly.') + assert barked.n_lefts == 1 + assert barked.n_rights == 1 + assert barked.left_edge == that + assert barked.right_edge == loudly + + assert dog.n_lefts == 1 + assert dog.n_rights == 1 + assert dog.left_edge == a + assert dog.right_edge == loudly + + assert a.n_lefts == 0 + assert a.n_rights == 0 + assert a.left_edge == a + assert a.right_edge == a + + assert saw.left_edge == yesterday + assert saw.right_edge == dot + + barked.head = a + + assert barked.n_lefts == 1 + assert barked.n_rights == 1 + assert barked.left_edge == that + assert barked.right_edge == loudly + + assert a.n_lefts == 0 + assert a.n_rights == 1 + assert a.left_edge == a + assert a.right_edge == loudly + + assert dog.n_lefts == 1 + assert dog.n_rights == 0 + assert dog.left_edge == a + assert dog.right_edge == loudly + + assert saw.left_edge == yesterday + assert saw.right_edge == dot + + yesterday.head = that + + assert that.left_edge == yesterday + assert barked.left_edge == yesterday + assert a.left_edge == yesterday + assert dog.left_edge == yesterday + assert saw.left_edge == yesterday diff --git a/spacy/tokens/token.pxd b/spacy/tokens/token.pxd index 8ff0e9587..1706cdc55 100644 --- a/spacy/tokens/token.pxd +++ b/spacy/tokens/token.pxd @@ -6,7 +6,7 @@ from .doc cimport Doc cdef class Token: cdef Vocab vocab - cdef const TokenC* c + cdef TokenC* c cdef readonly int i cdef int array_len cdef readonly Doc doc diff --git a/spacy/tokens/token.pyx b/spacy/tokens/token.pyx index 0ff574f1b..8b920934c 100644 --- a/spacy/tokens/token.pyx +++ b/spacy/tokens/token.pyx @@ -142,6 +142,8 @@ cdef class Token: property dep: def __get__(self): return self.c.dep + def __set__(self, int label): + self.c.dep = label property has_vector: def __get__(self): @@ -250,10 +252,113 @@ cdef class Token: def __get__(self): return self.doc[self.c.r_edge] + property ancestors: + def __get__(self): + cdef const TokenC* head_ptr = self.c + # guard against infinite loop, no token can have + # more ancestors than tokens in the tree + cdef int i = 0 + while head_ptr.head != 0 and i < self.doc.length: + head_ptr += head_ptr.head + yield self.doc[head_ptr - (self.c - self.i)] + i += 1 + + def is_ancestor_of(self, descendant): + return any( ancestor.i == self.i for ancestor in descendant.ancestors ) + property head: def __get__(self): """The token predicted by the parser to be the head of the current token.""" return self.doc[self.i + self.c.head] + def __set__(self, Token new_head): + # this function sets the head of self to new_head + # and updates the counters for left/right dependents + # and left/right corner for the new and the old head + + # do nothing if old head is new head + if self.i + self.c.head == new_head.i: + return + + cdef Token old_head = self.head + cdef int rel_newhead_i = new_head.i - self.i + + # is the new head a descendant of the old head + cdef bint is_desc = old_head.is_ancestor_of(new_head) + + cdef int new_edge + cdef Token anc, child + + # update number of deps of old head + if self.c.head > 0: # left dependent + old_head.c.l_kids -= 1 + if self.c.l_edge == old_head.c.l_edge: + # the token dominates the left edge so the left edge of the head + # may change when the token is reattached + # it may not change if the new head is a descendant of the current head + + new_edge = self.c.l_edge + # the new l_edge is the left-most l_edge on any of the other dependents + # where the l_edge is left of the head, otherwise it is the head + if not is_desc: + new_edge = old_head.i + for child in old_head.children: + if child == self: + continue + if child.c.l_edge < new_edge: + new_edge = child.c.l_edge + old_head.c.l_edge = new_edge + + # walk up the tree from old_head and assign new l_edge to ancestors + # until an ancestor already has an l_edge that's further left + for anc in old_head.ancestors: + if anc.c.l_edge <= new_edge: + break + anc.c.l_edge = new_edge + + elif self.c.head < 0: # right dependent + old_head.c.r_kids -= 1 + # do the same thing as for l_edge + if self.c.r_edge == old_head.c.r_edge: + new_edge = self.c.r_edge + + if not is_desc: + new_edge = old_head.i + for child in old_head.children: + if child == self: + continue + if child.c.r_edge > new_edge: + new_edge = child.c.r_edge + old_head.c.r_edge = new_edge + + for anc in old_head.ancestors: + if anc.c.r_edge >= new_edge: + break + anc.c.r_edge = new_edge + + # update number of deps of new head + if rel_newhead_i > 0: # left dependent + new_head.c.l_kids += 1 + # walk up the tree from new head and set l_edge to self.l_edge + # until you hit a token with an l_edge further to the left + if self.c.l_edge < new_head.c.l_edge: + new_head.c.l_edge = self.c.l_edge + for anc in new_head.ancestors: + if anc.c.l_edge <= self.c.l_edge: + break + anc.c.l_edge = self.c.l_edge + + elif rel_newhead_i < 0: # right dependent + new_head.c.r_kids += 1 + # do the same as for l_edge + if self.c.r_edge > new_head.c.r_edge: + new_head.c.r_edge = self.c.r_edge + for anc in new_head.ancestors: + if anc.c.r_edge >= self.c.r_edge: + break + anc.c.r_edge = self.c.r_edge + + # set new head + self.c.head = rel_newhead_i property conjuncts: def __get__(self): @@ -325,6 +430,8 @@ cdef class Token: property dep_: def __get__(self): return self.vocab.strings[self.c.dep] + def __set__(self, unicode label): + self.c.dep = self.vocab.strings[label] property is_oov: def __get__(self): return Lexeme.c_check_flag(self.c.lex, IS_OOV)