Validate pos values when creating Doc (#9148)

* Validate pos values when creating Doc

* Add clear error when setting invalid pos

This also changes the error language slightly.

* Fix variable name

* Update spacy/tokens/doc.pyx

* Test that setting invalid pos raises an error

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
Paul O'Leary McCann 2021-09-16 20:28:05 +09:00 committed by GitHub
parent 865cfbc903
commit c4f0800fb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 20 additions and 0 deletions

View File

@ -871,6 +871,8 @@ class Errors:
"the documentation:\nhttps://spacy.io/usage/models")
E1020 = ("No `epoch_resume` value specified and could not infer one from "
"filename. Specify an epoch to resume from.")
E1021 = ("`pos` value \"{pp}\" is not a valid Universal Dependencies tag. "
"Non-UD tags should use the `tag` property.")
# Deprecated model shortcuts, only used in errors and warnings

View File

@ -70,3 +70,10 @@ def test_create_with_heads_and_no_deps(vocab):
heads = list(range(len(words)))
with pytest.raises(ValueError):
Doc(vocab, words=words, heads=heads)
def test_create_invalid_pos(vocab):
words = "I like ginger".split()
pos = "QQ ZZ XX".split()
with pytest.raises(ValueError):
Doc(vocab, words=words, pos=pos)

View File

@ -202,6 +202,10 @@ def test_set_pos():
doc[1].pos = VERB
assert doc[1].pos_ == "VERB"
def test_set_invalid_pos():
doc = Doc(Vocab(), words=["hello", "world"])
with pytest.raises(ValueError):
doc[0].pos_ = "blah"
def test_tokens_sent(doc):
"""Test token.sent property"""

View File

@ -30,6 +30,7 @@ from ..compat import copy_reg, pickle
from ..errors import Errors, Warnings
from ..morphology import Morphology
from .. import util
from .. import parts_of_speech
from .underscore import Underscore, get_ext_args
from ._retokenize import Retokenizer
from ._serialize import ALL_ATTRS as DOCBIN_ALL_ATTRS
@ -285,6 +286,10 @@ cdef class Doc:
sent_starts[i] = -1
elif sent_starts[i] is None or sent_starts[i] not in [-1, 0, 1]:
sent_starts[i] = 0
if pos is not None:
for pp in set(pos):
if pp not in parts_of_speech.IDS:
raise ValueError(Errors.E1021.format(pp=pp))
ent_iobs = None
ent_types = None
if ents is not None:

View File

@ -867,6 +867,8 @@ cdef class Token:
return parts_of_speech.NAMES[self.c.pos]
def __set__(self, pos_name):
if pos_name not in parts_of_speech.IDS:
raise ValueError(Errors.E1021.format(pp=pos_name))
self.c.pos = parts_of_speech.IDS[pos_name]
property tag_: