Updated Node

This commit is contained in:
Syrus Akbary 2016-06-14 21:45:44 -07:00
parent 6c7cf55b18
commit 8d4cf2d059
3 changed files with 22 additions and 9 deletions

View File

@ -3,25 +3,26 @@ import six
from graphql_relay import node_definitions, from_global_id, to_global_id from graphql_relay import node_definitions, from_global_id, to_global_id
from ..types.field import Field from ..types.field import Field
from ..types.options import Options
from ..types.objecttype import ObjectTypeMeta from ..types.objecttype import ObjectTypeMeta
from ..types.interface import Interface from ..types.interface import Interface
class NodeMeta(ObjectTypeMeta): class NodeMeta(ObjectTypeMeta):
def construct(cls, bases, attrs): def __new__(cls, name, bases, attrs):
cls = super(NodeMeta, cls).__new__(cls, name, bases, attrs)
is_object_type = cls.is_object_type() is_object_type = cls.is_object_type()
cls = super(NodeMeta, cls).construct(bases, attrs)
if not is_object_type: if not is_object_type:
get_node_from_global_id = getattr(cls, 'get_node_from_global_id', None) get_node_from_global_id = getattr(cls, 'get_node_from_global_id', None)
id_resolver = getattr(cls, 'id_resolver', None)
assert get_node_from_global_id, '{}.get_node method is required by the Node interface.'.format(cls.__name__) assert get_node_from_global_id, '{}.get_node method is required by the Node interface.'.format(cls.__name__)
cls.id_resolver = attrs.pop('id_resolver', None)
node_interface, node_field = node_definitions( node_interface, node_field = node_definitions(
get_node_from_global_id, get_node_from_global_id,
id_resolver=cls.id_resolver, id_resolver=id_resolver,
) )
cls._meta.graphql_type = node_interface cls._meta = Options(None, graphql_type=node_interface)
cls.Field = partial(Field.copy_and_extend, node_field, type=node_field.type, _creation_counter=None) cls.Field = partial(Field.copy_and_extend, node_field, type=node_field.type, parent=cls, _creation_counter=None)
else: else:
# The interface provided by node_definitions is not an instance # The interface provided by node_definitions is not an instance
# of GrapheneInterfaceType, so it will have no graphql_type, # of GrapheneInterfaceType, so it will have no graphql_type,
@ -38,11 +39,15 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
@classmethod @classmethod
def from_global_id(cls, global_id): def from_global_id(cls, global_id):
return from_global_id(global_id) if cls is Node:
return from_global_id(global_id)
raise NotImplementedError("You need to implement {}.from_global_id".format(cls.__name__))
@classmethod @classmethod
def to_global_id(cls, type, id): def to_global_id(cls, type, id):
return to_global_id(type, id) if cls is Node:
return to_global_id(type, id)
raise NotImplementedError("You need to implement {}.to_global_id".format(cls.__name__))
@classmethod @classmethod
def id_resolver(cls, root, args, context, info): def id_resolver(cls, root, args, context, info):

View File

@ -6,6 +6,10 @@ from ...types.scalars import String, Int
class CustomNode(Node): class CustomNode(Node):
@staticmethod
def to_global_id(type, id):
return id
@staticmethod @staticmethod
def get_node_from_global_id(id, context, info): def get_node_from_global_id(id, context, info):
assert info.schema == schema assert info.schema == schema

View File

@ -35,6 +35,10 @@ def get_fields(in_type, attrs, bases, graphql_types=()):
extended_fields = list(get_fields_from_types(graphene_bases)) extended_fields = list(get_fields_from_types(graphene_bases))
local_fields = list(get_fields_from_attrs(in_type, attrs)) local_fields = list(get_fields_from_attrs(in_type, attrs))
# We asume the extended fields are already sorted, so we only
# have to sort the local fields, that are get from attrs
# and could be unordered as is a dict and not OrderedDict
local_fields = sorted(local_fields, key=lambda kv: kv[1])
field_names = set(f[0] for f in local_fields) field_names = set(f[0] for f in local_fields)
for name, extended_field in extended_fields: for name, extended_field in extended_fields:
@ -45,4 +49,4 @@ def get_fields(in_type, attrs, bases, graphql_types=()):
fields.extend(local_fields) fields.extend(local_fields)
return OrderedDict(sorted(fields, key=lambda kv: kv[1])) return OrderedDict(fields)