Improved relay nodes and field copies

This commit is contained in:
Syrus Akbary 2016-06-10 00:23:31 -07:00
parent 522f769cad
commit 51e97510c0
9 changed files with 80 additions and 44 deletions

View File

@ -4,7 +4,7 @@ from graphene import implements, relay, resolve_only_args
from .data import create_ship, get_empire, get_faction, get_rebels, get_ship
class Ship(graphene.ObjectType, relay.Node):
class Ship(relay.Node, graphene.ObjectType):
'''A ship in the Star Wars saga'''
name = graphene.String(description='The name of the ship.')
@ -13,7 +13,7 @@ class Ship(graphene.ObjectType, relay.Node):
return get_ship(id)
class Faction(graphene.ObjectType, relay.Node):
class Faction(relay.Node, graphene.ObjectType):
'''A faction in the Star Wars saga'''
name = graphene.String(description='The name of the faction.')
# ships = relay.ConnectionField(

View File

@ -39,14 +39,9 @@ class ClientIDMutationMeta(MutationMeta):
input_fields=input_local_fields,
output_fields=cls._fields(bases, attrs, local_fields),
mutate_and_get_payload=cls.mutate_and_get_payload,
input_type_class=partial(GrapheneInputObjectType, graphene_type=cls),
input_field_class=InputField,
output_type_class=partial(GrapheneObjectType, graphene_type=cls),
field_class=Field,
)
cls._meta.graphql_type = field.type
cls.Field = partial(Field.copy_and_extend, field, type=None, _creation_counter=None)
cls.Field = partial(Field.copy_and_extend, field, type=field.type, _creation_counter=None)
constructed = super(ClientIDMutationMeta, cls).construct(bases, attrs)
return constructed

View File

@ -11,25 +11,30 @@ from ..types.interface import GrapheneInterfaceType, Interface, InterfaceTypeMet
class NodeMeta(ObjectTypeMeta):
def construct(cls, bases, attrs):
if not cls.is_object_type():
cls.get_node = attrs.pop('get_node')
is_object_type = cls.is_object_type()
cls = super(NodeMeta, cls).construct(bases, attrs)
if not is_object_type:
get_node_from_global_id = getattr(cls, 'get_node_from_global_id', None)
assert get_node_from_global_id, '{}.get_node method is required by the Node interface.'.format(cls.__name__)
cls.id_resolver = attrs.pop('id_resolver', None)
node_interface, node_field = node_definitions(
cls.get_node,
get_node_from_global_id,
id_resolver=cls.id_resolver,
interface_class=partial(GrapheneInterfaceType, graphene_type=cls),
field_class=Field,
)
cls._meta.graphql_type = node_interface
cls.Field = partial(Field.copy_and_extend, node_field, type=None, _creation_counter=None)
return super(NodeMeta, cls).construct(bases, attrs)
cls.Field = partial(Field.copy_and_extend, node_field, type=node_field.type, _creation_counter=None)
else:
# The interface provided by node_definitions is not an instance
# of GrapheneInterfaceType, so it will have no graphql_type,
# so will not trigger Node.implements
cls.implements(cls)
return cls
class Node(six.with_metaclass(NodeMeta, Interface)):
@classmethod
def require_get_node(cls):
return cls == Node
return Node._meta.graphql_type in cls._meta.graphql_type._provided_interfaces
@classmethod
def from_global_id(cls, global_id):
@ -44,7 +49,7 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
return cls.to_global_id(info.parent_type.name, getattr(root, 'id', None))
@classmethod
def get_node(cls, global_id, context, info):
def get_node_from_global_id(cls, global_id, context, info):
try:
_type, _id = cls.from_global_id(global_id)
except:

View File

@ -7,7 +7,7 @@ from ...types import ObjectType, Schema, implements
from ...types.scalars import String
class MyNode(ObjectType, Node):
class MyNode(Node, ObjectType):
name = String()
@ -25,9 +25,7 @@ schema = Schema(query=RootQuery, types=[MyNode])
def test_node_no_get_node():
with pytest.raises(AssertionError) as excinfo:
class MyNode(ObjectType):
class Meta:
interfaces = [Node]
class MyNode(Node, ObjectType):
pass
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)
@ -35,9 +33,8 @@ def test_node_no_get_node():
def test_node_no_get_node_with_meta():
with pytest.raises(AssertionError) as excinfo:
class MyNode(ObjectType):
class Meta:
interfaces = [Node]
class MyNode(Node, ObjectType):
pass
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)

View File

@ -7,7 +7,7 @@ from ...types.scalars import String, Int
class CustomNode(Node):
@staticmethod
def get_node(id, context, info):
def get_node_from_global_id(id, context, info):
assert info.schema == schema
if id in user_data:
return user_data.get(id)
@ -15,11 +15,11 @@ class CustomNode(Node):
return photo_data.get(id)
class User(ObjectType, CustomNode):
class User(CustomNode, ObjectType):
name = String()
class Photo(ObjectType, CustomNode):
class Photo(CustomNode, ObjectType):
width = Int()

View File

@ -116,20 +116,40 @@ class Field(AbstractField, GraphQLField, OrderedType):
@classmethod
def copy_and_extend(cls, field, type=None, args=None, resolver=None, source=None, deprecation_reason=None, name=None, description=None, required=False, _creation_counter=False, **extra_args):
if isinstance(field, Field):
type = type or field._type
resolver = resolver or field._resolver
source = source or field.source
name = name or field._name
required = required or field.required
_creation_counter = field.creation_counter if _creation_counter is False else None
attname = field.attname
parent = field.parent
else:
# If is a GraphQLField
type = type or field.type
resolver = resolver or field.resolver
source = None
name = field.name
required = None
_creation_counter = None
attname = name
parent = None
new_field = cls(
type=type or field._type,
type=type,
args=to_arguments(args, field.args),
resolver=field._resolver,
source=source or getattr(field, 'source', None),
resolver=resolver,
source=source,
deprecation_reason=field.deprecation_reason,
name=field._name,
required=required or getattr(field, 'required', False),
name=name,
required=required,
description=field.description,
_creation_counter=getattr(field, 'creation_counter', None) if _creation_counter is False else None,
_creation_counter=_creation_counter,
**extra_args
)
new_field.attname = field.attname
new_field.parent = field.parent
new_field.attname = attname
new_field.parent = parent
return new_field
def __str__(self):
@ -168,14 +188,30 @@ class InputField(AbstractField, GraphQLInputObjectField, OrderedType):
@classmethod
def copy_and_extend(cls, field, type=None, default_value=None, description=None, name=None, required=False, _creation_counter=False):
if isinstance(field, Field):
type = type or field._type
name = name or field._name
required = required or field.required
_creation_counter = field.creation_counter if _creation_counter is False else None
attname = field.attname
parent = field.parent
else:
# If is a GraphQLField
type = type or field.type
name = field.name
required = None
_creation_counter = None
attname = None
parent = None
new_field = cls(
type=type or field._type,
name=name or field._name,
required=required or field.required,
type=type,
name=name,
required=required,
default_value=default_value or field.default_value,
description=description or field.description,
_creation_counter=getattr(field, 'creation_counter', None) if _creation_counter is False else None,
_creation_counter=_creation_counter,
)
new_field.attname = field.attname
new_field.parent = field.parent
new_field.attname = attname
new_field.parent = parent
return new_field

View File

@ -40,7 +40,10 @@ class Interface(six.with_metaclass(InterfaceTypeMeta)):
abstract = True
def __init__(self, *args, **kwargs):
raise Exception("An interface cannot be intitialized")
from .objecttype import ObjectType
if not isinstance(self, ObjectType):
raise Exception("An interface cannot be intitialized")
super(Interface, self).__init__(*args, **kwargs)
@classmethod
def implements(cls, object_type):

View File

@ -71,6 +71,7 @@ class ObjectTypeMeta(InterfaceTypeMeta):
cls._meta.interfaces,
cls.get_interfaces(bases),
)))
local_fields = cls._extract_local_fields(attrs)
if not cls._meta.graphql_type:
cls = super(ObjectTypeMeta, cls).construct(bases, attrs)

View File

@ -93,7 +93,6 @@ def test_objecttype_as_container_get_fields():
def test_parent_container_get_fields():
fields = Container._meta.graphql_type.get_fields()
print [(f.creation_counter, f.name) for f in fields.values()]
assert fields.keys() == ['field1', 'field2']