diff --git a/spacy/tests/doc/test_token_api.py b/spacy/tests/doc/test_token_api.py index d3fb044ee..0795080a5 100644 --- a/spacy/tests/doc/test_token_api.py +++ b/spacy/tests/doc/test_token_api.py @@ -261,14 +261,18 @@ def test_missing_head_dep(en_vocab): doc = Doc(en_vocab, words=words, heads=heads, deps=deps) pred_has_heads = [t.has_head() for t in doc] pred_deps = [t.dep_ for t in doc] + pred_sent_starts = [t.is_sent_start for t in doc] assert pred_has_heads == [True, True, True, True, True, False] assert pred_deps == ["nsubj", "ROOT", "dobj", "cc", "conj", MISSING_DEP_] + assert pred_sent_starts == [True, False, False, False, False, False] example = Example.from_dict(doc, {"heads": heads, "deps": deps}) ref_heads = [t.head.i for t in example.reference] ref_deps = [t.dep_ for t in example.reference] ref_has_heads = [t.has_head() for t in example.reference] + ref_sent_starts = [t.is_sent_start for t in example.reference] assert ref_deps == ["nsubj", "ROOT", "dobj", "cc", "conj", MISSING_DEP_] assert ref_has_heads == [True, True, True, True, True, False] + assert ref_sent_starts == [True, False, False, False, False, False] aligned_heads, aligned_deps = example.get_aligned_parse(projectivize=True) assert aligned_heads[5] == ref_heads[5] assert aligned_deps[5] == MISSING_DEP_ diff --git a/spacy/tests/training/test_new_example.py b/spacy/tests/training/test_new_example.py index 6b6486b2b..0a3184071 100644 --- a/spacy/tests/training/test_new_example.py +++ b/spacy/tests/training/test_new_example.py @@ -282,3 +282,24 @@ def test_Example_missing_deps(): # when providing deps, the head information is actually used example_2 = Example.from_dict(predicted, annots_head_dep) assert [t.head.i for t in example_2.reference] == heads + + +def test_Example_missing_heads(): + vocab = Vocab() + words = ["I", "like", "London", "and", "Berlin", "."] + deps = ["nsubj", "ROOT", "dobj", None, "conj", "punct"] + heads = [1, 1, 1, None, 2, 1] + annots = {"words": words, "heads": heads, "deps": deps} + predicted = Doc(vocab, words=words) + + example = Example.from_dict(predicted, annots) + parsed_heads = [t.head.i for t in example.reference] + assert parsed_heads[0] == heads[0] + assert parsed_heads[1] == heads[1] + assert parsed_heads[2] == heads[2] + assert parsed_heads[4] == heads[4] + assert parsed_heads[5] == heads[5] + assert [t.has_head() for t in example.reference] == [True, True, True, False, True, True] + + # Ensure that the missing head doesn't create an artificial new sentence start + assert example.get_aligned_sent_starts() == [True, False, False, False, False, False] diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index fc14fb506..221e78b2e 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -1540,7 +1540,7 @@ cdef int set_children_from_heads(TokenC* tokens, int start, int end) except -1: for i in range(start, end): tokens[i].sent_start = -1 for i in range(start, end): - if tokens[i].head == 0: + if tokens[i].head == 0 and not Token.missing_head(&tokens[i]): tokens[tokens[i].l_edge].sent_start = 1 diff --git a/spacy/tokens/token.pxd b/spacy/tokens/token.pxd index 45c906a82..9006c874c 100644 --- a/spacy/tokens/token.pxd +++ b/spacy/tokens/token.pxd @@ -94,3 +94,10 @@ cdef class Token: token.ent_kb_id = value elif feat_name == SENT_START: token.sent_start = value + + @staticmethod + cdef inline int missing_head(const TokenC* token) nogil: + if token.dep == 0: + return 1 + else: + return 0 diff --git a/spacy/training/example.pyx b/spacy/training/example.pyx index 856719893..3303a8456 100644 --- a/spacy/training/example.pyx +++ b/spacy/training/example.pyx @@ -184,7 +184,10 @@ cdef class Example: heads = [token.head.i for token in self.y] deps = [token.dep_ for token in self.y] if projectivize: - heads, deps = nonproj.projectivize(heads, deps) + proj_heads, proj_deps = nonproj.projectivize(heads, deps) + # ensure that data that was previously missing, remains missing + heads = [h if has_heads[i] else heads[i] for i, h in enumerate(proj_heads)] + deps = [d if deps[i] != MISSING_DEP_ else MISSING_DEP_ for i, d in enumerate(proj_deps)] for cand_i in range(self.x.length): if cand_to_gold.lengths[cand_i] == 1: gold_i = cand_to_gold[cand_i].dataXd[0, 0]