Require HEAD for is_parsed in Doc.from_array() (#5011)

Modify flag settings so that `DEP` is not sufficient to set `is_parsed`
and only run `set_children_from_heads()` if `HEAD` is provided.

Then the combination `[SENT_START, DEP]` will set deps and not clobber
sent starts with a lot of one-word sentences.
This commit is contained in:
adrianeboyd 2020-02-16 17:17:09 +01:00 committed by GitHub
parent 2572460175
commit 5b102963bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 2 deletions

View File

@ -7,7 +7,7 @@ import numpy
from spacy.tokens import Doc, Span from spacy.tokens import Doc, Span
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.errors import ModelsWarning from spacy.errors import ModelsWarning
from spacy.attrs import ENT_TYPE, ENT_IOB from spacy.attrs import ENT_TYPE, ENT_IOB, SENT_START, HEAD, DEP
from ..util import get_doc from ..util import get_doc
@ -274,6 +274,39 @@ def test_doc_is_nered(en_vocab):
assert new_doc.is_nered assert new_doc.is_nered
def test_doc_from_array_sent_starts(en_vocab):
words = ["I", "live", "in", "New", "York", ".", "I", "like", "cats", "."]
heads = [0, 0, 0, 0, 0, 0, 6, 6, 6, 6]
deps = ["ROOT", "dep", "dep", "dep", "dep", "dep", "ROOT", "dep", "dep", "dep", "dep"]
doc = Doc(en_vocab, words=words)
for i, (dep, head) in enumerate(zip(deps, heads)):
doc[i].dep_ = dep
doc[i].head = doc[head]
if head == i:
doc[i].is_sent_start = True
doc.is_parsed
attrs = [SENT_START, HEAD]
arr = doc.to_array(attrs)
new_doc = Doc(en_vocab, words=words)
with pytest.raises(ValueError):
new_doc.from_array(attrs, arr)
attrs = [SENT_START, DEP]
arr = doc.to_array(attrs)
new_doc = Doc(en_vocab, words=words)
new_doc.from_array(attrs, arr)
assert [t.is_sent_start for t in doc] == [t.is_sent_start for t in new_doc]
assert not new_doc.is_parsed
attrs = [HEAD, DEP]
arr = doc.to_array(attrs)
new_doc = Doc(en_vocab, words=words)
new_doc.from_array(attrs, arr)
assert [t.is_sent_start for t in doc] == [t.is_sent_start for t in new_doc]
assert new_doc.is_parsed
def test_doc_lang(en_vocab): def test_doc_lang(en_vocab):
doc = Doc(en_vocab, words=["Hello", "world"]) doc = Doc(en_vocab, words=["Hello", "world"])
assert doc.lang_ == "en" assert doc.lang_ == "en"

View File

@ -813,7 +813,7 @@ cdef class Doc:
if attr_ids[j] != TAG: if attr_ids[j] != TAG:
Token.set_struct_attr(token, attr_ids[j], array[i, j]) Token.set_struct_attr(token, attr_ids[j], array[i, j])
# Set flags # Set flags
self.is_parsed = bool(self.is_parsed or HEAD in attrs or DEP in attrs) self.is_parsed = bool(self.is_parsed or HEAD in attrs)
self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs) self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs)
# If document is parsed, set children # If document is parsed, set children
if self.is_parsed: if self.is_parsed: