mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 00:46:28 +03:00
Require HEAD for is_parsed in Doc.from_array() (#5011)
Modify flag settings so that `DEP` is not sufficient to set `is_parsed` and only run `set_children_from_heads()` if `HEAD` is provided. Then the combination `[SENT_START, DEP]` will set deps and not clobber sent starts with a lot of one-word sentences.
This commit is contained in:
parent
2572460175
commit
5b102963bf
|
@ -7,7 +7,7 @@ import numpy
|
|||
from spacy.tokens import Doc, Span
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.errors import ModelsWarning
|
||||
from spacy.attrs import ENT_TYPE, ENT_IOB
|
||||
from spacy.attrs import ENT_TYPE, ENT_IOB, SENT_START, HEAD, DEP
|
||||
|
||||
from ..util import get_doc
|
||||
|
||||
|
@ -274,6 +274,39 @@ def test_doc_is_nered(en_vocab):
|
|||
assert new_doc.is_nered
|
||||
|
||||
|
||||
def test_doc_from_array_sent_starts(en_vocab):
|
||||
words = ["I", "live", "in", "New", "York", ".", "I", "like", "cats", "."]
|
||||
heads = [0, 0, 0, 0, 0, 0, 6, 6, 6, 6]
|
||||
deps = ["ROOT", "dep", "dep", "dep", "dep", "dep", "ROOT", "dep", "dep", "dep", "dep"]
|
||||
doc = Doc(en_vocab, words=words)
|
||||
for i, (dep, head) in enumerate(zip(deps, heads)):
|
||||
doc[i].dep_ = dep
|
||||
doc[i].head = doc[head]
|
||||
if head == i:
|
||||
doc[i].is_sent_start = True
|
||||
doc.is_parsed
|
||||
|
||||
attrs = [SENT_START, HEAD]
|
||||
arr = doc.to_array(attrs)
|
||||
new_doc = Doc(en_vocab, words=words)
|
||||
with pytest.raises(ValueError):
|
||||
new_doc.from_array(attrs, arr)
|
||||
|
||||
attrs = [SENT_START, DEP]
|
||||
arr = doc.to_array(attrs)
|
||||
new_doc = Doc(en_vocab, words=words)
|
||||
new_doc.from_array(attrs, arr)
|
||||
assert [t.is_sent_start for t in doc] == [t.is_sent_start for t in new_doc]
|
||||
assert not new_doc.is_parsed
|
||||
|
||||
attrs = [HEAD, DEP]
|
||||
arr = doc.to_array(attrs)
|
||||
new_doc = Doc(en_vocab, words=words)
|
||||
new_doc.from_array(attrs, arr)
|
||||
assert [t.is_sent_start for t in doc] == [t.is_sent_start for t in new_doc]
|
||||
assert new_doc.is_parsed
|
||||
|
||||
|
||||
def test_doc_lang(en_vocab):
|
||||
doc = Doc(en_vocab, words=["Hello", "world"])
|
||||
assert doc.lang_ == "en"
|
||||
|
|
|
@ -813,7 +813,7 @@ cdef class Doc:
|
|||
if attr_ids[j] != TAG:
|
||||
Token.set_struct_attr(token, attr_ids[j], array[i, j])
|
||||
# Set flags
|
||||
self.is_parsed = bool(self.is_parsed or HEAD in attrs or DEP in attrs)
|
||||
self.is_parsed = bool(self.is_parsed or HEAD in attrs)
|
||||
self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs)
|
||||
# If document is parsed, set children
|
||||
if self.is_parsed:
|
||||
|
|
Loading…
Reference in New Issue
Block a user