Updated resolver api

This commit is contained in:
Syrus Akbary 2017-07-27 02:51:25 -07:00
parent 586ea56693
commit 394a1beb07
33 changed files with 134 additions and 266 deletions

View File

@ -174,6 +174,17 @@ class Query(ObjectType):
user_connection = relay.ConnectionField(UserConnection) user_connection = relay.ConnectionField(UserConnection)
``` ```
## Mutation.mutate
Now only receive (`root`, `info`, `**args`)
## ClientIDMutation.mutate_and_get_payload
Now only receive (`root`, `info`, `**input`)
## New Features ## New Features
### InputObjectType ### InputObjectType

View File

@ -11,10 +11,9 @@ class Address(graphene.ObjectType):
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
address = graphene.Field(Address, geo=GeoInput()) address = graphene.Field(Address, geo=GeoInput(required=True))
def resolve_address(self, args, context, info): def resolve_address(self, info, geo):
geo = args.get('geo')
return Address(latlng="({},{})".format(geo.get('lat'), geo.get('lng'))) return Address(latlng="({},{})".format(geo.get('lat'), geo.get('lng')))

View File

@ -9,9 +9,8 @@ class User(graphene.ObjectType):
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
me = graphene.Field(User) me = graphene.Field(User)
@graphene.annotate(context=graphene.Context) def resolve_me(self, info):
def resolve_me(self, context): return info.context['user']
return context['user']
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)

View File

@ -11,7 +11,7 @@ class Query(graphene.ObjectType):
patron = graphene.Field(Patron) patron = graphene.Field(Patron)
def resolve_patron(self): def resolve_patron(self, info):
return Patron(id=1, name='Syrus', age=27) return Patron(id=1, name='Syrus', age=27)

View File

@ -15,7 +15,7 @@ class Character(graphene.Interface):
friends = graphene.List(lambda: Character) friends = graphene.List(lambda: Character)
appears_in = graphene.List(Episode) appears_in = graphene.List(Episode)
def resolve_friends(self): def resolve_friends(self, info):
# The character friends is a list of strings # The character friends is a list of strings
return [get_character(f) for f in self.friends] return [get_character(f) for f in self.friends]
@ -45,13 +45,13 @@ class Query(graphene.ObjectType):
id=graphene.String() id=graphene.String()
) )
def resolve_hero(self, episode=None): def resolve_hero(self, info, episode=None):
return get_hero(episode) return get_hero(episode)
def resolve_human(self, id): def resolve_human(self, info, id):
return get_human(id) return get_human(id)
def resolve_droid(self, id): def resolve_droid(self, info, id):
return get_droid(id) return get_droid(id)

View File

@ -1,5 +1,5 @@
import graphene import graphene
from graphene import annotate, relay, annotate from graphene import relay
from .data import create_ship, get_empire, get_faction, get_rebels, get_ship from .data import create_ship, get_empire, get_faction, get_rebels, get_ship
@ -13,7 +13,7 @@ class Ship(graphene.ObjectType):
name = graphene.String(description='The name of the ship.') name = graphene.String(description='The name of the ship.')
@classmethod @classmethod
def get_node(cls, id, context, info): def get_node(cls, id, info):
return get_ship(id) return get_ship(id)
@ -32,12 +32,12 @@ class Faction(graphene.ObjectType):
name = graphene.String(description='The name of the faction.') name = graphene.String(description='The name of the faction.')
ships = relay.ConnectionField(ShipConnection, description='The ships used by the faction.') ships = relay.ConnectionField(ShipConnection, description='The ships used by the faction.')
def resolve_ships(self, **args): def resolve_ships(self, info, **args):
# Transform the instance ship_ids into real instances # Transform the instance ship_ids into real instances
return [get_ship(ship_id) for ship_id in self.ships] return [get_ship(ship_id) for ship_id in self.ships]
@classmethod @classmethod
def get_node(cls, id, context, info): def get_node(cls, id, info):
return get_faction(id) return get_faction(id)
@ -51,9 +51,7 @@ class IntroduceShip(relay.ClientIDMutation):
faction = graphene.Field(Faction) faction = graphene.Field(Faction)
@classmethod @classmethod
def mutate_and_get_payload(cls, input, context, info): def mutate_and_get_payload(cls, root, info, ship_name, faction_id, client_mutation_id=None):
ship_name = input.get('ship_name')
faction_id = input.get('faction_id')
ship = create_ship(ship_name, faction_id) ship = create_ship(ship_name, faction_id)
faction = get_faction(faction_id) faction = get_faction(faction_id)
return IntroduceShip(ship=ship, faction=faction) return IntroduceShip(ship=ship, faction=faction)
@ -64,10 +62,10 @@ class Query(graphene.ObjectType):
empire = graphene.Field(Faction) empire = graphene.Field(Faction)
node = relay.Node.Field() node = relay.Node.Field()
def resolve_rebels(self): def resolve_rebels(self, info):
return get_rebels() return get_rebels()
def resolve_empire(self): def resolve_empire(self, info):
return get_empire() return get_empire()

View File

@ -48,8 +48,6 @@ if not __SETUP__:
) )
from .utils.resolve_only_args import resolve_only_args from .utils.resolve_only_args import resolve_only_args
from .utils.module_loading import lazy_import from .utils.module_loading import lazy_import
from .utils.annotate import annotate
from .utils.auto_resolver import final_resolver, is_final_resolver
__all__ = [ __all__ = [
'ObjectType', 'ObjectType',
@ -82,11 +80,8 @@ if not __SETUP__:
'ConnectionField', 'ConnectionField',
'PageInfo', 'PageInfo',
'lazy_import', 'lazy_import',
'annotate',
'Context', 'Context',
'ResolveInfo', 'ResolveInfo',
'final_resolver',
'is_final_resolver',
# Deprecated # Deprecated
'AbstractType', 'AbstractType',

View File

@ -10,7 +10,6 @@ from ..types import (Boolean, Enum, Int, Interface, List, NonNull, Scalar,
from ..types.field import Field from ..types.field import Field
from ..types.objecttype import ObjectType, ObjectTypeOptions from ..types.objecttype import ObjectType, ObjectTypeOptions
from ..utils.deprecated import warn_deprecation from ..utils.deprecated import warn_deprecation
from ..utils.auto_resolver import final_resolver
from .node import is_node from .node import is_node
@ -132,8 +131,8 @@ class IterableConnectionField(Field):
return connection return connection
@classmethod @classmethod
def connection_resolver(cls, resolver, connection_type, root, args, context, info): def connection_resolver(cls, resolver, connection_type, root, info, **args):
resolved = resolver(root, args, context, info) resolved = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection_type, args) on_resolve = partial(cls.resolve_connection, connection_type, args)
if is_thenable(resolved): if is_thenable(resolved):
@ -143,7 +142,7 @@ class IterableConnectionField(Field):
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):
resolver = super(IterableConnectionField, self).get_resolver(parent_resolver) resolver = super(IterableConnectionField, self).get_resolver(parent_resolver)
return final_resolver(partial(self.connection_resolver, resolver, self.type)) return partial(self.connection_resolver, resolver, self.type)
ConnectionField = IterableConnectionField ConnectionField = IterableConnectionField

View File

@ -3,9 +3,8 @@ from collections import OrderedDict
from promise import Promise, is_thenable from promise import Promise, is_thenable
from ..types import Field, InputObjectType, String, Context, ResolveInfo from ..types import Field, InputObjectType, String
from ..types.mutation import Mutation from ..types.mutation import Mutation
from ..utils.annotate import annotate
class ClientIDMutation(Mutation): class ClientIDMutation(Mutation):
@ -58,18 +57,17 @@ class ClientIDMutation(Mutation):
) )
@classmethod @classmethod
@annotate(context=Context, info=ResolveInfo, _trigger_warning=False) def mutate(cls, root, info, input):
def mutate(cls, root, input, context, info):
def on_resolve(payload): def on_resolve(payload):
try: try:
payload.client_mutation_id = input.get('clientMutationId') payload.client_mutation_id = input.get('client_mutation_id')
except: except:
raise Exception( raise Exception(
('Cannot set client_mutation_id in the payload object {}' ('Cannot set client_mutation_id in the payload object {}'
).format(repr(payload))) ).format(repr(payload)))
return payload return payload
result = cls.mutate_and_get_payload(input, context, info) result = cls.mutate_and_get_payload(root, info, **input)
if is_thenable(result): if is_thenable(result):
return Promise.resolve(result).then(on_resolve) return Promise.resolve(result).then(on_resolve)

View File

@ -3,11 +3,9 @@ from functools import partial
from graphql_relay import from_global_id, to_global_id from graphql_relay import from_global_id, to_global_id
from ..types import ID, Field, Interface, ObjectType, Context, ResolveInfo from ..types import ID, Field, Interface, ObjectType
from ..types.interface import InterfaceOptions from ..types.interface import InterfaceOptions
from ..types.utils import get_type from ..types.utils import get_type
from ..utils.annotate import annotate
from ..utils.auto_resolver import final_resolver
def is_node(objecttype): def is_node(objecttype):
@ -31,15 +29,15 @@ class GlobalID(Field):
self.parent_type_name = parent_type._meta.name if parent_type else None self.parent_type_name = parent_type._meta.name if parent_type else None
@staticmethod @staticmethod
def id_resolver(parent_resolver, node, root, args, context, info, parent_type_name=None): def id_resolver(parent_resolver, node, root, info, parent_type_name=None, **args):
type_id = parent_resolver(root, args, context, info) type_id = parent_resolver(root, info, **args)
parent_type_name = parent_type_name or info.parent_type.name parent_type_name = parent_type_name or info.parent_type.name
return node.to_global_id(parent_type_name, type_id) # root._meta.name return node.to_global_id(parent_type_name, type_id) # root._meta.name
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):
return final_resolver(partial( return partial(
self.id_resolver, parent_resolver, self.node, parent_type_name=self.parent_type_name self.id_resolver, parent_resolver, self.node, parent_type_name=self.parent_type_name
)) )
class NodeField(Field): class NodeField(Field):
@ -83,12 +81,11 @@ class Node(AbstractNode):
return NodeField(cls, *args, **kwargs) return NodeField(cls, *args, **kwargs)
@classmethod @classmethod
@annotate(context=Context, info=ResolveInfo, _trigger_warning=False) def node_resolver(cls, root, info, id, only_type=None):
def node_resolver(cls, root, id, context, info, only_type=None): return cls.get_node_from_global_id(id, info, only_type)
return cls.get_node_from_global_id(id, context, info, only_type)
@classmethod @classmethod
def get_node_from_global_id(cls, global_id, context, info, only_type=None): def get_node_from_global_id(cls, global_id, info, only_type=None):
try: try:
_type, _id = cls.from_global_id(global_id) _type, _id = cls.from_global_id(global_id)
graphene_type = info.schema.get_type(_type).graphene_type graphene_type = info.schema.get_type(_type).graphene_type
@ -106,7 +103,7 @@ class Node(AbstractNode):
get_node = getattr(graphene_type, 'get_node', None) get_node = getattr(graphene_type, 'get_node', None)
if get_node: if get_node:
return get_node(_id, context, info) return get_node(_id, info)
@classmethod @classmethod
def from_global_id(cls, global_id): def from_global_id(cls, global_id):

View File

@ -31,13 +31,13 @@ class Query(ObjectType):
node = Node.Field() node = Node.Field()
def resolve_letters(self, **args): def resolve_letters(self, info, **args):
return list(letters.values()) return list(letters.values())
def resolve_promise_letters(self, **args): def resolve_promise_letters(self, info, **args):
return Promise.resolve(list(letters.values())) return Promise.resolve(list(letters.values()))
def resolve_connection_letters(self, **args): def resolve_connection_letters(self, info, **args):
return LetterConnection( return LetterConnection(
page_info=PageInfo( page_info=PageInfo(
has_next_page=True, has_next_page=True,

View File

@ -48,7 +48,7 @@ def test_global_id_defaults_to_info_parent_type():
my_id = '1' my_id = '1'
gid = GlobalID() gid = GlobalID()
id_resolver = gid.get_resolver(lambda *_: my_id) id_resolver = gid.get_resolver(lambda *_: my_id)
my_global_id = id_resolver(None, None, None, Info(User)) my_global_id = id_resolver(None, Info(User))
assert my_global_id == to_global_id(User._meta.name, my_id) assert my_global_id == to_global_id(User._meta.name, my_id)
@ -56,5 +56,5 @@ def test_global_id_allows_setting_customer_parent_type():
my_id = '1' my_id = '1'
gid = GlobalID(parent_type=User) gid = GlobalID(parent_type=User)
id_resolver = gid.get_resolver(lambda *_: my_id) id_resolver = gid.get_resolver(lambda *_: my_id)
my_global_id = id_resolver(None, None, None, None) my_global_id = id_resolver(None, None)
assert my_global_id == to_global_id(User._meta.name, my_id) assert my_global_id == to_global_id(User._meta.name, my_id)

View File

@ -27,8 +27,7 @@ class SaySomething(ClientIDMutation):
phrase = String() phrase = String()
@staticmethod @staticmethod
def mutate_and_get_payload(args, context, info): def mutate_and_get_payload(self, info, what, client_mutation_id=None):
what = args.get('what')
return SaySomething(phrase=str(what)) return SaySomething(phrase=str(what))
@ -40,8 +39,7 @@ class SaySomethingPromise(ClientIDMutation):
phrase = String() phrase = String()
@staticmethod @staticmethod
def mutate_and_get_payload(args, context, info): def mutate_and_get_payload(self, info, what, client_mutation_id=None):
what = args.get('what')
return Promise.resolve(SaySomething(phrase=str(what))) return Promise.resolve(SaySomething(phrase=str(what)))
@ -59,13 +57,11 @@ class OtherMutation(ClientIDMutation):
name = String() name = String()
my_node_edge = Field(MyEdge) my_node_edge = Field(MyEdge)
@classmethod @staticmethod
def mutate_and_get_payload(cls, args, context, info): def mutate_and_get_payload(self, info, shared='', additional_field='', client_mutation_id=None):
shared = args.get('shared', '')
additionalField = args.get('additionalField', '')
edge_type = MyEdge edge_type = MyEdge
return OtherMutation( return OtherMutation(
name=shared + additionalField, name=shared + additional_field,
my_node_edge=edge_type(cursor='1', node=MyNode(name='name'))) my_node_edge=edge_type(cursor='1', node=MyNode(name='name')))

View File

@ -15,7 +15,7 @@ class CustomNode(Node):
return id return id
@staticmethod @staticmethod
def get_node_from_global_id(id, context, info, only_type=None): def get_node_from_global_id(id, info, only_type=None):
assert info.schema == schema assert info.schema == schema
if id in user_data: if id in user_data:
return user_data.get(id) return user_data.get(id)

View File

@ -29,7 +29,7 @@ class CreatePost(graphene.Mutation):
result = graphene.Field(CreatePostResult) result = graphene.Field(CreatePostResult)
def mutate(self, text): def mutate(self, info, text):
result = Success(yeah='yeah') result = Success(yeah='yeah')
return CreatePost(result=result) return CreatePost(result=result)

View File

@ -6,7 +6,7 @@ import graphene
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
some_field = graphene.String(from_=graphene.String(name="from")) some_field = graphene.String(from_=graphene.String(name="from"))
def resolve_some_field(self, from_=None): def resolve_some_field(self, info, from_=None):
return from_ return from_

View File

@ -11,7 +11,20 @@ class InputObjectTypeOptions(BaseOptions):
create_container = None # type: Callable create_container = None # type: Callable
class InputObjectType(dict, UnmountedType, BaseType): class InputObjectTypeContainer(dict, BaseType):
class Meta:
abstract = True
def __init__(self, *args, **kwargs):
dict.__init__(self, *args, **kwargs)
for key, value in self.items():
setattr(self, key, value)
def __init_subclass__(cls, *args, **kwargs):
pass
class InputObjectType(UnmountedType, BaseType):
''' '''
Input Object Type Definition Input Object Type Definition
@ -20,30 +33,9 @@ class InputObjectType(dict, UnmountedType, BaseType):
Using `NonNull` will ensure that a value must be provided by the query Using `NonNull` will ensure that a value must be provided by the query
''' '''
def __init__(self, *args, **kwargs):
as_container = kwargs.pop('_as_container', False)
if as_container:
# Is inited as container for the input args
self.__init_container__(*args, **kwargs)
else:
# Is inited as UnmountedType, e.g.
#
# class MyObjectType(graphene.ObjectType):
# my_input = MyInputType(required=True)
#
UnmountedType.__init__(self, *args, **kwargs)
def __init_container__(self, *args, **kwargs):
dict.__init__(self, *args, **kwargs)
for key, value in self.items():
setattr(self, key, value)
@classmethod @classmethod
def create_container(cls, data): def __init_subclass_with_meta__(cls, container=None, **options):
return cls(data, _as_container=True)
@classmethod
def __init_subclass_with_meta__(cls, create_container=None, **options):
_meta = InputObjectTypeOptions(cls) _meta = InputObjectTypeOptions(cls)
fields = OrderedDict() fields = OrderedDict()
@ -53,9 +45,9 @@ class InputObjectType(dict, UnmountedType, BaseType):
) )
_meta.fields = fields _meta.fields = fields
if create_container is None: if container is None:
create_container = cls.create_container container = type(cls.__name__, (InputObjectTypeContainer, cls), {})
_meta.create_container = create_container _meta.container = container
super(InputObjectType, cls).__init_subclass_with_meta__(_meta=_meta, **options) super(InputObjectType, cls).__init_subclass_with_meta__(_meta=_meta, **options)
@classmethod @classmethod

View File

@ -37,14 +37,10 @@ class Interface(BaseType):
super(Interface, cls).__init_subclass_with_meta__(_meta=_meta, **options) super(Interface, cls).__init_subclass_with_meta__(_meta=_meta, **options)
@classmethod @classmethod
def resolve_type(cls, instance, context, info): def resolve_type(cls, instance, info):
from .objecttype import ObjectType from .objecttype import ObjectType
if isinstance(instance, ObjectType): if isinstance(instance, ObjectType):
return type(instance) return type(instance)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
raise Exception("An Interface cannot be intitialized") raise Exception("An Interface cannot be intitialized")
@classmethod
def implements(cls, objecttype):
pass

View File

@ -1,8 +1,8 @@
def attr_resolver(attname, default_value, root, args, context, info): def attr_resolver(attname, default_value, root, info, **args):
return getattr(root, attname, default_value) return getattr(root, attname, default_value)
def dict_resolver(attname, default_value, root, args, context, info): def dict_resolver(attname, default_value, root, info, **args):
return root.get(attname, default_value) return root.get(attname, default_value)

View File

@ -11,10 +11,10 @@ class Query(ObjectType):
datetime = DateTime(_in=DateTime(name='in')) datetime = DateTime(_in=DateTime(name='in'))
time = Time(_at=Time(name='at')) time = Time(_at=Time(name='at'))
def resolve_datetime(self, _in=None): def resolve_datetime(self, info, _in=None):
return _in return _in
def resolve_time(self, _at=None): def resolve_time(self, info, _at=None):
return _at return _at

View File

@ -6,7 +6,7 @@ from ..schema import Schema
class Query(ObjectType): class Query(ObjectType):
generic = GenericScalar(input=GenericScalar()) generic = GenericScalar(input=GenericScalar())
def resolve_generic(self, input=None): def resolve_generic(self, info, input=None):
return input return input

View File

@ -7,7 +7,7 @@ from ..schema import Schema
class Query(ObjectType): class Query(ObjectType):
json = JSONString(input=JSONString()) json = JSONString(input=JSONString())
def resolve_json(self, input): def resolve_json(self, info, input):
return input return input
schema = Schema(query=Query) schema = Schema(query=Query)

View File

@ -12,13 +12,13 @@ def test_generate_mutation_no_args():
class MyMutation(Mutation): class MyMutation(Mutation):
'''Documentation''' '''Documentation'''
def mutate(self, **args): def mutate(self, info, **args):
return args return args
assert issubclass(MyMutation, ObjectType) assert issubclass(MyMutation, ObjectType)
assert MyMutation._meta.name == "MyMutation" assert MyMutation._meta.name == "MyMutation"
assert MyMutation._meta.description == "Documentation" assert MyMutation._meta.description == "Documentation"
resolved = MyMutation.Field().resolver(None, name='Peter') resolved = MyMutation.Field().resolver(None, None, name='Peter')
assert resolved == {'name': 'Peter'} assert resolved == {'name': 'Peter'}
@ -29,12 +29,12 @@ def test_generate_mutation_with_meta():
name = 'MyOtherMutation' name = 'MyOtherMutation'
description = 'Documentation' description = 'Documentation'
def mutate(self, **args): def mutate(self, info, **args):
return args return args
assert MyMutation._meta.name == "MyOtherMutation" assert MyMutation._meta.name == "MyOtherMutation"
assert MyMutation._meta.description == "Documentation" assert MyMutation._meta.description == "Documentation"
resolved = MyMutation.Field().resolver(None, name='Peter') resolved = MyMutation.Field().resolver(None, None, name='Peter')
assert resolved == {'name': 'Peter'} assert resolved == {'name': 'Peter'}
@ -59,13 +59,13 @@ def test_mutation_custom_output_type():
Output = User Output = User
def mutate(self, name): def mutate(self, info, name):
return User(name=name) return User(name=name)
field = CreateUser.Field() field = CreateUser.Field()
assert field.type == User assert field.type == User
assert field.args == {'name': Argument(String)} assert field.args == {'name': Argument(String)}
resolved = field.resolver(None, name='Peter') resolved = field.resolver(None, None, name='Peter')
assert isinstance(resolved, User) assert isinstance(resolved, User)
assert resolved.name == 'Peter' assert resolved.name == 'Peter'
@ -81,7 +81,7 @@ def test_mutation_execution():
name = String() name = String()
dynamic = Dynamic(lambda: String()) dynamic = Dynamic(lambda: String())
def mutate(self, name, dynamic): def mutate(self, info, name, dynamic):
return CreateUser(name=name, dynamic=dynamic) return CreateUser(name=name, dynamic=dynamic)
class Query(ObjectType): class Query(ObjectType):

View File

@ -14,7 +14,6 @@ from ..schema import Schema
from ..structures import List from ..structures import List
from ..union import Union from ..union import Union
from ..context import Context from ..context import Context
from ...utils.annotate import annotate
def test_query(): def test_query():
@ -39,14 +38,14 @@ def test_query_union():
one = String() one = String()
@classmethod @classmethod
def is_type_of(cls, root, context, info): def is_type_of(cls, root, info):
return isinstance(root, one_object) return isinstance(root, one_object)
class Two(ObjectType): class Two(ObjectType):
two = String() two = String()
@classmethod @classmethod
def is_type_of(cls, root, context, info): def is_type_of(cls, root, info):
return isinstance(root, two_object) return isinstance(root, two_object)
class MyUnion(Union): class MyUnion(Union):
@ -57,7 +56,7 @@ def test_query_union():
class Query(ObjectType): class Query(ObjectType):
unions = List(MyUnion) unions = List(MyUnion)
def resolve_unions(self): def resolve_unions(self, info):
return [one_object(), two_object()] return [one_object(), two_object()]
hello_schema = Schema(Query) hello_schema = Schema(Query)
@ -91,7 +90,7 @@ def test_query_interface():
one = String() one = String()
@classmethod @classmethod
def is_type_of(cls, root, context, info): def is_type_of(cls, root, info):
return isinstance(root, one_object) return isinstance(root, one_object)
class Two(ObjectType): class Two(ObjectType):
@ -102,13 +101,13 @@ def test_query_interface():
two = String() two = String()
@classmethod @classmethod
def is_type_of(cls, root, context, info): def is_type_of(cls, root, info):
return isinstance(root, two_object) return isinstance(root, two_object)
class Query(ObjectType): class Query(ObjectType):
interfaces = List(MyInterface) interfaces = List(MyInterface)
def resolve_interfaces(self): def resolve_interfaces(self, info):
return [one_object(), two_object()] return [one_object(), two_object()]
hello_schema = Schema(Query, types=[One, Two]) hello_schema = Schema(Query, types=[One, Two])
@ -156,7 +155,7 @@ def test_query_wrong_default_value():
field = String() field = String()
@classmethod @classmethod
def is_type_of(cls, root, context, info): def is_type_of(cls, root, info):
return isinstance(root, MyType) return isinstance(root, MyType)
class Query(ObjectType): class Query(ObjectType):
@ -188,7 +187,7 @@ def test_query_resolve_function():
class Query(ObjectType): class Query(ObjectType):
hello = String() hello = String()
def resolve_hello(self): def resolve_hello(self, info):
return 'World' return 'World'
hello_schema = Schema(Query) hello_schema = Schema(Query)
@ -202,7 +201,7 @@ def test_query_arguments():
class Query(ObjectType): class Query(ObjectType):
test = String(a_str=String(), a_int=Int()) test = String(a_str=String(), a_int=Int())
def resolve_test(self, **args): def resolve_test(self, info, **args):
return json.dumps([self, args], separators=(',', ':')) return json.dumps([self, args], separators=(',', ':'))
test_schema = Schema(Query) test_schema = Schema(Query)
@ -231,7 +230,7 @@ def test_query_input_field():
class Query(ObjectType): class Query(ObjectType):
test = String(a_input=Input()) test = String(a_input=Input())
def resolve_test(self, **args): def resolve_test(self, info, **args):
return json.dumps([self, args], separators=(',', ':')) return json.dumps([self, args], separators=(',', ':'))
test_schema = Schema(Query) test_schema = Schema(Query)
@ -254,10 +253,10 @@ def test_query_middlewares():
hello = String() hello = String()
other = String() other = String()
def resolve_hello(self): def resolve_hello(self, info):
return 'World' return 'World'
def resolve_other(self): def resolve_other(self, info):
return 'other' return 'other'
def reversed_middleware(next, *args, **kwargs): def reversed_middleware(next, *args, **kwargs):
@ -280,14 +279,14 @@ def test_objecttype_on_instances():
class ShipType(ObjectType): class ShipType(ObjectType):
name = String(description="Ship name", required=True) name = String(description="Ship name", required=True)
def resolve_name(self): def resolve_name(self, info):
# Here self will be the Ship instance returned in resolve_ship # Here self will be the Ship instance returned in resolve_ship
return self.name return self.name
class Query(ObjectType): class Query(ObjectType):
ship = Field(ShipType) ship = Field(ShipType)
def resolve_ship(self): def resolve_ship(self, info):
return Ship(name='xwing') return Ship(name='xwing')
schema = Schema(query=Query) schema = Schema(query=Query)
@ -302,7 +301,7 @@ def test_big_list_query_benchmark(benchmark):
class Query(ObjectType): class Query(ObjectType):
all_ints = List(Int) all_ints = List(Int)
def resolve_all_ints(self): def resolve_all_ints(self, info):
return big_list return big_list
hello_schema = Schema(Query) hello_schema = Schema(Query)
@ -319,7 +318,7 @@ def test_big_list_query_compiled_query_benchmark(benchmark):
class Query(ObjectType): class Query(ObjectType):
all_ints = List(Int) all_ints = List(Int)
def resolve_all_ints(self): def resolve_all_ints(self, info):
return big_list return big_list
hello_schema = Schema(Query) hello_schema = Schema(Query)
@ -420,15 +419,13 @@ def test_query_annotated_resolvers():
context = String() context = String()
info = String() info = String()
def resolve_annotated(self, id): def resolve_annotated(self, info, id):
return "{}-{}".format(self, id) return "{}-{}".format(self, id)
@annotate(context=Context, _trigger_warning=False) def resolve_context(self, info):
def resolve_context(self, context): assert isinstance(info.context, Context)
assert isinstance(context, Context) return "{}-{}".format(self, info.context.key)
return "{}-{}".format(self, context.key)
@annotate(info=ResolveInfo, _trigger_warning=False)
def resolve_info(self, info): def resolve_info(self, info):
assert isinstance(info, ResolveInfo) assert isinstance(info, ResolveInfo)
return "{}-{}".format(self, info.field_name) return "{}-{}".format(self, info.field_name)

View File

@ -16,22 +16,22 @@ class demo_obj(object):
def test_attr_resolver(): def test_attr_resolver():
resolved = attr_resolver('attr', None, demo_obj, args, context, info) resolved = attr_resolver('attr', None, demo_obj, info, **args)
assert resolved == 'value' assert resolved == 'value'
def test_attr_resolver_default_value(): def test_attr_resolver_default_value():
resolved = attr_resolver('attr2', 'default', demo_obj, args, context, info) resolved = attr_resolver('attr2', 'default', demo_obj, info, **args)
assert resolved == 'default' assert resolved == 'default'
def test_dict_resolver(): def test_dict_resolver():
resolved = dict_resolver('attr', None, demo_dict, args, context, info) resolved = dict_resolver('attr', None, demo_dict, info, **args)
assert resolved == 'value' assert resolved == 'value'
def test_dict_resolver_default_value(): def test_dict_resolver_default_value():
resolved = dict_resolver('attr2', 'default', demo_dict, args, context, info) resolved = dict_resolver('attr2', 'default', demo_dict, info, **args)
assert resolved == 'default' assert resolved == 'default'

View File

@ -204,5 +204,5 @@ def test_objecttype_with_possible_types():
typemap = TypeMap([MyObjectType]) typemap = TypeMap([MyObjectType])
graphql_type = typemap['MyObjectType'] graphql_type = typemap['MyObjectType']
assert graphql_type.is_type_of assert graphql_type.is_type_of
assert graphql_type.is_type_of({}, None, None) is True assert graphql_type.is_type_of({}, None) is True
assert graphql_type.is_type_of(MyObjectType(), None, None) is False assert graphql_type.is_type_of(MyObjectType(), None) is False

View File

@ -6,7 +6,7 @@ from ..schema import Schema
class Query(ObjectType): class Query(ObjectType):
uuid = UUID(input=UUID()) uuid = UUID(input=UUID())
def resolve_uuid(self, input): def resolve_uuid(self, info, input):
return input return input
schema = Schema(query=Query) schema = Schema(query=Query)

View File

@ -11,7 +11,6 @@ from graphql.type.typemap import GraphQLTypeMap
from ..utils.get_unbound_function import get_unbound_function from ..utils.get_unbound_function import get_unbound_function
from ..utils.str_converters import to_camel_case from ..utils.str_converters import to_camel_case
from ..utils.auto_resolver import auto_resolver, final_resolver
from .definitions import (GrapheneEnumType, GrapheneGraphQLType, from .definitions import (GrapheneEnumType, GrapheneGraphQLType,
GrapheneInputObjectType, GrapheneInterfaceType, GrapheneInputObjectType, GrapheneInterfaceType,
GrapheneObjectType, GrapheneScalarType, GrapheneObjectType, GrapheneScalarType,
@ -38,12 +37,12 @@ def is_graphene_type(_type):
return True return True
def resolve_type(resolve_type_func, map, type_name, root, context, info): def resolve_type(resolve_type_func, map, type_name, root, info):
_type = resolve_type_func(root, context, info) _type = resolve_type_func(root, info)
if not _type: if not _type:
return_type = map[type_name] return_type = map[type_name]
return get_default_resolve_type_fn(root, context, info, return_type) return get_default_resolve_type_fn(root, info, return_type)
if inspect.isclass(_type) and issubclass(_type, ObjectType): if inspect.isclass(_type) and issubclass(_type, ObjectType):
graphql_type = map.get(_type._meta.name) graphql_type = map.get(_type._meta.name)
@ -55,7 +54,7 @@ def resolve_type(resolve_type_func, map, type_name, root, context, info):
return _type return _type
def is_type_of_from_possible_types(possible_types, root, context, info): def is_type_of_from_possible_types(possible_types, root, info):
return isinstance(root, possible_types) return isinstance(root, possible_types)
@ -196,7 +195,7 @@ class TypeMap(GraphQLTypeMap):
graphene_type=type, graphene_type=type,
name=type._meta.name, name=type._meta.name,
description=type._meta.description, description=type._meta.description,
container_type=type._meta.create_container, container_type=type._meta.container,
fields=partial( fields=partial(
self.construct_fields_for_type, map, type, is_input_type=True), self.construct_fields_for_type, map, type, is_input_type=True),
) )
@ -240,7 +239,7 @@ class TypeMap(GraphQLTypeMap):
_field = GraphQLInputObjectField( _field = GraphQLInputObjectField(
field_type, field_type,
default_value=field.default_value, default_value=field.default_value,
out_name=field.name or name, out_name=name,
description=field.description) description=field.description)
else: else:
args = OrderedDict() args = OrderedDict()
@ -256,13 +255,13 @@ class TypeMap(GraphQLTypeMap):
_field = GraphQLField( _field = GraphQLField(
field_type, field_type,
args=args, args=args,
resolver=auto_resolver(field.get_resolver( resolver=field.get_resolver(
auto_resolver(self.get_resolver_for_type( self.get_resolver_for_type(
type, type,
name, name,
field.default_value field.default_value
)) )
)), ),
deprecation_reason=field.deprecation_reason, deprecation_reason=field.deprecation_reason,
description=field.description) description=field.description)
field_name = field.name or self.get_name(name) field_name = field.name or self.get_name(name)
@ -292,7 +291,7 @@ class TypeMap(GraphQLTypeMap):
default_resolver = type._meta.default_resolver or get_default_resolver( default_resolver = type._meta.default_resolver or get_default_resolver(
) )
return final_resolver(partial(default_resolver, name, default_value)) return partial(default_resolver, name, default_value)
def get_field_type(self, map, type): def get_field_type(self, map, type):
if isinstance(type, List): if isinstance(type, List):

View File

@ -34,7 +34,7 @@ class Union(UnmountedType, BaseType):
return cls return cls
@classmethod @classmethod
def resolve_type(cls, instance, context, info): def resolve_type(cls, instance, info):
from .objecttype import ObjectType from .objecttype import ObjectType
if isinstance(instance, ObjectType): if isinstance(instance, ObjectType):
return type(instance) return type(instance)

View File

@ -1,21 +0,0 @@
from .resolver_from_annotations import resolver_from_annotations
def final_resolver(func):
func._is_final_resolver = True
return func
def auto_resolver(func=None):
if not func:
return
if not is_final_resolver(func):
# Is a Graphene 2.0 resolver function
return final_resolver(resolver_from_annotations(func))
else:
return func
def is_final_resolver(func):
return getattr(func, '_is_final_resolver', False)

View File

@ -1,18 +1,11 @@
from six import PY2 from functools import wraps
from .annotate import annotate
from .deprecated import deprecated from .deprecated import deprecated
if PY2:
deprecation_reason = (
'Please use @annotate instead.'
)
else:
deprecation_reason = (
'Please use Python 3 type annotations instead. Read more: '
'https://docs.python.org/3/library/typing.html'
)
@deprecated('This function is deprecated')
@deprecated(deprecation_reason)
def resolve_only_args(func): def resolve_only_args(func):
return annotate(func) @wraps(func)
def wrapped_func(root, info, **args):
return func(root, **args)
return wrapped_func

View File

@ -1,36 +0,0 @@
import pytest
from ..annotate import annotate
from ..auto_resolver import auto_resolver, final_resolver
from ...types import Context, ResolveInfo
@final_resolver
def resolver(root, args, context, info):
return root, args, context, info
def resolver_annotated(root, **args):
return root, args, None, None
@annotate(context=Context, info=ResolveInfo, _trigger_warning=False)
def resolver_with_context_and_info(root, context, info, **args):
return root, args, context, info
def test_auto_resolver_non_annotated():
decorated_resolver = auto_resolver(resolver)
# We make sure the function is not wrapped
assert decorated_resolver == resolver
assert decorated_resolver(1, {}, 2, 3) == (1, {}, 2, 3)
def test_auto_resolver_annotated():
decorated_resolver = auto_resolver(resolver_annotated)
assert decorated_resolver(1, {}, 2, 3) == (1, {}, None, None)
def test_auto_resolver_annotated_with_context_and_info():
decorated_resolver = auto_resolver(resolver_with_context_and_info)
assert decorated_resolver(1, {}, 2, 3) == (1, {}, 2, 3)

View File

@ -1,44 +0,0 @@
import pytest
from ..annotate import annotate
from ..resolver_from_annotations import resolver_from_annotations
from ...types import Context, ResolveInfo
@annotate
def func(root, **args):
return root, args, None, None
@annotate(context=Context)
def func_with_context(root, context, **args):
return root, args, context, None
@annotate(info=ResolveInfo)
def func_with_info(root, info, **args):
return root, args, None, info
@annotate(context=Context, info=ResolveInfo)
def func_with_context_and_info(root, context, info, **args):
return root, args, context, info
root = 1
args = {
'arg': 0
}
context = 2
info = 3
@pytest.mark.parametrize("func,expected", [
(func, (1, {'arg': 0}, None, None)),
(func_with_context, (1, {'arg': 0}, 2, None)),
(func_with_info, (1, {'arg': 0}, None, 3)),
(func_with_context_and_info, (1, {'arg': 0}, 2, 3)),
])
def test_resolver_from_annotations(func, expected):
resolver_func = resolver_from_annotations(func)
resolved = resolver_func(root, args, context, info)
assert resolved == expected