Added default value for default resolver.

This commit is contained in:
Markus Padourek 2016-09-21 09:34:29 +01:00
parent 316569b019
commit d9b8f5941d
5 changed files with 73 additions and 11 deletions

1
.gitignore vendored
View File

@ -75,6 +75,7 @@ target/
# PyCharm # PyCharm
.idea .idea
*.iml
# Databases # Databases
*.sqlite3 *.sqlite3

View File

@ -17,9 +17,10 @@ def source_resolver(source, root, args, context, info):
class Field(OrderedType): class Field(OrderedType):
def __init__(self, type, args=None, resolver=None, source=None, def __init__(self, gql_type, args=None, resolver=None, source=None,
deprecation_reason=None, name=None, description=None, deprecation_reason=None, name=None, description=None,
required=False, _creation_counter=None, **extra_args): required=False, _creation_counter=None, default_value=None,
**extra_args):
super(Field, self).__init__(_creation_counter=_creation_counter) super(Field, self).__init__(_creation_counter=_creation_counter)
assert not args or isinstance(args, Mapping), ( assert not args or isinstance(args, Mapping), (
'Arguments in a field have to be a mapping, received "{}".' 'Arguments in a field have to be a mapping, received "{}".'
@ -27,9 +28,12 @@ class Field(OrderedType):
assert not (source and resolver), ( assert not (source and resolver), (
'A Field cannot have a source and a resolver in at the same time.' 'A Field cannot have a source and a resolver in at the same time.'
) )
assert not callable(default_value), (
'The default value can not be a function but received "{}".'
).format(type(default_value))
if required: if required:
type = NonNull(type) gql_type = NonNull(gql_type)
# Check if name is actually an argument of the field # Check if name is actually an argument of the field
if isinstance(name, (Argument, UnmountedType)): if isinstance(name, (Argument, UnmountedType)):
@ -42,13 +46,14 @@ class Field(OrderedType):
source = None source = None
self.name = name self.name = name
self._type = type self._type = gql_type
self.args = to_arguments(args or OrderedDict(), extra_args) self.args = to_arguments(args or OrderedDict(), extra_args)
if source: if source:
resolver = partial(source_resolver, source) resolver = partial(source_resolver, source)
self.resolver = resolver self.resolver = resolver
self.deprecation_reason = deprecation_reason self.deprecation_reason = deprecation_reason
self.description = description self.description = description
self.default_value = default_value
@property @property
def type(self): def type(self):

View File

@ -16,19 +16,22 @@ def test_field_basic():
resolver = lambda: None resolver = lambda: None
deprecation_reason = 'Deprecated now' deprecation_reason = 'Deprecated now'
description = 'My Field' description = 'My Field'
my_default='something'
field = Field( field = Field(
MyType, MyType,
name='name', name='name',
args=args, args=args,
resolver=resolver, resolver=resolver,
description=description, description=description,
deprecation_reason=deprecation_reason deprecation_reason=deprecation_reason,
default_value=my_default,
) )
assert field.name == 'name' assert field.name == 'name'
assert field.args == args assert field.args == args
assert field.resolver == resolver assert field.resolver == resolver
assert field.deprecation_reason == deprecation_reason assert field.deprecation_reason == deprecation_reason
assert field.description == description assert field.description == description
assert field.default_value == my_default
def test_field_required(): def test_field_required():
@ -38,6 +41,15 @@ def test_field_required():
assert field.type.of_type == MyType assert field.type.of_type == MyType
def test_field_default_value_not_callable():
MyType = object()
try:
Field(MyType, default_value=lambda: True)
except AssertionError as e:
# substring comparison for py 2/3 compatibility
assert 'The default value can not be a function but received' in str(e)
def test_field_source(): def test_field_source():
MyType = object() MyType = object()
field = Field(MyType, source='value') field = Field(MyType, source='value')

View File

@ -1,8 +1,9 @@
import json import json
from functools import partial from functools import partial
from graphql import Source, execute, parse from graphql import Source, execute, parse, GraphQLError
from ..field import Field
from ..inputfield import InputField from ..inputfield import InputField
from ..inputobjecttype import InputObjectType from ..inputobjecttype import InputObjectType
from ..objecttype import ObjectType from ..objecttype import ObjectType
@ -22,6 +23,49 @@ def test_query():
assert executed.data == {'hello': 'World'} assert executed.data == {'hello': 'World'}
def test_query_default_value():
class MyType(ObjectType):
field = String()
class Query(ObjectType):
hello = Field(MyType, default_value=MyType(field='something else!'))
hello_schema = Schema(Query)
executed = hello_schema.execute('{ hello { field } }')
assert not executed.errors
assert executed.data == {'hello': {'field': 'something else!'}}
def test_query_wrong_default_value():
class MyType(ObjectType):
field = String()
class Query(ObjectType):
hello = Field(MyType, default_value='hello')
hello_schema = Schema(Query)
executed = hello_schema.execute('{ hello { field } }')
assert len(executed.errors) == 1
assert executed.errors[0].message == GraphQLError('Expected value of type "MyType" but got: str.').message
assert executed.data == {'hello': None}
def test_query_default_value_ignored_by_resolver():
class MyType(ObjectType):
field = String()
class Query(ObjectType):
hello = Field(MyType, default_value='hello', resolver=lambda *_: MyType(field='no default.'))
hello_schema = Schema(Query)
executed = hello_schema.execute('{ hello { field } }')
assert not executed.errors
assert executed.data == {'hello': {'field': 'no default.'}}
def test_query_resolve_function(): def test_query_resolve_function():
class Query(ObjectType): class Query(ObjectType):
hello = String() hello = String()

View File

@ -190,8 +190,8 @@ class TypeMap(GraphQLTypeMap):
return to_camel_case(name) return to_camel_case(name)
return name return name
def default_resolver(self, attname, root, *_): def default_resolver(self, attname, default_value, root, *_):
return getattr(root, attname, None) return getattr(root, attname, default_value)
def construct_fields_for_type(self, map, type, is_input_type=False): def construct_fields_for_type(self, map, type, is_input_type=False):
fields = OrderedDict() fields = OrderedDict()
@ -224,7 +224,7 @@ class TypeMap(GraphQLTypeMap):
_field = GraphQLField( _field = GraphQLField(
field_type, field_type,
args=args, args=args,
resolver=field.get_resolver(self.get_resolver_for_type(type, name)), resolver=field.get_resolver(self.get_resolver_for_type(type, name, field.default_value)),
deprecation_reason=field.deprecation_reason, deprecation_reason=field.deprecation_reason,
description=field.description description=field.description
) )
@ -232,7 +232,7 @@ class TypeMap(GraphQLTypeMap):
fields[field_name] = _field fields[field_name] = _field
return fields return fields
def get_resolver_for_type(self, type, name): def get_resolver_for_type(self, type, name, default_value):
if not issubclass(type, ObjectType): if not issubclass(type, ObjectType):
return return
resolver = getattr(type, 'resolve_{}'.format(name), None) resolver = getattr(type, 'resolve_{}'.format(name), None)
@ -253,7 +253,7 @@ class TypeMap(GraphQLTypeMap):
return resolver.__func__ return resolver.__func__
return resolver return resolver
return partial(self.default_resolver, name) return partial(self.default_resolver, name, default_value)
def get_field_type(self, map, type): def get_field_type(self, map, type):
if isinstance(type, List): if isinstance(type, List):