Don't raise error if set_extension has getter and setter (closes #2177)

Improve error messages, raise error if setter is specified without a getter and compare against _unset to allow default=None. Also add more tests.
This commit is contained in:
ines 2018-04-03 18:30:17 +02:00
parent ee3082ad29
commit 62b4b527d7
6 changed files with 84 additions and 17 deletions

View File

@ -218,8 +218,8 @@ class Errors(object):
E082 = ("Error deprojectivizing parse: number of heads ({n_heads}), "
"projective heads ({n_proj_heads}) and labels ({n_labels}) do not "
"match.")
E083 = ("Error setting extension: only one of default, getter, setter and "
"method is allowed. {n_args} keyword arguments were specified.")
E083 = ("Error setting extension: only one of `default`, `method`, or "
"`getter` (plus optional `setter`) is allowed. Got: {nr_defined}")
E084 = ("Error assigning label ID {label} to span: not in StringStore.")
E085 = ("Can't create lexeme for string '{string}'.")
E086 = ("Error deserializing lexeme '{string}': orth ID {orth_id} does "
@ -233,6 +233,12 @@ class Errors(object):
"`nlp.max_length` limit. The limit is in number of characters, so "
"you can check whether your inputs are too long by checking "
"`len(text)`.")
E089 = ("Extensions can't have a setter argument without a getter "
"argument. Check the keyword arguments on `set_extension`.")
E090 = ("Extension '{name}' already exists on {obj}. To overwrite the "
"existing extension, set `force=True` on `{obj}.set_extension`.")
E091 = ("Invalid extension attribute {name}: expected callable or None, "
"but got: {value}")
@add_codes

View File

@ -1,4 +1,11 @@
# coding: utf-8
from __future__ import unicode_literals
import pytest
from mock import Mock
from ..vocab import Vocab
from ..tokens.doc import Doc
from ..tokens.underscore import Underscore
@ -51,3 +58,36 @@ def test_token_underscore_method():
None, None)
token._ = Underscore(Underscore.token_extensions, token, start=token.idx)
assert token._.hello() == 'cheese'
@pytest.mark.parametrize('obj', [
Doc(Vocab(), words=['hello', 'world']),
Doc(Vocab(), words=['hello', 'world'])[1],
Doc(Vocab(), words=['hello', 'world'])[0:2]])
def test_underscore_raises_for_dup(obj):
obj.set_extension('test', default=None)
with pytest.raises(ValueError):
obj.set_extension('test', default=None)
@pytest.mark.parametrize('invalid_kwargs', [
{'getter': None, 'setter': lambda: None},
{'default': None, 'method': lambda: None, 'getter': lambda: None},
{'setter': lambda: None},
{'default': None, 'method': lambda: None},
{'getter': True}])
def test_underscore_raises_for_invalid(invalid_kwargs):
doc = Doc(Vocab(), words=['hello', 'world'])
with pytest.raises(ValueError):
doc.set_extension('test', **invalid_kwargs, force=True)
@pytest.mark.parametrize('valid_kwargs', [
{'getter': lambda: None},
{'getter': lambda: None, 'setter': lambda: None},
{'default': 'hello'},
{'default': None},
{'method': lambda: None}])
def test_underscore_accepts_valid(valid_kwargs):
doc = Doc(Vocab(), words=['hello', 'world'])
doc.set_extension('test', **valid_kwargs, force=True)

View File

@ -33,7 +33,7 @@ from ..util import normalize_slice
from ..compat import is_config, copy_reg, pickle, basestring_
from ..errors import Errors, Warnings, deprecation_warning
from .. import util
from .underscore import Underscore
from .underscore import Underscore, get_ext_args
from ._retokenize import Retokenizer
DEF PADDING = 5
@ -95,12 +95,10 @@ cdef class Doc:
spaces=[True, False, False])
"""
@classmethod
def set_extension(cls, name, default=None, method=None,
getter=None, setter=None):
nr_defined = sum(t is not None for t in (default, getter, setter, method))
if nr_defined != 1:
raise ValueError(Errors.E083.format(n_args=nr_defined))
Underscore.doc_extensions[name] = (default, method, getter, setter)
def set_extension(cls, name, **kwargs):
if cls.has_extension(name) and not kwargs.get('force', False):
raise ValueError(Errors.E090.format(name=name, obj='Doc'))
Underscore.doc_extensions[name] = get_ext_args(**kwargs)
@classmethod
def get_extension(cls, name):

View File

@ -17,15 +17,16 @@ from ..attrs cimport IS_PUNCT, IS_SPACE
from ..lexeme cimport Lexeme
from ..compat import is_config
from ..errors import Errors, TempErrors
from .underscore import Underscore
from .underscore import Underscore, get_ext_args
cdef class Span:
"""A slice from a Doc object."""
@classmethod
def set_extension(cls, name, default=None, method=None,
getter=None, setter=None):
Underscore.span_extensions[name] = (default, method, getter, setter)
def set_extension(cls, name, **kwargs):
if cls.has_extension(name) and not kwargs.get('force', False):
raise ValueError(Errors.E090.format(name=name, obj='Span'))
Underscore.span_extensions[name] = get_ext_args(**kwargs)
@classmethod
def get_extension(cls, name):

View File

@ -21,16 +21,17 @@ from ..attrs cimport LENGTH, CLUSTER, LEMMA, POS, TAG, DEP
from ..compat import is_config
from ..errors import Errors
from .. import util
from .underscore import Underscore
from .underscore import Underscore, get_ext_args
cdef class Token:
"""An individual token i.e. a word, punctuation symbol, whitespace,
etc."""
@classmethod
def set_extension(cls, name, default=None, method=None,
getter=None, setter=None):
Underscore.token_extensions[name] = (default, method, getter, setter)
def set_extension(cls, name, **kwargs):
if cls.has_extension(name) and not kwargs.get('force', False):
raise ValueError(Errors.E090.format(name=name, obj='Token'))
Underscore.token_extensions[name] = get_ext_args(**kwargs)
@classmethod
def get_extension(cls, name):

View File

@ -54,3 +54,24 @@ class Underscore(object):
def _get_key(self, name):
return ('._.', name, self._start, self._end)
def get_ext_args(**kwargs):
"""Validate and convert arguments. Reused in Doc, Token and Span."""
default = kwargs.get('default')
getter = kwargs.get('getter')
setter = kwargs.get('setter')
method = kwargs.get('method')
if getter is None and setter is not None:
raise ValueError(Errors.E089)
valid_opts = ('default' in kwargs, method is not None, getter is not None)
nr_defined = sum(t is True for t in valid_opts)
if nr_defined != 1:
raise ValueError(Errors.E083.format(nr_defined=nr_defined))
if setter is not None and not hasattr(setter, '__call__'):
raise ValueError(Errors.E091.format(name='setter', value=repr(setter)))
if getter is not None and not hasattr(getter, '__call__'):
raise ValueError(Errors.E091.format(name='getter', value=repr(getter)))
if method is not None and not hasattr(method, '__call__'):
raise ValueError(Errors.E091.format(name='method', value=repr(method)))
return (default, method, getter, setter)