mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Require HEAD for is_parsed in Doc.from_array() (#5011)
Modify flag settings so that `DEP` is not sufficient to set `is_parsed` and only run `set_children_from_heads()` if `HEAD` is provided. Then the combination `[SENT_START, DEP]` will set deps and not clobber sent starts with a lot of one-word sentences.
This commit is contained in:
parent
2572460175
commit
5b102963bf
|
@ -7,7 +7,7 @@ import numpy
|
||||||
from spacy.tokens import Doc, Span
|
from spacy.tokens import Doc, Span
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from spacy.errors import ModelsWarning
|
from spacy.errors import ModelsWarning
|
||||||
from spacy.attrs import ENT_TYPE, ENT_IOB
|
from spacy.attrs import ENT_TYPE, ENT_IOB, SENT_START, HEAD, DEP
|
||||||
|
|
||||||
from ..util import get_doc
|
from ..util import get_doc
|
||||||
|
|
||||||
|
@ -274,6 +274,39 @@ def test_doc_is_nered(en_vocab):
|
||||||
assert new_doc.is_nered
|
assert new_doc.is_nered
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_from_array_sent_starts(en_vocab):
|
||||||
|
words = ["I", "live", "in", "New", "York", ".", "I", "like", "cats", "."]
|
||||||
|
heads = [0, 0, 0, 0, 0, 0, 6, 6, 6, 6]
|
||||||
|
deps = ["ROOT", "dep", "dep", "dep", "dep", "dep", "ROOT", "dep", "dep", "dep", "dep"]
|
||||||
|
doc = Doc(en_vocab, words=words)
|
||||||
|
for i, (dep, head) in enumerate(zip(deps, heads)):
|
||||||
|
doc[i].dep_ = dep
|
||||||
|
doc[i].head = doc[head]
|
||||||
|
if head == i:
|
||||||
|
doc[i].is_sent_start = True
|
||||||
|
doc.is_parsed
|
||||||
|
|
||||||
|
attrs = [SENT_START, HEAD]
|
||||||
|
arr = doc.to_array(attrs)
|
||||||
|
new_doc = Doc(en_vocab, words=words)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
new_doc.from_array(attrs, arr)
|
||||||
|
|
||||||
|
attrs = [SENT_START, DEP]
|
||||||
|
arr = doc.to_array(attrs)
|
||||||
|
new_doc = Doc(en_vocab, words=words)
|
||||||
|
new_doc.from_array(attrs, arr)
|
||||||
|
assert [t.is_sent_start for t in doc] == [t.is_sent_start for t in new_doc]
|
||||||
|
assert not new_doc.is_parsed
|
||||||
|
|
||||||
|
attrs = [HEAD, DEP]
|
||||||
|
arr = doc.to_array(attrs)
|
||||||
|
new_doc = Doc(en_vocab, words=words)
|
||||||
|
new_doc.from_array(attrs, arr)
|
||||||
|
assert [t.is_sent_start for t in doc] == [t.is_sent_start for t in new_doc]
|
||||||
|
assert new_doc.is_parsed
|
||||||
|
|
||||||
|
|
||||||
def test_doc_lang(en_vocab):
|
def test_doc_lang(en_vocab):
|
||||||
doc = Doc(en_vocab, words=["Hello", "world"])
|
doc = Doc(en_vocab, words=["Hello", "world"])
|
||||||
assert doc.lang_ == "en"
|
assert doc.lang_ == "en"
|
||||||
|
|
|
@ -813,7 +813,7 @@ cdef class Doc:
|
||||||
if attr_ids[j] != TAG:
|
if attr_ids[j] != TAG:
|
||||||
Token.set_struct_attr(token, attr_ids[j], array[i, j])
|
Token.set_struct_attr(token, attr_ids[j], array[i, j])
|
||||||
# Set flags
|
# Set flags
|
||||||
self.is_parsed = bool(self.is_parsed or HEAD in attrs or DEP in attrs)
|
self.is_parsed = bool(self.is_parsed or HEAD in attrs)
|
||||||
self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs)
|
self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs)
|
||||||
# If document is parsed, set children
|
# If document is parsed, set children
|
||||||
if self.is_parsed:
|
if self.is_parsed:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user