Make resolvers simple again

This commit is contained in:
Syrus Akbary 2017-07-23 23:10:15 -07:00
parent 800fbdf820
commit 9769612a44
23 changed files with 93 additions and 99 deletions

View File

@ -32,7 +32,7 @@ Also, Graphene is fully compatible with the GraphQL spec, working seamlessly wit
For instaling graphene, just run this command in your shell For instaling graphene, just run this command in your shell
```bash ```bash
pip install "graphene>=2.0" pip install "graphene>=2.0.dev"
``` ```
## 2.0 Upgrade Guide ## 2.0 Upgrade Guide
@ -48,7 +48,7 @@ Here is one example for you to get started:
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
hello = graphene.String(description='A typical hello world') hello = graphene.String(description='A typical hello world')
def resolve_hello(self, args, context, info): def resolve_hello(self):
return 'World' return 'World'
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)

View File

@ -65,7 +65,7 @@ Here is one example for you to get started:
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
hello = graphene.String(description='A typical hello world') hello = graphene.String(description='A typical hello world')
def resolve_hello(self, args, context, info): def resolve_hello(self):
return 'World' return 'World'
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)

View File

@ -49,7 +49,7 @@ class Pet(CommonFields, Interface):
### resolve\_only\_args ### resolve\_only\_args
`resolve_only_args` is now deprecated in favor of type annotations (using the polyfill `@graphene.annotate` in Python 2). `resolve_only_args` is now deprecated in favor of type annotations (using the polyfill `@graphene.annotate` in Python 2 in case is necessary for accessing `context` or `info`).
Before: Before:
@ -68,8 +68,7 @@ With 2.0:
class User(ObjectType): class User(ObjectType):
name = String() name = String()
# Decorate the resolver with @annotate in Python 2 def resolve_name(self):
def resolve_name(self) -> str:
return self.name return self.name
``` ```
@ -129,7 +128,7 @@ def is_user_id(id):
return id.startswith('userid_') return id.startswith('userid_')
class Query(ObjectType): class Query(ObjectType):
user = graphene.Field(User, id=UserInput()) user = graphene.Field(User, input=UserInput())
@resolve_only_args @resolve_only_args
def resolve_user(self, input): def resolve_user(self, input):
@ -149,10 +148,9 @@ class UserInput(InputObjectType):
return self.id.startswith('userid_') return self.id.startswith('userid_')
class Query(ObjectType): class Query(ObjectType):
user = graphene.Field(User, id=UserInput()) user = graphene.Field(User, input=UserInput())
# Decorate the resolver with @annotate(input=UserInput) in Python 2 def resolve_user(self, input):
def resolve_user(self, input: UserInput) -> User:
if input.is_user_id: if input.is_user_id:
return get_user(input.id) return get_user(input.id)

View File

@ -9,7 +9,8 @@ class User(graphene.ObjectType):
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
me = graphene.Field(User) me = graphene.Field(User)
def resolve_me(self, args, context, info): @graphene.annotate(context=graphene.Context)
def resolve_me(self, context):
return context['user'] return context['user']

View File

@ -11,7 +11,7 @@ class Query(graphene.ObjectType):
patron = graphene.Field(Patron) patron = graphene.Field(Patron)
def resolve_patron(self, args, context, info): def resolve_patron(self):
return Patron(id=1, name='Syrus', age=27) return Patron(id=1, name='Syrus', age=27)

View File

@ -1,5 +1,4 @@
import graphene import graphene
from graphene import annotate
from .data import get_character, get_droid, get_hero, get_human from .data import get_character, get_droid, get_hero, get_human
@ -16,7 +15,7 @@ class Character(graphene.Interface):
friends = graphene.List(lambda: Character) friends = graphene.List(lambda: Character)
appears_in = graphene.List(Episode) appears_in = graphene.List(Episode)
def resolve_friends(self, args, *_): def resolve_friends(self):
# The character friends is a list of strings # The character friends is a list of strings
return [get_character(f) for f in self.friends] return [get_character(f) for f in self.friends]
@ -46,15 +45,12 @@ class Query(graphene.ObjectType):
id=graphene.String() id=graphene.String()
) )
@annotate(episode=Episode)
def resolve_hero(self, episode=None): def resolve_hero(self, episode=None):
return get_hero(episode) return get_hero(episode)
@annotate(id=str)
def resolve_human(self, id): def resolve_human(self, id):
return get_human(id) return get_human(id)
@annotate(id=str)
def resolve_droid(self, id): def resolve_droid(self, id):
return get_droid(id) return get_droid(id)

View File

@ -32,7 +32,6 @@ class Faction(graphene.ObjectType):
name = graphene.String(description='The name of the faction.') name = graphene.String(description='The name of the faction.')
ships = relay.ConnectionField(ShipConnection, description='The ships used by the faction.') ships = relay.ConnectionField(ShipConnection, description='The ships used by the faction.')
@annotate
def resolve_ships(self, **args): def resolve_ships(self, **args):
# Transform the instance ship_ids into real instances # Transform the instance ship_ids into real instances
return [get_ship(ship_id) for ship_id in self.ships] return [get_ship(ship_id) for ship_id in self.ships]
@ -65,11 +64,9 @@ class Query(graphene.ObjectType):
empire = graphene.Field(Faction) empire = graphene.Field(Faction)
node = relay.Node.Field() node = relay.Node.Field()
@annotate
def resolve_rebels(self): def resolve_rebels(self):
return get_rebels() return get_rebels()
@annotate
def resolve_empire(self): def resolve_empire(self):
return get_empire() return get_empire()

View File

@ -3,8 +3,9 @@ from collections import OrderedDict
from promise import Promise, is_thenable from promise import Promise, is_thenable
from ..types import Field, InputObjectType, String from ..types import Field, InputObjectType, String, Context, ResolveInfo
from ..types.mutation import Mutation from ..types.mutation import Mutation
from ..utils.annotate import annotate
class ClientIDMutation(Mutation): class ClientIDMutation(Mutation):
@ -49,9 +50,8 @@ class ClientIDMutation(Mutation):
) )
@classmethod @classmethod
def mutate(cls, root, args, context, info): @annotate(context=Context, info=ResolveInfo)
input = args.get('input') def mutate(cls, root, input, context, info):
def on_resolve(payload): def on_resolve(payload):
try: try:
payload.client_mutation_id = input.get('clientMutationId') payload.client_mutation_id = input.get('clientMutationId')

View File

@ -31,13 +31,13 @@ class Query(ObjectType):
node = Node.Field() node = Node.Field()
def resolve_letters(self, args, context, info): def resolve_letters(self, **args):
return list(letters.values()) return list(letters.values())
def resolve_promise_letters(self, args, context, info): def resolve_promise_letters(self, **args):
return Promise.resolve(list(letters.values())) return Promise.resolve(list(letters.values()))
def resolve_connection_letters(self, args, context, info): def resolve_connection_letters(self, **args):
return LetterConnection( return LetterConnection(
page_info=PageInfo( page_info=PageInfo(
has_next_page=True, has_next_page=True,

View File

@ -29,7 +29,6 @@ class CreatePost(graphene.Mutation):
result = graphene.Field(CreatePostResult) result = graphene.Field(CreatePostResult)
@resolve_only_args
def mutate(self, text): def mutate(self, text):
result = Success(yeah='yeah') result = Success(yeah='yeah')

View File

@ -6,8 +6,8 @@ import graphene
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
some_field = graphene.String(from_=graphene.String(name="from")) some_field = graphene.String(from_=graphene.String(name="from"))
def resolve_some_field(_, args, context, infos): def resolve_some_field(self, from_=None):
return args.get("from_") return from_
def test_issue(): def test_issue():

View File

@ -7,7 +7,6 @@ from .mountedtype import MountedType
from .structures import NonNull from .structures import NonNull
from .unmountedtype import UnmountedType from .unmountedtype import UnmountedType
from .utils import get_type from .utils import get_type
from ..utils.auto_resolver import auto_resolver
base_type = type base_type = type
@ -64,4 +63,4 @@ class Field(MountedType):
return get_type(self._type) return get_type(self._type)
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):
return auto_resolver(self.resolver or parent_resolver) return self.resolver or parent_resolver

View File

@ -5,6 +5,7 @@ from ..utils.props import props
from .field import Field from .field import Field
from .objecttype import ObjectType, ObjectTypeOptions from .objecttype import ObjectType, ObjectTypeOptions
from .utils import yank_fields_from_attrs from .utils import yank_fields_from_attrs
from ..utils.auto_resolver import auto_resolver
class MutationOptions(ObjectTypeOptions): class MutationOptions(ObjectTypeOptions):
@ -59,7 +60,7 @@ class Mutation(ObjectType):
_meta.fields = fields _meta.fields = fields
_meta.output = output _meta.output = output
_meta.resolver = resolver _meta.resolver = auto_resolver(resolver)
_meta.arguments = arguments _meta.arguments = arguments
super(Mutation, cls).__init_subclass_with_meta__(_meta=_meta, **options) super(Mutation, cls).__init_subclass_with_meta__(_meta=_meta, **options)

View File

@ -11,12 +11,11 @@ class Query(ObjectType):
datetime = DateTime(_in=DateTime(name='in')) datetime = DateTime(_in=DateTime(name='in'))
time = Time(_at=Time(name='at')) time = Time(_at=Time(name='at'))
def resolve_datetime(self, args, context, info): def resolve_datetime(self, _in=None):
_in = args.get('_in')
return _in return _in
def resolve_time(self, args, context, info): def resolve_time(self, _at=None):
return args.get('_at') return _at
schema = Schema(query=Query) schema = Schema(query=Query)

View File

@ -6,8 +6,7 @@ from ..schema import Schema
class Query(ObjectType): class Query(ObjectType):
generic = GenericScalar(input=GenericScalar()) generic = GenericScalar(input=GenericScalar())
def resolve_generic(self, args, context, info): def resolve_generic(self, input=None):
input = args.get('input')
return input return input

View File

@ -7,8 +7,7 @@ from ..schema import Schema
class Query(ObjectType): class Query(ObjectType):
json = JSONString(input=JSONString()) json = JSONString(input=JSONString())
def resolve_json(self, args, context, info): def resolve_json(self, input):
input = args.get('input')
return input return input
schema = Schema(query=Query) schema = Schema(query=Query)

View File

@ -12,14 +12,14 @@ def test_generate_mutation_no_args():
class MyMutation(Mutation): class MyMutation(Mutation):
'''Documentation''' '''Documentation'''
@classmethod def mutate(self, **args):
def mutate(cls, *args, **kwargs): return args
pass
assert issubclass(MyMutation, ObjectType) assert issubclass(MyMutation, ObjectType)
assert MyMutation._meta.name == "MyMutation" assert MyMutation._meta.name == "MyMutation"
assert MyMutation._meta.description == "Documentation" assert MyMutation._meta.description == "Documentation"
assert MyMutation.Field().resolver == MyMutation.mutate resolved = MyMutation.Field().resolver(None, {'name': 'Peter'}, None, None)
assert resolved == {'name': 'Peter'}
def test_generate_mutation_with_meta(): def test_generate_mutation_with_meta():
@ -29,13 +29,13 @@ def test_generate_mutation_with_meta():
name = 'MyOtherMutation' name = 'MyOtherMutation'
description = 'Documentation' description = 'Documentation'
@classmethod def mutate(self, **args):
def mutate(cls, *args, **kwargs): return args
pass
assert MyMutation._meta.name == "MyOtherMutation" assert MyMutation._meta.name == "MyOtherMutation"
assert MyMutation._meta.description == "Documentation" assert MyMutation._meta.description == "Documentation"
assert MyMutation.Field().resolver == MyMutation.mutate resolved = MyMutation.Field().resolver(None, {'name': 'Peter'}, None, None)
assert resolved == {'name': 'Peter'}
def test_mutation_raises_exception_if_no_mutate(): def test_mutation_raises_exception_if_no_mutate():
@ -59,15 +59,15 @@ def test_mutation_custom_output_type():
Output = User Output = User
@classmethod def mutate(self, name):
def mutate(cls, args, context, info):
name = args.get('name')
return User(name=name) return User(name=name)
field = CreateUser.Field() field = CreateUser.Field()
assert field.type == User assert field.type == User
assert field.args == {'name': Argument(String)} assert field.args == {'name': Argument(String)}
assert field.resolver == CreateUser.mutate resolved = field.resolver(None, {'name': 'Peter'}, None, None)
assert isinstance(resolved, User)
assert resolved.name == 'Peter'
def test_mutation_execution(): def test_mutation_execution():
@ -81,9 +81,7 @@ def test_mutation_execution():
name = String() name = String()
dynamic = Dynamic(lambda: String()) dynamic = Dynamic(lambda: String())
def mutate(self, args, context, info): def mutate(self, name, dynamic):
name = args.get('name')
dynamic = args.get('dynamic')
return CreateUser(name=name, dynamic=dynamic) return CreateUser(name=name, dynamic=dynamic)
class Query(ObjectType): class Query(ObjectType):

View File

@ -57,7 +57,7 @@ def test_query_union():
class Query(ObjectType): class Query(ObjectType):
unions = List(MyUnion) unions = List(MyUnion)
def resolve_unions(self, args, context, info): def resolve_unions(self):
return [one_object(), two_object()] return [one_object(), two_object()]
hello_schema = Schema(Query) hello_schema = Schema(Query)
@ -108,7 +108,7 @@ def test_query_interface():
class Query(ObjectType): class Query(ObjectType):
interfaces = List(MyInterface) interfaces = List(MyInterface)
def resolve_interfaces(self, args, context, info): def resolve_interfaces(self):
return [one_object(), two_object()] return [one_object(), two_object()]
hello_schema = Schema(Query, types=[One, Two]) hello_schema = Schema(Query, types=[One, Two])
@ -188,7 +188,7 @@ def test_query_resolve_function():
class Query(ObjectType): class Query(ObjectType):
hello = String() hello = String()
def resolve_hello(self, args, context, info): def resolve_hello(self):
return 'World' return 'World'
hello_schema = Schema(Query) hello_schema = Schema(Query)
@ -202,7 +202,7 @@ def test_query_arguments():
class Query(ObjectType): class Query(ObjectType):
test = String(a_str=String(), a_int=Int()) test = String(a_str=String(), a_int=Int())
def resolve_test(self, args, context, info): def resolve_test(self, **args):
return json.dumps([self, args], separators=(',', ':')) return json.dumps([self, args], separators=(',', ':'))
test_schema = Schema(Query) test_schema = Schema(Query)
@ -231,7 +231,7 @@ def test_query_input_field():
class Query(ObjectType): class Query(ObjectType):
test = String(a_input=Input()) test = String(a_input=Input())
def resolve_test(self, args, context, info): def resolve_test(self, **args):
return json.dumps([self, args], separators=(',', ':')) return json.dumps([self, args], separators=(',', ':'))
test_schema = Schema(Query) test_schema = Schema(Query)
@ -254,10 +254,10 @@ def test_query_middlewares():
hello = String() hello = String()
other = String() other = String()
def resolve_hello(self, args, context, info): def resolve_hello(self):
return 'World' return 'World'
def resolve_other(self, args, context, info): def resolve_other(self):
return 'other' return 'other'
def reversed_middleware(next, *args, **kwargs): def reversed_middleware(next, *args, **kwargs):
@ -280,14 +280,14 @@ def test_objecttype_on_instances():
class ShipType(ObjectType): class ShipType(ObjectType):
name = String(description="Ship name", required=True) name = String(description="Ship name", required=True)
def resolve_name(self, context, args, info): def resolve_name(self):
# Here self will be the Ship instance returned in resolve_ship # Here self will be the Ship instance returned in resolve_ship
return self.name return self.name
class Query(ObjectType): class Query(ObjectType):
ship = Field(ShipType) ship = Field(ShipType)
def resolve_ship(self, context, args, info): def resolve_ship(self):
return Ship(name='xwing') return Ship(name='xwing')
schema = Schema(query=Query) schema = Schema(query=Query)
@ -302,7 +302,7 @@ def test_big_list_query_benchmark(benchmark):
class Query(ObjectType): class Query(ObjectType):
all_ints = List(Int) all_ints = List(Int)
def resolve_all_ints(self, args, context, info): def resolve_all_ints(self):
return big_list return big_list
hello_schema = Schema(Query) hello_schema = Schema(Query)
@ -319,7 +319,7 @@ def test_big_list_query_compiled_query_benchmark(benchmark):
class Query(ObjectType): class Query(ObjectType):
all_ints = List(Int) all_ints = List(Int)
def resolve_all_ints(self, args, context, info): def resolve_all_ints(self):
return big_list return big_list
hello_schema = Schema(Query) hello_schema = Schema(Query)
@ -341,7 +341,7 @@ def test_big_list_of_containers_query_benchmark(benchmark):
class Query(ObjectType): class Query(ObjectType):
all_containers = List(Container) all_containers = List(Container)
def resolve_all_containers(self, args, context, info): def resolve_all_containers(self):
return big_container_list return big_container_list
hello_schema = Schema(Query) hello_schema = Schema(Query)
@ -364,7 +364,7 @@ def test_big_list_of_containers_multiple_fields_query_benchmark(benchmark):
class Query(ObjectType): class Query(ObjectType):
all_containers = List(Container) all_containers = List(Container)
def resolve_all_containers(self, args, context, info): def resolve_all_containers(self):
return big_container_list return big_container_list
hello_schema = Schema(Query) hello_schema = Schema(Query)
@ -382,16 +382,16 @@ def test_big_list_of_containers_multiple_fields_custom_resolvers_query_benchmark
z = Int() z = Int()
o = Int() o = Int()
def resolve_x(self, args, context, info): def resolve_x(self):
return self.x return self.x
def resolve_y(self, args, context, info): def resolve_y(self):
return self.y return self.y
def resolve_z(self, args, context, info): def resolve_z(self):
return self.z return self.z
def resolve_o(self, args, context, info): def resolve_o(self):
return self.o return self.o
big_container_list = [Container(x=x, y=x, z=x, o=x) for x in range(1000)] big_container_list = [Container(x=x, y=x, z=x, o=x) for x in range(1000)]
@ -399,7 +399,7 @@ def test_big_list_of_containers_multiple_fields_custom_resolvers_query_benchmark
class Query(ObjectType): class Query(ObjectType):
all_containers = List(Container) all_containers = List(Container)
def resolve_all_containers(self, args, context, info): def resolve_all_containers(self):
return big_container_list return big_container_list
hello_schema = Schema(Query) hello_schema = Schema(Query)
@ -420,7 +420,6 @@ def test_query_annotated_resolvers():
context = String() context = String()
info = String() info = String()
@annotate(_trigger_warning=False)
def resolve_annotated(self, id): def resolve_annotated(self, id):
return "{}-{}".format(self, id) return "{}-{}".format(self, id)

View File

@ -49,8 +49,8 @@ def test_objecttype():
foo = String(bar=String(description='Argument description', default_value='x'), description='Field description') foo = String(bar=String(description='Argument description', default_value='x'), description='Field description')
bar = String(name='gizmo') bar = String(name='gizmo')
def resolve_foo(self, args, info): def resolve_foo(self, bar):
return args.get('bar') return bar
typemap = TypeMap([MyObjectType]) typemap = TypeMap([MyObjectType])
assert 'MyObjectType' in typemap assert 'MyObjectType' in typemap
@ -65,7 +65,7 @@ def test_objecttype():
assert isinstance(foo_field, GraphQLField) assert isinstance(foo_field, GraphQLField)
assert foo_field.description == 'Field description' assert foo_field.description == 'Field description'
f = MyObjectType.resolve_foo f = MyObjectType.resolve_foo
assert foo_field.resolver == getattr(f, '__func__', f) # assert foo_field.resolver == getattr(f, '__func__', f)
assert foo_field.args == { assert foo_field.args == {
'bar': GraphQLArgument(GraphQLString, description='Argument description', default_value='x', out_name='bar') 'bar': GraphQLArgument(GraphQLString, description='Argument description', default_value='x', out_name='bar')
} }

View File

@ -11,6 +11,7 @@ from graphql.type.typemap import GraphQLTypeMap
from ..utils.get_unbound_function import get_unbound_function from ..utils.get_unbound_function import get_unbound_function
from ..utils.str_converters import to_camel_case from ..utils.str_converters import to_camel_case
from ..utils.auto_resolver import auto_resolver, final_resolver
from .definitions import (GrapheneEnumType, GrapheneGraphQLType, from .definitions import (GrapheneEnumType, GrapheneGraphQLType,
GrapheneInputObjectType, GrapheneInterfaceType, GrapheneInputObjectType, GrapheneInterfaceType,
GrapheneObjectType, GrapheneScalarType, GrapheneObjectType, GrapheneScalarType,
@ -256,8 +257,14 @@ class TypeMap(GraphQLTypeMap):
field_type, field_type,
args=args, args=args,
resolver=field.get_resolver( resolver=field.get_resolver(
self.get_resolver_for_type(type, name, auto_resolver(
field.default_value)), self.get_resolver_for_type(
type,
name,
field.default_value
)
)
),
deprecation_reason=field.deprecation_reason, deprecation_reason=field.deprecation_reason,
description=field.description) description=field.description)
field_name = field.name or self.get_name(name) field_name = field.name or self.get_name(name)
@ -287,7 +294,7 @@ class TypeMap(GraphQLTypeMap):
default_resolver = type._meta.default_resolver or get_default_resolver( default_resolver = type._meta.default_resolver or get_default_resolver(
) )
return partial(default_resolver, name, default_value) return final_resolver(partial(default_resolver, name, default_value))
def get_field_type(self, map, type): def get_field_type(self, map, type):
if isinstance(type, List): if isinstance(type, List):

View File

@ -1,12 +1,21 @@
from .resolver_from_annotations import resolver_from_annotations, is_wrapped_from_annotations from .resolver_from_annotations import resolver_from_annotations
def final_resolver(func):
func._is_final_resolver = True
return func
def auto_resolver(func=None): def auto_resolver(func=None):
annotations = getattr(func, '__annotations__', {}) if not func:
is_annotated = getattr(func, '_is_annotated', False) return
if (annotations or is_annotated) and not is_wrapped_from_annotations(func): if not is_final_resolver(func):
# Is a Graphene 2.0 resolver function # Is a Graphene 2.0 resolver function
return resolver_from_annotations(func) return final_resolver(resolver_from_annotations(func))
else: else:
return func return func
def is_final_resolver(func):
return getattr(func, '_is_final_resolver', False)

View File

@ -1,15 +1,10 @@
from ..pyutils.compat import signature from ..pyutils.compat import signature
from functools import wraps from functools import wraps, partial
def resolver_from_annotations(func): def resolver_from_annotations(func):
from ..types import Context, ResolveInfo from ..types import Context, ResolveInfo
_is_wrapped_from_annotations = is_wrapped_from_annotations(func)
assert not _is_wrapped_from_annotations, "The function {func_name} is already wrapped.".format(
func_name=func.func_name
)
func_signature = signature(func) func_signature = signature(func)
_context_var = None _context_var = None
@ -38,9 +33,7 @@ def resolver_from_annotations(func):
def inner(root, args, context, info): def inner(root, args, context, info):
return func(root, **args) return func(root, **args)
inner._is_wrapped_from_annotations = True if isinstance(func, partial):
return inner
return wraps(func)(inner) return wraps(func)(inner)
def is_wrapped_from_annotations(func):
return getattr(func, '_is_wrapped_from_annotations', False)

View File

@ -1,15 +1,15 @@
import pytest import pytest
from ..annotate import annotate from ..annotate import annotate
from ..auto_resolver import auto_resolver from ..auto_resolver import auto_resolver, final_resolver
from ...types import Context, ResolveInfo from ...types import Context, ResolveInfo
@final_resolver
def resolver(root, args, context, info): def resolver(root, args, context, info):
return root, args, context, info return root, args, context, info
@annotate
def resolver_annotated(root, **args): def resolver_annotated(root, **args):
return root, args, None, None return root, args, None, None