diff --git a/.gitignore b/.gitignore index b7554723..9f465556 100644 --- a/.gitignore +++ b/.gitignore @@ -75,6 +75,7 @@ target/ # PyCharm .idea +*.iml # Databases *.sqlite3 diff --git a/graphene/types/field.py b/graphene/types/field.py index 531c2f5c..6331b0ef 100644 --- a/graphene/types/field.py +++ b/graphene/types/field.py @@ -19,7 +19,8 @@ class Field(OrderedType): def __init__(self, type, args=None, resolver=None, source=None, deprecation_reason=None, name=None, description=None, - required=False, _creation_counter=None, **extra_args): + required=False, _creation_counter=None, default_value=None, + **extra_args): super(Field, self).__init__(_creation_counter=_creation_counter) assert not args or isinstance(args, Mapping), ( 'Arguments in a field have to be a mapping, received "{}".' @@ -27,6 +28,9 @@ class Field(OrderedType): assert not (source and resolver), ( 'A Field cannot have a source and a resolver in at the same time.' ) + assert not callable(default_value), ( + 'The default value can not be a function but received "{}".' + ).format(type(default_value)) if required: type = NonNull(type) @@ -49,6 +53,7 @@ class Field(OrderedType): self.resolver = resolver self.deprecation_reason = deprecation_reason self.description = description + self.default_value = default_value @property def type(self): diff --git a/graphene/types/tests/test_field.py b/graphene/types/tests/test_field.py index 5883e588..7ca557ba 100644 --- a/graphene/types/tests/test_field.py +++ b/graphene/types/tests/test_field.py @@ -16,19 +16,22 @@ def test_field_basic(): resolver = lambda: None deprecation_reason = 'Deprecated now' description = 'My Field' + my_default='something' field = Field( MyType, name='name', args=args, resolver=resolver, description=description, - deprecation_reason=deprecation_reason + deprecation_reason=deprecation_reason, + default_value=my_default, ) assert field.name == 'name' assert field.args == args assert field.resolver == resolver assert field.deprecation_reason == deprecation_reason assert field.description == description + assert field.default_value == my_default def test_field_required(): @@ -38,6 +41,15 @@ def test_field_required(): assert field.type.of_type == MyType +def test_field_default_value_not_callable(): + MyType = object() + try: + Field(MyType, default_value=lambda: True) + except AssertionError as e: + # substring comparison for py 2/3 compatibility + assert 'The default value can not be a function but received' in str(e) + + def test_field_source(): MyType = object() field = Field(MyType, source='value') diff --git a/graphene/types/tests/test_query.py b/graphene/types/tests/test_query.py index 3059776a..1f09b925 100644 --- a/graphene/types/tests/test_query.py +++ b/graphene/types/tests/test_query.py @@ -1,8 +1,9 @@ import json from functools import partial -from graphql import Source, execute, parse +from graphql import Source, execute, parse, GraphQLError +from ..field import Field from ..inputfield import InputField from ..inputobjecttype import InputObjectType from ..objecttype import ObjectType @@ -22,6 +23,49 @@ def test_query(): assert executed.data == {'hello': 'World'} +def test_query_default_value(): + class MyType(ObjectType): + field = String() + + class Query(ObjectType): + hello = Field(MyType, default_value=MyType(field='something else!')) + + hello_schema = Schema(Query) + + executed = hello_schema.execute('{ hello { field } }') + assert not executed.errors + assert executed.data == {'hello': {'field': 'something else!'}} + + +def test_query_wrong_default_value(): + class MyType(ObjectType): + field = String() + + class Query(ObjectType): + hello = Field(MyType, default_value='hello') + + hello_schema = Schema(Query) + + executed = hello_schema.execute('{ hello { field } }') + assert len(executed.errors) == 1 + assert executed.errors[0].message == GraphQLError('Expected value of type "MyType" but got: str.').message + assert executed.data == {'hello': None} + + +def test_query_default_value_ignored_by_resolver(): + class MyType(ObjectType): + field = String() + + class Query(ObjectType): + hello = Field(MyType, default_value='hello', resolver=lambda *_: MyType(field='no default.')) + + hello_schema = Schema(Query) + + executed = hello_schema.execute('{ hello { field } }') + assert not executed.errors + assert executed.data == {'hello': {'field': 'no default.'}} + + def test_query_resolve_function(): class Query(ObjectType): hello = String() diff --git a/graphene/types/typemap.py b/graphene/types/typemap.py index af193a9b..ed036d6c 100644 --- a/graphene/types/typemap.py +++ b/graphene/types/typemap.py @@ -190,8 +190,8 @@ class TypeMap(GraphQLTypeMap): return to_camel_case(name) return name - def default_resolver(self, attname, root, *_): - return getattr(root, attname, None) + def default_resolver(self, attname, default_value, root, *_): + return getattr(root, attname, default_value) def construct_fields_for_type(self, map, type, is_input_type=False): fields = OrderedDict() @@ -224,7 +224,7 @@ class TypeMap(GraphQLTypeMap): _field = GraphQLField( field_type, args=args, - resolver=field.get_resolver(self.get_resolver_for_type(type, name)), + resolver=field.get_resolver(self.get_resolver_for_type(type, name, field.default_value)), deprecation_reason=field.deprecation_reason, description=field.description ) @@ -232,7 +232,7 @@ class TypeMap(GraphQLTypeMap): fields[field_name] = _field return fields - def get_resolver_for_type(self, type, name): + def get_resolver_for_type(self, type, name, default_value): if not issubclass(type, ObjectType): return resolver = getattr(type, 'resolve_{}'.format(name), None) @@ -253,7 +253,7 @@ class TypeMap(GraphQLTypeMap): return resolver.__func__ return resolver - return partial(self.default_resolver, name) + return partial(self.default_resolver, name, default_value) def get_field_type(self, map, type): if isinstance(type, List):