diff --git a/graphene/types/resolver.py b/graphene/types/resolver.py index 888aba8a..6a8ea02b 100644 --- a/graphene/types/resolver.py +++ b/graphene/types/resolver.py @@ -6,7 +6,14 @@ def dict_resolver(attname, default_value, root, info, **args): return root.get(attname, default_value) -default_resolver = attr_resolver +def dict_or_attr_resolver(attname, default_value, root, info, **args): + resolver = attr_resolver + if isinstance(root, dict): + resolver = dict_resolver + return resolver(attname, default_value, root, info, **args) + + +default_resolver = dict_or_attr_resolver def set_default_resolver(resolver): diff --git a/graphene/types/tests/test_resolver.py b/graphene/types/tests/test_resolver.py index 2a15028d..a03cf187 100644 --- a/graphene/types/tests/test_resolver.py +++ b/graphene/types/tests/test_resolver.py @@ -1,6 +1,7 @@ from ..resolver import ( attr_resolver, dict_resolver, + dict_or_attr_resolver, get_default_resolver, set_default_resolver, ) @@ -36,8 +37,16 @@ def test_dict_resolver_default_value(): assert resolved == "default" +def test_dict_or_attr_resolver(): + resolved = dict_or_attr_resolver("attr", None, demo_dict, info, **args) + assert resolved == "value" + + resolved = dict_or_attr_resolver("attr", None, demo_obj, info, **args) + assert resolved == "value" + + def test_get_default_resolver_is_attr_resolver(): - assert get_default_resolver() == attr_resolver + assert get_default_resolver() == dict_or_attr_resolver def test_set_default_resolver_workd():