mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Don't raise error if set_extension has getter and setter (closes #2177)
Improve error messages, raise error if setter is specified without a getter and compare against _unset to allow default=None. Also add more tests.
This commit is contained in:
parent
ee3082ad29
commit
62b4b527d7
|
@ -218,8 +218,8 @@ class Errors(object):
|
||||||
E082 = ("Error deprojectivizing parse: number of heads ({n_heads}), "
|
E082 = ("Error deprojectivizing parse: number of heads ({n_heads}), "
|
||||||
"projective heads ({n_proj_heads}) and labels ({n_labels}) do not "
|
"projective heads ({n_proj_heads}) and labels ({n_labels}) do not "
|
||||||
"match.")
|
"match.")
|
||||||
E083 = ("Error setting extension: only one of default, getter, setter and "
|
E083 = ("Error setting extension: only one of `default`, `method`, or "
|
||||||
"method is allowed. {n_args} keyword arguments were specified.")
|
"`getter` (plus optional `setter`) is allowed. Got: {nr_defined}")
|
||||||
E084 = ("Error assigning label ID {label} to span: not in StringStore.")
|
E084 = ("Error assigning label ID {label} to span: not in StringStore.")
|
||||||
E085 = ("Can't create lexeme for string '{string}'.")
|
E085 = ("Can't create lexeme for string '{string}'.")
|
||||||
E086 = ("Error deserializing lexeme '{string}': orth ID {orth_id} does "
|
E086 = ("Error deserializing lexeme '{string}': orth ID {orth_id} does "
|
||||||
|
@ -233,6 +233,12 @@ class Errors(object):
|
||||||
"`nlp.max_length` limit. The limit is in number of characters, so "
|
"`nlp.max_length` limit. The limit is in number of characters, so "
|
||||||
"you can check whether your inputs are too long by checking "
|
"you can check whether your inputs are too long by checking "
|
||||||
"`len(text)`.")
|
"`len(text)`.")
|
||||||
|
E089 = ("Extensions can't have a setter argument without a getter "
|
||||||
|
"argument. Check the keyword arguments on `set_extension`.")
|
||||||
|
E090 = ("Extension '{name}' already exists on {obj}. To overwrite the "
|
||||||
|
"existing extension, set `force=True` on `{obj}.set_extension`.")
|
||||||
|
E091 = ("Invalid extension attribute {name}: expected callable or None, "
|
||||||
|
"but got: {value}")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
|
|
@ -1,4 +1,11 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import pytest
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
|
from ..vocab import Vocab
|
||||||
|
from ..tokens.doc import Doc
|
||||||
from ..tokens.underscore import Underscore
|
from ..tokens.underscore import Underscore
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,3 +58,36 @@ def test_token_underscore_method():
|
||||||
None, None)
|
None, None)
|
||||||
token._ = Underscore(Underscore.token_extensions, token, start=token.idx)
|
token._ = Underscore(Underscore.token_extensions, token, start=token.idx)
|
||||||
assert token._.hello() == 'cheese'
|
assert token._.hello() == 'cheese'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('obj', [
|
||||||
|
Doc(Vocab(), words=['hello', 'world']),
|
||||||
|
Doc(Vocab(), words=['hello', 'world'])[1],
|
||||||
|
Doc(Vocab(), words=['hello', 'world'])[0:2]])
|
||||||
|
def test_underscore_raises_for_dup(obj):
|
||||||
|
obj.set_extension('test', default=None)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
obj.set_extension('test', default=None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('invalid_kwargs', [
|
||||||
|
{'getter': None, 'setter': lambda: None},
|
||||||
|
{'default': None, 'method': lambda: None, 'getter': lambda: None},
|
||||||
|
{'setter': lambda: None},
|
||||||
|
{'default': None, 'method': lambda: None},
|
||||||
|
{'getter': True}])
|
||||||
|
def test_underscore_raises_for_invalid(invalid_kwargs):
|
||||||
|
doc = Doc(Vocab(), words=['hello', 'world'])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
doc.set_extension('test', **invalid_kwargs, force=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('valid_kwargs', [
|
||||||
|
{'getter': lambda: None},
|
||||||
|
{'getter': lambda: None, 'setter': lambda: None},
|
||||||
|
{'default': 'hello'},
|
||||||
|
{'default': None},
|
||||||
|
{'method': lambda: None}])
|
||||||
|
def test_underscore_accepts_valid(valid_kwargs):
|
||||||
|
doc = Doc(Vocab(), words=['hello', 'world'])
|
||||||
|
doc.set_extension('test', **valid_kwargs, force=True)
|
||||||
|
|
|
@ -33,7 +33,7 @@ from ..util import normalize_slice
|
||||||
from ..compat import is_config, copy_reg, pickle, basestring_
|
from ..compat import is_config, copy_reg, pickle, basestring_
|
||||||
from ..errors import Errors, Warnings, deprecation_warning
|
from ..errors import Errors, Warnings, deprecation_warning
|
||||||
from .. import util
|
from .. import util
|
||||||
from .underscore import Underscore
|
from .underscore import Underscore, get_ext_args
|
||||||
from ._retokenize import Retokenizer
|
from ._retokenize import Retokenizer
|
||||||
|
|
||||||
DEF PADDING = 5
|
DEF PADDING = 5
|
||||||
|
@ -95,12 +95,10 @@ cdef class Doc:
|
||||||
spaces=[True, False, False])
|
spaces=[True, False, False])
|
||||||
"""
|
"""
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_extension(cls, name, default=None, method=None,
|
def set_extension(cls, name, **kwargs):
|
||||||
getter=None, setter=None):
|
if cls.has_extension(name) and not kwargs.get('force', False):
|
||||||
nr_defined = sum(t is not None for t in (default, getter, setter, method))
|
raise ValueError(Errors.E090.format(name=name, obj='Doc'))
|
||||||
if nr_defined != 1:
|
Underscore.doc_extensions[name] = get_ext_args(**kwargs)
|
||||||
raise ValueError(Errors.E083.format(n_args=nr_defined))
|
|
||||||
Underscore.doc_extensions[name] = (default, method, getter, setter)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_extension(cls, name):
|
def get_extension(cls, name):
|
||||||
|
|
|
@ -17,15 +17,16 @@ from ..attrs cimport IS_PUNCT, IS_SPACE
|
||||||
from ..lexeme cimport Lexeme
|
from ..lexeme cimport Lexeme
|
||||||
from ..compat import is_config
|
from ..compat import is_config
|
||||||
from ..errors import Errors, TempErrors
|
from ..errors import Errors, TempErrors
|
||||||
from .underscore import Underscore
|
from .underscore import Underscore, get_ext_args
|
||||||
|
|
||||||
|
|
||||||
cdef class Span:
|
cdef class Span:
|
||||||
"""A slice from a Doc object."""
|
"""A slice from a Doc object."""
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_extension(cls, name, default=None, method=None,
|
def set_extension(cls, name, **kwargs):
|
||||||
getter=None, setter=None):
|
if cls.has_extension(name) and not kwargs.get('force', False):
|
||||||
Underscore.span_extensions[name] = (default, method, getter, setter)
|
raise ValueError(Errors.E090.format(name=name, obj='Span'))
|
||||||
|
Underscore.span_extensions[name] = get_ext_args(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_extension(cls, name):
|
def get_extension(cls, name):
|
||||||
|
|
|
@ -21,16 +21,17 @@ from ..attrs cimport LENGTH, CLUSTER, LEMMA, POS, TAG, DEP
|
||||||
from ..compat import is_config
|
from ..compat import is_config
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from .. import util
|
from .. import util
|
||||||
from .underscore import Underscore
|
from .underscore import Underscore, get_ext_args
|
||||||
|
|
||||||
|
|
||||||
cdef class Token:
|
cdef class Token:
|
||||||
"""An individual token – i.e. a word, punctuation symbol, whitespace,
|
"""An individual token – i.e. a word, punctuation symbol, whitespace,
|
||||||
etc."""
|
etc."""
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_extension(cls, name, default=None, method=None,
|
def set_extension(cls, name, **kwargs):
|
||||||
getter=None, setter=None):
|
if cls.has_extension(name) and not kwargs.get('force', False):
|
||||||
Underscore.token_extensions[name] = (default, method, getter, setter)
|
raise ValueError(Errors.E090.format(name=name, obj='Token'))
|
||||||
|
Underscore.token_extensions[name] = get_ext_args(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_extension(cls, name):
|
def get_extension(cls, name):
|
||||||
|
|
|
@ -54,3 +54,24 @@ class Underscore(object):
|
||||||
|
|
||||||
def _get_key(self, name):
|
def _get_key(self, name):
|
||||||
return ('._.', name, self._start, self._end)
|
return ('._.', name, self._start, self._end)
|
||||||
|
|
||||||
|
|
||||||
|
def get_ext_args(**kwargs):
|
||||||
|
"""Validate and convert arguments. Reused in Doc, Token and Span."""
|
||||||
|
default = kwargs.get('default')
|
||||||
|
getter = kwargs.get('getter')
|
||||||
|
setter = kwargs.get('setter')
|
||||||
|
method = kwargs.get('method')
|
||||||
|
if getter is None and setter is not None:
|
||||||
|
raise ValueError(Errors.E089)
|
||||||
|
valid_opts = ('default' in kwargs, method is not None, getter is not None)
|
||||||
|
nr_defined = sum(t is True for t in valid_opts)
|
||||||
|
if nr_defined != 1:
|
||||||
|
raise ValueError(Errors.E083.format(nr_defined=nr_defined))
|
||||||
|
if setter is not None and not hasattr(setter, '__call__'):
|
||||||
|
raise ValueError(Errors.E091.format(name='setter', value=repr(setter)))
|
||||||
|
if getter is not None and not hasattr(getter, '__call__'):
|
||||||
|
raise ValueError(Errors.E091.format(name='getter', value=repr(getter)))
|
||||||
|
if method is not None and not hasattr(method, '__call__'):
|
||||||
|
raise ValueError(Errors.E091.format(name='method', value=repr(method)))
|
||||||
|
return (default, method, getter, setter)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user