Added optional default_resolver to ObjectType.

This commit is contained in:
Syrus Akbary 2017-02-23 21:37:45 -08:00
parent 62e58bd953
commit 98825fa4bc
4 changed files with 73 additions and 6 deletions

View File

@ -19,12 +19,13 @@ class ObjectTypeMeta(AbstractTypeMeta):
if not is_base_type(bases, ObjectTypeMeta): if not is_base_type(bases, ObjectTypeMeta):
return type.__new__(cls, name, bases, attrs) return type.__new__(cls, name, bases, attrs)
_meta = attrs.pop('_meta', None) attrs.pop('_meta', None)
options = _meta or Options( options = Options(
attrs.pop('Meta', None), attrs.pop('Meta', None),
name=name, name=name,
description=trim_docstring(attrs.get('__doc__')), description=trim_docstring(attrs.get('__doc__')),
interfaces=(), interfaces=(),
default_resolver=None,
local_fields=OrderedDict(), local_fields=OrderedDict(),
) )
options.base_fields = get_base_fields(bases, _as=Field) options.base_fields = get_base_fields(bases, _as=Field)

View File

@ -0,0 +1,19 @@
def attr_resolver(attname, default_value, root, args, context, info):
return getattr(root, attname, default_value)
def dict_resolver(attname, default_value, root, args, context, info):
return root.get(attname, default_value)
default_resolver = attr_resolver
def set_default_resolver(resolver):
global default_resolver
assert callable(resolver), 'Received non-callable resolver.'
default_resolver = resolver
def get_default_resolver():
return default_resolver

View File

@ -0,0 +1,48 @@
import pytest
from ..resolver import attr_resolver, dict_resolver, get_default_resolver, set_default_resolver
args = {}
context = None
info = None
demo_dict = {
'attr': 'value'
}
class demo_obj(object):
attr = 'value'
def test_attr_resolver():
resolved = attr_resolver('attr', None, demo_obj, args, context, info)
assert resolved == 'value'
def test_attr_resolver_default_value():
resolved = attr_resolver('attr2', 'default', demo_obj, args, context, info)
assert resolved == 'default'
def test_dict_resolver():
resolved = dict_resolver('attr', None, demo_dict, args, context, info)
assert resolved == 'value'
def test_dict_resolver_default_value():
resolved = dict_resolver('attr2', 'default', demo_dict, args, context, info)
assert resolved == 'default'
def test_get_default_resolver_is_attr_resolver():
assert get_default_resolver() == attr_resolver
def test_set_default_resolver_workd():
default_resolver = get_default_resolver()
set_default_resolver(dict_resolver)
assert get_default_resolver() == dict_resolver
set_default_resolver(default_resolver)

View File

@ -21,6 +21,7 @@ from .field import Field
from .inputobjecttype import InputObjectType from .inputobjecttype import InputObjectType
from .interface import Interface from .interface import Interface
from .objecttype import ObjectType from .objecttype import ObjectType
from .resolver import get_default_resolver
from .scalars import ID, Boolean, Float, Int, Scalar, String from .scalars import ID, Boolean, Float, Int, Scalar, String
from .structures import List, NonNull from .structures import List, NonNull
from .union import Union from .union import Union
@ -205,9 +206,6 @@ class TypeMap(GraphQLTypeMap):
return to_camel_case(name) return to_camel_case(name)
return name return name
def default_resolver(self, attname, default_value, root, *_):
return getattr(root, attname, default_value)
def construct_fields_for_type(self, map, type, is_input_type=False): def construct_fields_for_type(self, map, type, is_input_type=False):
fields = OrderedDict() fields = OrderedDict()
for name, field in type._meta.fields.items(): for name, field in type._meta.fields.items():
@ -267,7 +265,8 @@ class TypeMap(GraphQLTypeMap):
if resolver: if resolver:
return get_unbound_function(resolver) return get_unbound_function(resolver)
return partial(self.default_resolver, name, default_value) default_resolver = type._meta.default_resolver or get_default_resolver()
return partial(default_resolver, name, default_value)
def get_field_type(self, map, type): def get_field_type(self, map, type):
if isinstance(type, List): if isinstance(type, List):