mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Don't raise error if set_extension has getter and setter (closes #2177)
Improve error messages, raise error if setter is specified without a getter and compare against _unset to allow default=None. Also add more tests.
This commit is contained in:
parent
ee3082ad29
commit
62b4b527d7
|
@ -218,8 +218,8 @@ class Errors(object):
|
|||
E082 = ("Error deprojectivizing parse: number of heads ({n_heads}), "
|
||||
"projective heads ({n_proj_heads}) and labels ({n_labels}) do not "
|
||||
"match.")
|
||||
E083 = ("Error setting extension: only one of default, getter, setter and "
|
||||
"method is allowed. {n_args} keyword arguments were specified.")
|
||||
E083 = ("Error setting extension: only one of `default`, `method`, or "
|
||||
"`getter` (plus optional `setter`) is allowed. Got: {nr_defined}")
|
||||
E084 = ("Error assigning label ID {label} to span: not in StringStore.")
|
||||
E085 = ("Can't create lexeme for string '{string}'.")
|
||||
E086 = ("Error deserializing lexeme '{string}': orth ID {orth_id} does "
|
||||
|
@ -233,6 +233,12 @@ class Errors(object):
|
|||
"`nlp.max_length` limit. The limit is in number of characters, so "
|
||||
"you can check whether your inputs are too long by checking "
|
||||
"`len(text)`.")
|
||||
E089 = ("Extensions can't have a setter argument without a getter "
|
||||
"argument. Check the keyword arguments on `set_extension`.")
|
||||
E090 = ("Extension '{name}' already exists on {obj}. To overwrite the "
|
||||
"existing extension, set `force=True` on `{obj}.set_extension`.")
|
||||
E091 = ("Invalid extension attribute {name}: expected callable or None, "
|
||||
"but got: {value}")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
|
|
@ -1,4 +1,11 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
from mock import Mock
|
||||
|
||||
from ..vocab import Vocab
|
||||
from ..tokens.doc import Doc
|
||||
from ..tokens.underscore import Underscore
|
||||
|
||||
|
||||
|
@ -51,3 +58,36 @@ def test_token_underscore_method():
|
|||
None, None)
|
||||
token._ = Underscore(Underscore.token_extensions, token, start=token.idx)
|
||||
assert token._.hello() == 'cheese'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('obj', [
|
||||
Doc(Vocab(), words=['hello', 'world']),
|
||||
Doc(Vocab(), words=['hello', 'world'])[1],
|
||||
Doc(Vocab(), words=['hello', 'world'])[0:2]])
|
||||
def test_underscore_raises_for_dup(obj):
|
||||
obj.set_extension('test', default=None)
|
||||
with pytest.raises(ValueError):
|
||||
obj.set_extension('test', default=None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('invalid_kwargs', [
|
||||
{'getter': None, 'setter': lambda: None},
|
||||
{'default': None, 'method': lambda: None, 'getter': lambda: None},
|
||||
{'setter': lambda: None},
|
||||
{'default': None, 'method': lambda: None},
|
||||
{'getter': True}])
|
||||
def test_underscore_raises_for_invalid(invalid_kwargs):
|
||||
doc = Doc(Vocab(), words=['hello', 'world'])
|
||||
with pytest.raises(ValueError):
|
||||
doc.set_extension('test', **invalid_kwargs, force=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('valid_kwargs', [
|
||||
{'getter': lambda: None},
|
||||
{'getter': lambda: None, 'setter': lambda: None},
|
||||
{'default': 'hello'},
|
||||
{'default': None},
|
||||
{'method': lambda: None}])
|
||||
def test_underscore_accepts_valid(valid_kwargs):
|
||||
doc = Doc(Vocab(), words=['hello', 'world'])
|
||||
doc.set_extension('test', **valid_kwargs, force=True)
|
||||
|
|
|
@ -33,7 +33,7 @@ from ..util import normalize_slice
|
|||
from ..compat import is_config, copy_reg, pickle, basestring_
|
||||
from ..errors import Errors, Warnings, deprecation_warning
|
||||
from .. import util
|
||||
from .underscore import Underscore
|
||||
from .underscore import Underscore, get_ext_args
|
||||
from ._retokenize import Retokenizer
|
||||
|
||||
DEF PADDING = 5
|
||||
|
@ -95,12 +95,10 @@ cdef class Doc:
|
|||
spaces=[True, False, False])
|
||||
"""
|
||||
@classmethod
|
||||
def set_extension(cls, name, default=None, method=None,
|
||||
getter=None, setter=None):
|
||||
nr_defined = sum(t is not None for t in (default, getter, setter, method))
|
||||
if nr_defined != 1:
|
||||
raise ValueError(Errors.E083.format(n_args=nr_defined))
|
||||
Underscore.doc_extensions[name] = (default, method, getter, setter)
|
||||
def set_extension(cls, name, **kwargs):
|
||||
if cls.has_extension(name) and not kwargs.get('force', False):
|
||||
raise ValueError(Errors.E090.format(name=name, obj='Doc'))
|
||||
Underscore.doc_extensions[name] = get_ext_args(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_extension(cls, name):
|
||||
|
|
|
@ -17,15 +17,16 @@ from ..attrs cimport IS_PUNCT, IS_SPACE
|
|||
from ..lexeme cimport Lexeme
|
||||
from ..compat import is_config
|
||||
from ..errors import Errors, TempErrors
|
||||
from .underscore import Underscore
|
||||
from .underscore import Underscore, get_ext_args
|
||||
|
||||
|
||||
cdef class Span:
|
||||
"""A slice from a Doc object."""
|
||||
@classmethod
|
||||
def set_extension(cls, name, default=None, method=None,
|
||||
getter=None, setter=None):
|
||||
Underscore.span_extensions[name] = (default, method, getter, setter)
|
||||
def set_extension(cls, name, **kwargs):
|
||||
if cls.has_extension(name) and not kwargs.get('force', False):
|
||||
raise ValueError(Errors.E090.format(name=name, obj='Span'))
|
||||
Underscore.span_extensions[name] = get_ext_args(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_extension(cls, name):
|
||||
|
|
|
@ -21,16 +21,17 @@ from ..attrs cimport LENGTH, CLUSTER, LEMMA, POS, TAG, DEP
|
|||
from ..compat import is_config
|
||||
from ..errors import Errors
|
||||
from .. import util
|
||||
from .underscore import Underscore
|
||||
from .underscore import Underscore, get_ext_args
|
||||
|
||||
|
||||
cdef class Token:
|
||||
"""An individual token – i.e. a word, punctuation symbol, whitespace,
|
||||
etc."""
|
||||
@classmethod
|
||||
def set_extension(cls, name, default=None, method=None,
|
||||
getter=None, setter=None):
|
||||
Underscore.token_extensions[name] = (default, method, getter, setter)
|
||||
def set_extension(cls, name, **kwargs):
|
||||
if cls.has_extension(name) and not kwargs.get('force', False):
|
||||
raise ValueError(Errors.E090.format(name=name, obj='Token'))
|
||||
Underscore.token_extensions[name] = get_ext_args(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_extension(cls, name):
|
||||
|
|
|
@ -54,3 +54,24 @@ class Underscore(object):
|
|||
|
||||
def _get_key(self, name):
|
||||
return ('._.', name, self._start, self._end)
|
||||
|
||||
|
||||
def get_ext_args(**kwargs):
|
||||
"""Validate and convert arguments. Reused in Doc, Token and Span."""
|
||||
default = kwargs.get('default')
|
||||
getter = kwargs.get('getter')
|
||||
setter = kwargs.get('setter')
|
||||
method = kwargs.get('method')
|
||||
if getter is None and setter is not None:
|
||||
raise ValueError(Errors.E089)
|
||||
valid_opts = ('default' in kwargs, method is not None, getter is not None)
|
||||
nr_defined = sum(t is True for t in valid_opts)
|
||||
if nr_defined != 1:
|
||||
raise ValueError(Errors.E083.format(nr_defined=nr_defined))
|
||||
if setter is not None and not hasattr(setter, '__call__'):
|
||||
raise ValueError(Errors.E091.format(name='setter', value=repr(setter)))
|
||||
if getter is not None and not hasattr(getter, '__call__'):
|
||||
raise ValueError(Errors.E091.format(name='getter', value=repr(getter)))
|
||||
if method is not None and not hasattr(method, '__call__'):
|
||||
raise ValueError(Errors.E091.format(name='method', value=repr(method)))
|
||||
return (default, method, getter, setter)
|
||||
|
|
Loading…
Reference in New Issue
Block a user