Require HEAD for is_parsed in Doc.from_array() (#5011)

Modify flag settings so that `DEP` is not sufficient to set `is_parsed`
and only run `set_children_from_heads()` if `HEAD` is provided.

Then the combination `[SENT_START, DEP]` will set deps and not clobber
sent starts with a lot of one-word sentences.
This commit is contained in:
adrianeboyd 2020-02-16 17:17:09 +01:00 committed by GitHub
parent 2572460175
commit 5b102963bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 2 deletions

View File

@ -7,7 +7,7 @@ import numpy
from spacy.tokens import Doc, Span
from spacy.vocab import Vocab
from spacy.errors import ModelsWarning
from spacy.attrs import ENT_TYPE, ENT_IOB
from spacy.attrs import ENT_TYPE, ENT_IOB, SENT_START, HEAD, DEP
from ..util import get_doc
@ -274,6 +274,39 @@ def test_doc_is_nered(en_vocab):
assert new_doc.is_nered
def test_doc_from_array_sent_starts(en_vocab):
words = ["I", "live", "in", "New", "York", ".", "I", "like", "cats", "."]
heads = [0, 0, 0, 0, 0, 0, 6, 6, 6, 6]
deps = ["ROOT", "dep", "dep", "dep", "dep", "dep", "ROOT", "dep", "dep", "dep", "dep"]
doc = Doc(en_vocab, words=words)
for i, (dep, head) in enumerate(zip(deps, heads)):
doc[i].dep_ = dep
doc[i].head = doc[head]
if head == i:
doc[i].is_sent_start = True
doc.is_parsed
attrs = [SENT_START, HEAD]
arr = doc.to_array(attrs)
new_doc = Doc(en_vocab, words=words)
with pytest.raises(ValueError):
new_doc.from_array(attrs, arr)
attrs = [SENT_START, DEP]
arr = doc.to_array(attrs)
new_doc = Doc(en_vocab, words=words)
new_doc.from_array(attrs, arr)
assert [t.is_sent_start for t in doc] == [t.is_sent_start for t in new_doc]
assert not new_doc.is_parsed
attrs = [HEAD, DEP]
arr = doc.to_array(attrs)
new_doc = Doc(en_vocab, words=words)
new_doc.from_array(attrs, arr)
assert [t.is_sent_start for t in doc] == [t.is_sent_start for t in new_doc]
assert new_doc.is_parsed
def test_doc_lang(en_vocab):
doc = Doc(en_vocab, words=["Hello", "world"])
assert doc.lang_ == "en"

View File

@ -813,7 +813,7 @@ cdef class Doc:
if attr_ids[j] != TAG:
Token.set_struct_attr(token, attr_ids[j], array[i, j])
# Set flags
self.is_parsed = bool(self.is_parsed or HEAD in attrs or DEP in attrs)
self.is_parsed = bool(self.is_parsed or HEAD in attrs)
self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs)
# If document is parsed, set children
if self.is_parsed: