diff --git a/graphene/types/field.py b/graphene/types/field.py index b1122696..d67707ac 100644 --- a/graphene/types/field.py +++ b/graphene/types/field.py @@ -2,6 +2,8 @@ import inspect from collections import Mapping, OrderedDict from functools import partial +from promise import Promise + from ..utils.orderedtype import OrderedType from .argument import to_arguments from .structures import NonNull @@ -18,7 +20,8 @@ class Field(OrderedType): def __init__(self, type, args=None, resolver=None, source=None, deprecation_reason=None, name=None, description=None, - required=False, _creation_counter=None, **extra_args): + required=False, _creation_counter=None, default_value=None, + **extra_args): super(Field, self).__init__(_creation_counter=_creation_counter) assert not args or isinstance(args, Mapping), ( 'Arguments in a field have to be a mapping, received "{}".' @@ -27,7 +30,7 @@ class Field(OrderedType): 'A Field cannot have a source and a resolver in at the same time.' ) - if required: + if required or default_value is not None: type = NonNull(type) self.name = name @@ -38,6 +41,7 @@ class Field(OrderedType): self.resolver = resolver self.deprecation_reason = deprecation_reason self.description = description + self.default_value = default_value @property def type(self): @@ -46,4 +50,14 @@ class Field(OrderedType): return self._type def get_resolver(self, parent_resolver): - return self.resolver or parent_resolver + resolver = self.resolver or parent_resolver + + def default_resolve(result): + return self.default_value if result is None else result + + def defualt_resolver(self, *args, **kwargs): + return Promise.resolve( + resolver(*args, **kwargs) + ).then(default_resolve) + + return defualt_resolver if self.default_value is not None else resolver diff --git a/graphene/types/tests/test_field.py b/graphene/types/tests/test_field.py index 5883e588..ae0f08ea 100644 --- a/graphene/types/tests/test_field.py +++ b/graphene/types/tests/test_field.py @@ -16,19 +16,22 @@ def test_field_basic(): resolver = lambda: None deprecation_reason = 'Deprecated now' description = 'My Field' + my_default='something' field = Field( MyType, name='name', args=args, resolver=resolver, description=description, - deprecation_reason=deprecation_reason + deprecation_reason=deprecation_reason, + default_value=my_default, ) assert field.name == 'name' assert field.args == args assert field.resolver == resolver assert field.deprecation_reason == deprecation_reason assert field.description == description + assert field.default_value == my_default def test_field_required(): diff --git a/graphene/types/tests/test_query.py b/graphene/types/tests/test_query.py index dd6de01c..f1a5e417 100644 --- a/graphene/types/tests/test_query.py +++ b/graphene/types/tests/test_query.py @@ -1,9 +1,10 @@ import json from functools import partial -from graphql import execute, Source, parse +from graphql import execute, Source, parse, GraphQLError from ..objecttype import ObjectType +from ..field import Field from ..inputfield import InputField from ..inputobjecttype import InputObjectType from ..scalars import String, Int @@ -22,6 +23,35 @@ def test_query(): assert executed.data == {'hello': 'World'} +def test_query_default_value(): + class MyType(ObjectType): + field = String() + + class Query(ObjectType): + hello = Field(MyType, default_value=MyType(field='something else!')) + + hello_schema = Schema(Query) + + executed = hello_schema.execute('{ hello { field } }') + assert not executed.errors + assert executed.data == {'hello': {'field': 'something else!'}} + + +def test_query_wrong_default_value(): + class MyType(ObjectType): + field = String() + + class Query(ObjectType): + hello = Field(MyType, default_value='hello') + + hello_schema = Schema(Query) + + executed = hello_schema.execute('{ hello { field } }') + assert len(executed.errors) == 1 + assert executed.errors[0].message == GraphQLError('Expected value of type "MyType" but got: str.').message + assert executed.data is None + + def test_query_resolve_function(): class Query(ObjectType): hello = String()