diff --git a/graphene/types/field.py b/graphene/types/field.py index b1122696..88a60f25 100644 --- a/graphene/types/field.py +++ b/graphene/types/field.py @@ -2,6 +2,8 @@ import inspect from collections import Mapping, OrderedDict from functools import partial +from promise import Promise + from ..utils.orderedtype import OrderedType from .argument import to_arguments from .structures import NonNull @@ -16,9 +18,10 @@ def source_resolver(source, root, args, context, info): class Field(OrderedType): - def __init__(self, type, args=None, resolver=None, source=None, + def __init__(self, object_type, args=None, resolver=None, source=None, deprecation_reason=None, name=None, description=None, - required=False, _creation_counter=None, **extra_args): + required=False, _creation_counter=None, default_value=None, + **extra_args): super(Field, self).__init__(_creation_counter=_creation_counter) assert not args or isinstance(args, Mapping), ( 'Arguments in a field have to be a mapping, received "{}".' @@ -26,18 +29,22 @@ class Field(OrderedType): assert not (source and resolver), ( 'A Field cannot have a source and a resolver in at the same time.' ) + assert not callable(default_value), ( + 'The default value can not be a function but received "{}".' + ).format(type(default_value)) - if required: - type = NonNull(type) + if required or default_value is not None: + object_type = NonNull(object_type) self.name = name - self._type = type + self._type = object_type self.args = to_arguments(args or OrderedDict(), extra_args) if source: resolver = partial(source_resolver, source) self.resolver = resolver self.deprecation_reason = deprecation_reason self.description = description + self.default_value = default_value @property def type(self): @@ -46,4 +53,14 @@ class Field(OrderedType): return self._type def get_resolver(self, parent_resolver): - return self.resolver or parent_resolver + resolver = self.resolver or parent_resolver + + def default_resolve(result): + return self.default_value if result is None else result + + def defualt_resolver(self, *args, **kwargs): + return Promise.resolve( + resolver(*args, **kwargs) + ).then(default_resolve) + + return defualt_resolver if self.default_value is not None else resolver diff --git a/graphene/types/tests/test_field.py b/graphene/types/tests/test_field.py index 5883e588..7ca557ba 100644 --- a/graphene/types/tests/test_field.py +++ b/graphene/types/tests/test_field.py @@ -16,19 +16,22 @@ def test_field_basic(): resolver = lambda: None deprecation_reason = 'Deprecated now' description = 'My Field' + my_default='something' field = Field( MyType, name='name', args=args, resolver=resolver, description=description, - deprecation_reason=deprecation_reason + deprecation_reason=deprecation_reason, + default_value=my_default, ) assert field.name == 'name' assert field.args == args assert field.resolver == resolver assert field.deprecation_reason == deprecation_reason assert field.description == description + assert field.default_value == my_default def test_field_required(): @@ -38,6 +41,15 @@ def test_field_required(): assert field.type.of_type == MyType +def test_field_default_value_not_callable(): + MyType = object() + try: + Field(MyType, default_value=lambda: True) + except AssertionError as e: + # substring comparison for py 2/3 compatibility + assert 'The default value can not be a function but received' in str(e) + + def test_field_source(): MyType = object() field = Field(MyType, source='value') diff --git a/graphene/types/tests/test_query.py b/graphene/types/tests/test_query.py index dd6de01c..f1a5e417 100644 --- a/graphene/types/tests/test_query.py +++ b/graphene/types/tests/test_query.py @@ -1,9 +1,10 @@ import json from functools import partial -from graphql import execute, Source, parse +from graphql import execute, Source, parse, GraphQLError from ..objecttype import ObjectType +from ..field import Field from ..inputfield import InputField from ..inputobjecttype import InputObjectType from ..scalars import String, Int @@ -22,6 +23,35 @@ def test_query(): assert executed.data == {'hello': 'World'} +def test_query_default_value(): + class MyType(ObjectType): + field = String() + + class Query(ObjectType): + hello = Field(MyType, default_value=MyType(field='something else!')) + + hello_schema = Schema(Query) + + executed = hello_schema.execute('{ hello { field } }') + assert not executed.errors + assert executed.data == {'hello': {'field': 'something else!'}} + + +def test_query_wrong_default_value(): + class MyType(ObjectType): + field = String() + + class Query(ObjectType): + hello = Field(MyType, default_value='hello') + + hello_schema = Schema(Query) + + executed = hello_schema.execute('{ hello { field } }') + assert len(executed.errors) == 1 + assert executed.errors[0].message == GraphQLError('Expected value of type "MyType" but got: str.').message + assert executed.data is None + + def test_query_resolve_function(): class Query(ObjectType): hello = String()