mirror of
https://github.com/graphql-python/graphene.git
synced 2025-02-02 12:44:15 +03:00
Added optional default_resolver to ObjectType.
This commit is contained in:
parent
62e58bd953
commit
98825fa4bc
|
@ -19,12 +19,13 @@ class ObjectTypeMeta(AbstractTypeMeta):
|
|||
if not is_base_type(bases, ObjectTypeMeta):
|
||||
return type.__new__(cls, name, bases, attrs)
|
||||
|
||||
_meta = attrs.pop('_meta', None)
|
||||
options = _meta or Options(
|
||||
attrs.pop('_meta', None)
|
||||
options = Options(
|
||||
attrs.pop('Meta', None),
|
||||
name=name,
|
||||
description=trim_docstring(attrs.get('__doc__')),
|
||||
interfaces=(),
|
||||
default_resolver=None,
|
||||
local_fields=OrderedDict(),
|
||||
)
|
||||
options.base_fields = get_base_fields(bases, _as=Field)
|
||||
|
|
19
graphene/types/resolver.py
Normal file
19
graphene/types/resolver.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
def attr_resolver(attname, default_value, root, args, context, info):
|
||||
return getattr(root, attname, default_value)
|
||||
|
||||
|
||||
def dict_resolver(attname, default_value, root, args, context, info):
|
||||
return root.get(attname, default_value)
|
||||
|
||||
|
||||
default_resolver = attr_resolver
|
||||
|
||||
|
||||
def set_default_resolver(resolver):
|
||||
global default_resolver
|
||||
assert callable(resolver), 'Received non-callable resolver.'
|
||||
default_resolver = resolver
|
||||
|
||||
|
||||
def get_default_resolver():
|
||||
return default_resolver
|
48
graphene/types/tests/test_resolver.py
Normal file
48
graphene/types/tests/test_resolver.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
import pytest
|
||||
|
||||
from ..resolver import attr_resolver, dict_resolver, get_default_resolver, set_default_resolver
|
||||
|
||||
args = {}
|
||||
context = None
|
||||
info = None
|
||||
|
||||
demo_dict = {
|
||||
'attr': 'value'
|
||||
}
|
||||
|
||||
|
||||
class demo_obj(object):
|
||||
attr = 'value'
|
||||
|
||||
|
||||
def test_attr_resolver():
|
||||
resolved = attr_resolver('attr', None, demo_obj, args, context, info)
|
||||
assert resolved == 'value'
|
||||
|
||||
|
||||
def test_attr_resolver_default_value():
|
||||
resolved = attr_resolver('attr2', 'default', demo_obj, args, context, info)
|
||||
assert resolved == 'default'
|
||||
|
||||
|
||||
def test_dict_resolver():
|
||||
resolved = dict_resolver('attr', None, demo_dict, args, context, info)
|
||||
assert resolved == 'value'
|
||||
|
||||
|
||||
def test_dict_resolver_default_value():
|
||||
resolved = dict_resolver('attr2', 'default', demo_dict, args, context, info)
|
||||
assert resolved == 'default'
|
||||
|
||||
|
||||
def test_get_default_resolver_is_attr_resolver():
|
||||
assert get_default_resolver() == attr_resolver
|
||||
|
||||
|
||||
def test_set_default_resolver_workd():
|
||||
default_resolver = get_default_resolver()
|
||||
|
||||
set_default_resolver(dict_resolver)
|
||||
assert get_default_resolver() == dict_resolver
|
||||
|
||||
set_default_resolver(default_resolver)
|
|
@ -21,6 +21,7 @@ from .field import Field
|
|||
from .inputobjecttype import InputObjectType
|
||||
from .interface import Interface
|
||||
from .objecttype import ObjectType
|
||||
from .resolver import get_default_resolver
|
||||
from .scalars import ID, Boolean, Float, Int, Scalar, String
|
||||
from .structures import List, NonNull
|
||||
from .union import Union
|
||||
|
@ -205,9 +206,6 @@ class TypeMap(GraphQLTypeMap):
|
|||
return to_camel_case(name)
|
||||
return name
|
||||
|
||||
def default_resolver(self, attname, default_value, root, *_):
|
||||
return getattr(root, attname, default_value)
|
||||
|
||||
def construct_fields_for_type(self, map, type, is_input_type=False):
|
||||
fields = OrderedDict()
|
||||
for name, field in type._meta.fields.items():
|
||||
|
@ -267,7 +265,8 @@ class TypeMap(GraphQLTypeMap):
|
|||
if resolver:
|
||||
return get_unbound_function(resolver)
|
||||
|
||||
return partial(self.default_resolver, name, default_value)
|
||||
default_resolver = type._meta.default_resolver or get_default_resolver()
|
||||
return partial(default_resolver, name, default_value)
|
||||
|
||||
def get_field_type(self, map, type):
|
||||
if isinstance(type, List):
|
||||
|
|
Loading…
Reference in New Issue
Block a user