Validate pos values when creating Doc (#9148)

* Validate pos values when creating Doc

* Add clear error when setting invalid pos

This also changes the error language slightly.

* Fix variable name

* Update spacy/tokens/doc.pyx

* Test that setting invalid pos raises an error

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
Paul O'Leary McCann 2021-09-16 20:28:05 +09:00 committed by GitHub
parent 865cfbc903
commit c4f0800fb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 20 additions and 0 deletions

View File

@ -871,6 +871,8 @@ class Errors:
"the documentation:\nhttps://spacy.io/usage/models") "the documentation:\nhttps://spacy.io/usage/models")
E1020 = ("No `epoch_resume` value specified and could not infer one from " E1020 = ("No `epoch_resume` value specified and could not infer one from "
"filename. Specify an epoch to resume from.") "filename. Specify an epoch to resume from.")
E1021 = ("`pos` value \"{pp}\" is not a valid Universal Dependencies tag. "
"Non-UD tags should use the `tag` property.")
# Deprecated model shortcuts, only used in errors and warnings # Deprecated model shortcuts, only used in errors and warnings

View File

@ -70,3 +70,10 @@ def test_create_with_heads_and_no_deps(vocab):
heads = list(range(len(words))) heads = list(range(len(words)))
with pytest.raises(ValueError): with pytest.raises(ValueError):
Doc(vocab, words=words, heads=heads) Doc(vocab, words=words, heads=heads)
def test_create_invalid_pos(vocab):
words = "I like ginger".split()
pos = "QQ ZZ XX".split()
with pytest.raises(ValueError):
Doc(vocab, words=words, pos=pos)

View File

@ -202,6 +202,10 @@ def test_set_pos():
doc[1].pos = VERB doc[1].pos = VERB
assert doc[1].pos_ == "VERB" assert doc[1].pos_ == "VERB"
def test_set_invalid_pos():
doc = Doc(Vocab(), words=["hello", "world"])
with pytest.raises(ValueError):
doc[0].pos_ = "blah"
def test_tokens_sent(doc): def test_tokens_sent(doc):
"""Test token.sent property""" """Test token.sent property"""

View File

@ -30,6 +30,7 @@ from ..compat import copy_reg, pickle
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
from ..morphology import Morphology from ..morphology import Morphology
from .. import util from .. import util
from .. import parts_of_speech
from .underscore import Underscore, get_ext_args from .underscore import Underscore, get_ext_args
from ._retokenize import Retokenizer from ._retokenize import Retokenizer
from ._serialize import ALL_ATTRS as DOCBIN_ALL_ATTRS from ._serialize import ALL_ATTRS as DOCBIN_ALL_ATTRS
@ -285,6 +286,10 @@ cdef class Doc:
sent_starts[i] = -1 sent_starts[i] = -1
elif sent_starts[i] is None or sent_starts[i] not in [-1, 0, 1]: elif sent_starts[i] is None or sent_starts[i] not in [-1, 0, 1]:
sent_starts[i] = 0 sent_starts[i] = 0
if pos is not None:
for pp in set(pos):
if pp not in parts_of_speech.IDS:
raise ValueError(Errors.E1021.format(pp=pp))
ent_iobs = None ent_iobs = None
ent_types = None ent_types = None
if ents is not None: if ents is not None:

View File

@ -867,6 +867,8 @@ cdef class Token:
return parts_of_speech.NAMES[self.c.pos] return parts_of_speech.NAMES[self.c.pos]
def __set__(self, pos_name): def __set__(self, pos_name):
if pos_name not in parts_of_speech.IDS:
raise ValueError(Errors.E1021.format(pp=pos_name))
self.c.pos = parts_of_speech.IDS[pos_name] self.c.pos = parts_of_speech.IDS[pos_name]
property tag_: property tag_: