Detect cycle during projectivize (#10877)

* detect cycle during projectivize

* not complete test to detect cycle in projectivize

* boolean to int type to propagate error

* use unordered_set instead of set

* moved error message to errors

* removed cycle from test case

* use find instead of count

* cycle check: only perform one lookup

* Return bool again from _has_head_as_ancestor

Communicate presence of cycles through an output argument.

* Switch to returning std::pair to encode presence of a cycle

The has_cycle pointer is too easy to misuse. Ideally, we would have a
sum type like Rust's `Result` here, but C++ is not there yet.

* _is_non_proj_arc: clarify what we are returning

* _has_head_as_ancestor: remove count

We are now explicitly checking for cycles, so the algorithm must always
terminate. Either we encounter the head, we find a root, or a cycle.

* _is_nonproj_arc: simplify condition

* Another refactor using C++ exceptions

* Remove unused error code

* Print graph with cycle on exception

* Include .hh files in source package

* Add FIXME comment

* cycle detection test

* find cycle when starting from problematic vertex

Co-authored-by: Daniël de Kok <me@danieldk.eu>
This commit is contained in:
kadarakos 2022-06-08 19:34:11 +02:00 committed by GitHub
parent d176afd32f
commit 1bb87f35bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 67 additions and 16 deletions

View File

@ -1,4 +1,4 @@
recursive-include spacy *.pyi *.pyx *.pxd *.txt *.cfg *.jinja *.toml recursive-include spacy *.pyi *.pyx *.pxd *.txt *.cfg *.jinja *.toml *.hh
include LICENSE include LICENSE
include README.md include README.md
include pyproject.toml include pyproject.toml

View File

@ -0,0 +1,11 @@
#ifndef NONPROJ_HH
#define NONPROJ_HH
#include <stdexcept>
#include <string>
void raise_domain_error(std::string const &msg) {
throw std::domain_error(msg);
}
#endif // NONPROJ_HH

View File

@ -0,0 +1,4 @@
from libcpp.string cimport string
cdef extern from "nonproj.hh":
cdef void raise_domain_error(const string& msg) nogil except +

View File

@ -4,10 +4,13 @@ for doing pseudo-projective parsing implementation uses the HEAD decoration
scheme. scheme.
""" """
from copy import copy from copy import copy
from cython.operator cimport preincrement as incr, dereference as deref
from libc.limits cimport INT_MAX from libc.limits cimport INT_MAX
from libc.stdlib cimport abs from libc.stdlib cimport abs
from libcpp cimport bool from libcpp cimport bool
from libcpp.string cimport string, to_string
from libcpp.vector cimport vector from libcpp.vector cimport vector
from libcpp.unordered_set cimport unordered_set
from ...tokens.doc cimport Doc, set_children_from_heads from ...tokens.doc cimport Doc, set_children_from_heads
@ -49,7 +52,7 @@ def is_nonproj_arc(tokenid, heads):
return _is_nonproj_arc(tokenid, c_heads) return _is_nonproj_arc(tokenid, c_heads)
cdef bool _is_nonproj_arc(int tokenid, const vector[int]& heads) nogil: cdef bool _is_nonproj_arc(int tokenid, const vector[int]& heads) nogil except *:
# definition (e.g. Havelka 2007): an arc h -> d, h < d is non-projective # definition (e.g. Havelka 2007): an arc h -> d, h < d is non-projective
# if there is a token k, h < k < d such that h is not # if there is a token k, h < k < d such that h is not
# an ancestor of k. Same for h -> d, h > d # an ancestor of k. Same for h -> d, h > d
@ -58,32 +61,56 @@ cdef bool _is_nonproj_arc(int tokenid, const vector[int]& heads) nogil:
return False return False
elif head < 0: # unattached tokens cannot be non-projective elif head < 0: # unattached tokens cannot be non-projective
return False return False
cdef int start, end cdef int start, end
if head < tokenid: if head < tokenid:
start, end = (head+1, tokenid) start, end = (head+1, tokenid)
else: else:
start, end = (tokenid+1, head) start, end = (tokenid+1, head)
for k in range(start, end): for k in range(start, end):
if _has_head_as_ancestor(k, head, heads): if not _has_head_as_ancestor(k, head, heads):
continue
else: # head not in ancestors: d -> h is non-projective
return True return True
return False return False
cdef bool _has_head_as_ancestor(int tokenid, int head, const vector[int]& heads) nogil: cdef bool _has_head_as_ancestor(int tokenid, int head, const vector[int]& heads) nogil except *:
ancestor = tokenid ancestor = tokenid
cnt = 0 cdef unordered_set[int] seen_tokens
while cnt < heads.size(): seen_tokens.insert(ancestor)
while True:
# Reached the head or a disconnected node
if heads[ancestor] == head or heads[ancestor] < 0: if heads[ancestor] == head or heads[ancestor] < 0:
return True return True
# Reached the root
if heads[ancestor] == ancestor:
return False
ancestor = heads[ancestor] ancestor = heads[ancestor]
cnt += 1 result = seen_tokens.insert(ancestor)
# Found cycle
if not result.second:
raise_domain_error(heads_to_string(heads))
return False return False
cdef string heads_to_string(const vector[int]& heads) nogil:
cdef vector[int].const_iterator citer
cdef string cycle_str
cycle_str.append("Found cycle in dependency graph: [")
# FIXME: Rewrite using ostringstream when available in Cython.
citer = heads.const_begin()
while citer != heads.const_end():
if citer != heads.const_begin():
cycle_str.append(", ")
cycle_str.append(to_string(deref(citer)))
incr(citer)
cycle_str.append("]")
return cycle_str
def is_nonproj_tree(heads): def is_nonproj_tree(heads):
cdef vector[int] c_heads = _heads_to_c(heads) cdef vector[int] c_heads = _heads_to_c(heads)
# a tree is non-projective if at least one arc is non-projective # a tree is non-projective if at least one arc is non-projective
@ -176,11 +203,12 @@ def get_smallest_nonproj_arc_slow(heads):
return _get_smallest_nonproj_arc(c_heads) return _get_smallest_nonproj_arc(c_heads)
cdef int _get_smallest_nonproj_arc(const vector[int]& heads) nogil: cdef int _get_smallest_nonproj_arc(const vector[int]& heads) nogil except -2:
# return the smallest non-proj arc or None # return the smallest non-proj arc or None
# where size is defined as the distance between dep and head # where size is defined as the distance between dep and head
# and ties are broken left to right # and ties are broken left to right
cdef int smallest_size = INT_MAX cdef int smallest_size = INT_MAX
# -1 means its already projective.
cdef int smallest_np_arc = -1 cdef int smallest_np_arc = -1
cdef int size cdef int size
cdef int tokenid cdef int tokenid

View File

@ -49,7 +49,7 @@ def test_parser_contains_cycle(tree, cyclic_tree, partial_tree, multirooted_tree
assert contains_cycle(multirooted_tree) is None assert contains_cycle(multirooted_tree) is None
def test_parser_is_nonproj_arc(nonproj_tree, partial_tree, multirooted_tree): def test_parser_is_nonproj_arc(cyclic_tree, nonproj_tree, partial_tree, multirooted_tree):
assert is_nonproj_arc(0, nonproj_tree) is False assert is_nonproj_arc(0, nonproj_tree) is False
assert is_nonproj_arc(1, nonproj_tree) is False assert is_nonproj_arc(1, nonproj_tree) is False
assert is_nonproj_arc(2, nonproj_tree) is False assert is_nonproj_arc(2, nonproj_tree) is False
@ -62,15 +62,19 @@ def test_parser_is_nonproj_arc(nonproj_tree, partial_tree, multirooted_tree):
assert is_nonproj_arc(7, partial_tree) is False assert is_nonproj_arc(7, partial_tree) is False
assert is_nonproj_arc(17, multirooted_tree) is False assert is_nonproj_arc(17, multirooted_tree) is False
assert is_nonproj_arc(16, multirooted_tree) is True assert is_nonproj_arc(16, multirooted_tree) is True
with pytest.raises(ValueError, match=r'Found cycle in dependency graph: \[1, 2, 2, 4, 5, 3, 2\]'):
is_nonproj_arc(6, cyclic_tree)
def test_parser_is_nonproj_tree( def test_parser_is_nonproj_tree(
proj_tree, nonproj_tree, partial_tree, multirooted_tree proj_tree, cyclic_tree, nonproj_tree, partial_tree, multirooted_tree
): ):
assert is_nonproj_tree(proj_tree) is False assert is_nonproj_tree(proj_tree) is False
assert is_nonproj_tree(nonproj_tree) is True assert is_nonproj_tree(nonproj_tree) is True
assert is_nonproj_tree(partial_tree) is False assert is_nonproj_tree(partial_tree) is False
assert is_nonproj_tree(multirooted_tree) is True assert is_nonproj_tree(multirooted_tree) is True
with pytest.raises(ValueError, match=r'Found cycle in dependency graph: \[1, 2, 2, 4, 5, 3, 2\]'):
is_nonproj_tree(cyclic_tree)
def test_parser_pseudoprojectivity(en_vocab): def test_parser_pseudoprojectivity(en_vocab):
@ -84,8 +88,10 @@ def test_parser_pseudoprojectivity(en_vocab):
tree = [1, 2, 2] tree = [1, 2, 2]
nonproj_tree = [1, 2, 2, 4, 5, 2, 7, 4, 2] nonproj_tree = [1, 2, 2, 4, 5, 2, 7, 4, 2]
nonproj_tree2 = [9, 1, 3, 1, 5, 6, 9, 8, 6, 1, 6, 12, 13, 10, 1] nonproj_tree2 = [9, 1, 3, 1, 5, 6, 9, 8, 6, 1, 6, 12, 13, 10, 1]
cyclic_tree = [1, 2, 2, 4, 5, 3, 2]
labels = ["det", "nsubj", "root", "det", "dobj", "aux", "nsubj", "acl", "punct"] labels = ["det", "nsubj", "root", "det", "dobj", "aux", "nsubj", "acl", "punct"]
labels2 = ["advmod", "root", "det", "nsubj", "advmod", "det", "dobj", "det", "nmod", "aux", "nmod", "advmod", "det", "amod", "punct"] labels2 = ["advmod", "root", "det", "nsubj", "advmod", "det", "dobj", "det", "nmod", "aux", "nmod", "advmod", "det", "amod", "punct"]
cyclic_labels = ["det", "nsubj", "root", "det", "dobj", "aux", "punct"]
# fmt: on # fmt: on
assert nonproj.decompose("X||Y") == ("X", "Y") assert nonproj.decompose("X||Y") == ("X", "Y")
assert nonproj.decompose("X") == ("X", "") assert nonproj.decompose("X") == ("X", "")
@ -97,6 +103,8 @@ def test_parser_pseudoprojectivity(en_vocab):
assert nonproj.get_smallest_nonproj_arc_slow(nonproj_tree2) == 10 assert nonproj.get_smallest_nonproj_arc_slow(nonproj_tree2) == 10
# fmt: off # fmt: off
proj_heads, deco_labels = nonproj.projectivize(nonproj_tree, labels) proj_heads, deco_labels = nonproj.projectivize(nonproj_tree, labels)
with pytest.raises(ValueError, match=r'Found cycle in dependency graph: \[1, 2, 2, 4, 5, 3, 2\]'):
nonproj.projectivize(cyclic_tree, cyclic_labels)
assert proj_heads == [1, 2, 2, 4, 5, 2, 7, 5, 2] assert proj_heads == [1, 2, 2, 4, 5, 2, 7, 5, 2]
assert deco_labels == ["det", "nsubj", "root", "det", "dobj", "aux", assert deco_labels == ["det", "nsubj", "root", "det", "dobj", "aux",
"nsubj", "acl||dobj", "punct"] "nsubj", "acl||dobj", "punct"]

View File

@ -671,13 +671,13 @@ def test_gold_ner_missing_tags(en_tokenizer):
def test_projectivize(en_tokenizer): def test_projectivize(en_tokenizer):
doc = en_tokenizer("He pretty quickly walks away") doc = en_tokenizer("He pretty quickly walks away")
heads = [3, 2, 3, 0, 2] heads = [3, 2, 3, 3, 2]
deps = ["dep"] * len(heads) deps = ["dep"] * len(heads)
example = Example.from_dict(doc, {"heads": heads, "deps": deps}) example = Example.from_dict(doc, {"heads": heads, "deps": deps})
proj_heads, proj_labels = example.get_aligned_parse(projectivize=True) proj_heads, proj_labels = example.get_aligned_parse(projectivize=True)
nonproj_heads, nonproj_labels = example.get_aligned_parse(projectivize=False) nonproj_heads, nonproj_labels = example.get_aligned_parse(projectivize=False)
assert proj_heads == [3, 2, 3, 0, 3] assert proj_heads == [3, 2, 3, 3, 3]
assert nonproj_heads == [3, 2, 3, 0, 2] assert nonproj_heads == [3, 2, 3, 3, 2]
def test_iob_to_biluo(): def test_iob_to_biluo():