Fixed get_node

This commit is contained in:
Syrus Akbary 2016-05-11 23:47:10 -07:00
parent 379768930d
commit 7a27a9ef9f
4 changed files with 18 additions and 12 deletions

View File

@ -80,7 +80,7 @@ class NodeField(Field):
if not is_node(object_type) or (self.field_object_type and object_type != field_object_type): if not is_node(object_type) or (self.field_object_type and object_type != field_object_type):
return return
return object_type.get_node(_id, info) return object_type.get_node(_id, context, info)
@with_context @with_context
def resolver(self, instance, args, context, info): def resolver(self, instance, args, context, info):

View File

@ -16,7 +16,7 @@ class MyNode(relay.Node):
name = graphene.String() name = graphene.String()
@classmethod @classmethod
def get_node(cls, id, context, info): def get_node(cls, id, info):
return MyNode(id=id, name='mo') return MyNode(id=id, name='mo')

View File

@ -14,7 +14,7 @@ from ..core.types import Boolean, Field, List, String
from ..core.types.argument import ArgumentsGroup from ..core.types.argument import ArgumentsGroup
from ..core.types.definitions import NonNull from ..core.types.definitions import NonNull
from ..utils import memoize from ..utils import memoize
from ..utils.wrap_resolver_function import wrap_resolver_function from ..utils.wrap_resolver_function import has_context
from .fields import GlobalIDField from .fields import GlobalIDField
@ -108,7 +108,7 @@ class NodeMeta(InterfaceMeta):
get_node = getattr(cls, 'get_node', None) get_node = getattr(cls, 'get_node', None)
assert get_node, 'get_node classmethod not found in %s Node' % cls assert get_node, 'get_node classmethod not found in %s Node' % cls
assert callable(get_node), 'get_node have to be callable' assert callable(get_node), 'get_node have to be callable'
args = 3 args = 4
if isinstance(get_node, staticmethod): if isinstance(get_node, staticmethod):
args -= 1 args -= 1
@ -120,12 +120,15 @@ class NodeMeta(InterfaceMeta):
@staticmethod @staticmethod
@wraps(get_node) @wraps(get_node)
def wrapped_node(*node_args): def wrapped_node(id, context=None, info=None):
if len(node_args) < args: node_args = [id, info, context]
node_args += (None, ) if has_context(get_node):
return get_node(*node_args[:-1]) return get_node(*node_args[:get_node_num_args-1], context=context)
if get_node_num_args-1 == 0:
setattr(cls, 'get_node', wrapped_node) return get_node(id)
return get_node(*node_args[:get_node_num_args-1])
node_func = wrapped_node
setattr(cls, 'get_node', node_func)
def construct(cls, *args, **kwargs): def construct(cls, *args, **kwargs):
cls = super(NodeMeta, cls).construct(*args, **kwargs) cls = super(NodeMeta, cls).construct(*args, **kwargs)

View File

@ -6,11 +6,14 @@ def with_context(func):
return func return func
def has_context(func):
return getattr(func, 'with_context', None)
def wrap_resolver_function(func): def wrap_resolver_function(func):
@wraps(func) @wraps(func)
def inner(self, args, context, info): def inner(self, args, context, info):
with_context = getattr(func, 'with_context', None) if has_context(func):
if with_context:
return func(self, args, context, info) return func(self, args, context, info)
# For old compatibility # For old compatibility
return func(self, args, info) return func(self, args, info)