diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index 86c7fbf72..52f856d3e 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -7,7 +7,7 @@ import numpy from spacy.tokens import Doc, Span from spacy.vocab import Vocab from spacy.errors import ModelsWarning -from spacy.attrs import ENT_TYPE, ENT_IOB +from spacy.attrs import ENT_TYPE, ENT_IOB, SENT_START, HEAD, DEP from ..util import get_doc @@ -274,6 +274,39 @@ def test_doc_is_nered(en_vocab): assert new_doc.is_nered +def test_doc_from_array_sent_starts(en_vocab): + words = ["I", "live", "in", "New", "York", ".", "I", "like", "cats", "."] + heads = [0, 0, 0, 0, 0, 0, 6, 6, 6, 6] + deps = ["ROOT", "dep", "dep", "dep", "dep", "dep", "ROOT", "dep", "dep", "dep", "dep"] + doc = Doc(en_vocab, words=words) + for i, (dep, head) in enumerate(zip(deps, heads)): + doc[i].dep_ = dep + doc[i].head = doc[head] + if head == i: + doc[i].is_sent_start = True + doc.is_parsed + + attrs = [SENT_START, HEAD] + arr = doc.to_array(attrs) + new_doc = Doc(en_vocab, words=words) + with pytest.raises(ValueError): + new_doc.from_array(attrs, arr) + + attrs = [SENT_START, DEP] + arr = doc.to_array(attrs) + new_doc = Doc(en_vocab, words=words) + new_doc.from_array(attrs, arr) + assert [t.is_sent_start for t in doc] == [t.is_sent_start for t in new_doc] + assert not new_doc.is_parsed + + attrs = [HEAD, DEP] + arr = doc.to_array(attrs) + new_doc = Doc(en_vocab, words=words) + new_doc.from_array(attrs, arr) + assert [t.is_sent_start for t in doc] == [t.is_sent_start for t in new_doc] + assert new_doc.is_parsed + + def test_doc_lang(en_vocab): doc = Doc(en_vocab, words=["Hello", "world"]) assert doc.lang_ == "en" diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 4aee21153..04e02fd98 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -813,7 +813,7 @@ cdef class Doc: if attr_ids[j] != TAG: Token.set_struct_attr(token, attr_ids[j], array[i, j]) # Set flags - self.is_parsed = bool(self.is_parsed or HEAD in attrs or DEP in attrs) + self.is_parsed = bool(self.is_parsed or HEAD in attrs) self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs) # If document is parsed, set children if self.is_parsed: