Fixed get_node

This commit is contained in:
Syrus Akbary 2016-05-11 23:47:10 -07:00
parent 379768930d
commit 7a27a9ef9f
4 changed files with 18 additions and 12 deletions

View File

@ -80,7 +80,7 @@ class NodeField(Field):
if not is_node(object_type) or (self.field_object_type and object_type != field_object_type):
return
return object_type.get_node(_id, info)
return object_type.get_node(_id, context, info)
@with_context
def resolver(self, instance, args, context, info):

View File

@ -16,7 +16,7 @@ class MyNode(relay.Node):
name = graphene.String()
@classmethod
def get_node(cls, id, context, info):
def get_node(cls, id, info):
return MyNode(id=id, name='mo')

View File

@ -14,7 +14,7 @@ from ..core.types import Boolean, Field, List, String
from ..core.types.argument import ArgumentsGroup
from ..core.types.definitions import NonNull
from ..utils import memoize
from ..utils.wrap_resolver_function import wrap_resolver_function
from ..utils.wrap_resolver_function import has_context
from .fields import GlobalIDField
@ -108,7 +108,7 @@ class NodeMeta(InterfaceMeta):
get_node = getattr(cls, 'get_node', None)
assert get_node, 'get_node classmethod not found in %s Node' % cls
assert callable(get_node), 'get_node have to be callable'
args = 3
args = 4
if isinstance(get_node, staticmethod):
args -= 1
@ -120,12 +120,15 @@ class NodeMeta(InterfaceMeta):
@staticmethod
@wraps(get_node)
def wrapped_node(*node_args):
if len(node_args) < args:
node_args += (None, )
return get_node(*node_args[:-1])
setattr(cls, 'get_node', wrapped_node)
def wrapped_node(id, context=None, info=None):
node_args = [id, info, context]
if has_context(get_node):
return get_node(*node_args[:get_node_num_args-1], context=context)
if get_node_num_args-1 == 0:
return get_node(id)
return get_node(*node_args[:get_node_num_args-1])
node_func = wrapped_node
setattr(cls, 'get_node', node_func)
def construct(cls, *args, **kwargs):
cls = super(NodeMeta, cls).construct(*args, **kwargs)

View File

@ -6,11 +6,14 @@ def with_context(func):
return func
def has_context(func):
return getattr(func, 'with_context', None)
def wrap_resolver_function(func):
@wraps(func)
def inner(self, args, context, info):
with_context = getattr(func, 'with_context', None)
if with_context:
if has_context(func):
return func(self, args, context, info)
# For old compatibility
return func(self, args, info)