mirror of
https://github.com/graphql-python/graphene.git
synced 2024-11-26 03:23:55 +03:00
Improved relay nodes and field copies
This commit is contained in:
parent
522f769cad
commit
51e97510c0
|
@ -4,7 +4,7 @@ from graphene import implements, relay, resolve_only_args
|
|||
from .data import create_ship, get_empire, get_faction, get_rebels, get_ship
|
||||
|
||||
|
||||
class Ship(graphene.ObjectType, relay.Node):
|
||||
class Ship(relay.Node, graphene.ObjectType):
|
||||
'''A ship in the Star Wars saga'''
|
||||
name = graphene.String(description='The name of the ship.')
|
||||
|
||||
|
@ -13,7 +13,7 @@ class Ship(graphene.ObjectType, relay.Node):
|
|||
return get_ship(id)
|
||||
|
||||
|
||||
class Faction(graphene.ObjectType, relay.Node):
|
||||
class Faction(relay.Node, graphene.ObjectType):
|
||||
'''A faction in the Star Wars saga'''
|
||||
name = graphene.String(description='The name of the faction.')
|
||||
# ships = relay.ConnectionField(
|
||||
|
|
|
@ -39,14 +39,9 @@ class ClientIDMutationMeta(MutationMeta):
|
|||
input_fields=input_local_fields,
|
||||
output_fields=cls._fields(bases, attrs, local_fields),
|
||||
mutate_and_get_payload=cls.mutate_and_get_payload,
|
||||
|
||||
input_type_class=partial(GrapheneInputObjectType, graphene_type=cls),
|
||||
input_field_class=InputField,
|
||||
output_type_class=partial(GrapheneObjectType, graphene_type=cls),
|
||||
field_class=Field,
|
||||
)
|
||||
cls._meta.graphql_type = field.type
|
||||
cls.Field = partial(Field.copy_and_extend, field, type=None, _creation_counter=None)
|
||||
cls.Field = partial(Field.copy_and_extend, field, type=field.type, _creation_counter=None)
|
||||
constructed = super(ClientIDMutationMeta, cls).construct(bases, attrs)
|
||||
return constructed
|
||||
|
||||
|
|
|
@ -11,25 +11,30 @@ from ..types.interface import GrapheneInterfaceType, Interface, InterfaceTypeMet
|
|||
class NodeMeta(ObjectTypeMeta):
|
||||
|
||||
def construct(cls, bases, attrs):
|
||||
if not cls.is_object_type():
|
||||
cls.get_node = attrs.pop('get_node')
|
||||
is_object_type = cls.is_object_type()
|
||||
cls = super(NodeMeta, cls).construct(bases, attrs)
|
||||
if not is_object_type:
|
||||
get_node_from_global_id = getattr(cls, 'get_node_from_global_id', None)
|
||||
assert get_node_from_global_id, '{}.get_node method is required by the Node interface.'.format(cls.__name__)
|
||||
cls.id_resolver = attrs.pop('id_resolver', None)
|
||||
node_interface, node_field = node_definitions(
|
||||
cls.get_node,
|
||||
get_node_from_global_id,
|
||||
id_resolver=cls.id_resolver,
|
||||
interface_class=partial(GrapheneInterfaceType, graphene_type=cls),
|
||||
field_class=Field,
|
||||
)
|
||||
cls._meta.graphql_type = node_interface
|
||||
cls.Field = partial(Field.copy_and_extend, node_field, type=None, _creation_counter=None)
|
||||
return super(NodeMeta, cls).construct(bases, attrs)
|
||||
|
||||
cls.Field = partial(Field.copy_and_extend, node_field, type=node_field.type, _creation_counter=None)
|
||||
else:
|
||||
# The interface provided by node_definitions is not an instance
|
||||
# of GrapheneInterfaceType, so it will have no graphql_type,
|
||||
# so will not trigger Node.implements
|
||||
cls.implements(cls)
|
||||
return cls
|
||||
|
||||
class Node(six.with_metaclass(NodeMeta, Interface)):
|
||||
|
||||
@classmethod
|
||||
def require_get_node(cls):
|
||||
return cls == Node
|
||||
return Node._meta.graphql_type in cls._meta.graphql_type._provided_interfaces
|
||||
|
||||
@classmethod
|
||||
def from_global_id(cls, global_id):
|
||||
|
@ -44,7 +49,7 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
|
|||
return cls.to_global_id(info.parent_type.name, getattr(root, 'id', None))
|
||||
|
||||
@classmethod
|
||||
def get_node(cls, global_id, context, info):
|
||||
def get_node_from_global_id(cls, global_id, context, info):
|
||||
try:
|
||||
_type, _id = cls.from_global_id(global_id)
|
||||
except:
|
||||
|
|
|
@ -7,7 +7,7 @@ from ...types import ObjectType, Schema, implements
|
|||
from ...types.scalars import String
|
||||
|
||||
|
||||
class MyNode(ObjectType, Node):
|
||||
class MyNode(Node, ObjectType):
|
||||
|
||||
name = String()
|
||||
|
||||
|
@ -25,9 +25,7 @@ schema = Schema(query=RootQuery, types=[MyNode])
|
|||
|
||||
def test_node_no_get_node():
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
class MyNode(ObjectType):
|
||||
class Meta:
|
||||
interfaces = [Node]
|
||||
class MyNode(Node, ObjectType):
|
||||
pass
|
||||
|
||||
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)
|
||||
|
@ -35,9 +33,8 @@ def test_node_no_get_node():
|
|||
|
||||
def test_node_no_get_node_with_meta():
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
class MyNode(ObjectType):
|
||||
class Meta:
|
||||
interfaces = [Node]
|
||||
class MyNode(Node, ObjectType):
|
||||
pass
|
||||
|
||||
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from ...types.scalars import String, Int
|
|||
|
||||
class CustomNode(Node):
|
||||
@staticmethod
|
||||
def get_node(id, context, info):
|
||||
def get_node_from_global_id(id, context, info):
|
||||
assert info.schema == schema
|
||||
if id in user_data:
|
||||
return user_data.get(id)
|
||||
|
@ -15,11 +15,11 @@ class CustomNode(Node):
|
|||
return photo_data.get(id)
|
||||
|
||||
|
||||
class User(ObjectType, CustomNode):
|
||||
class User(CustomNode, ObjectType):
|
||||
name = String()
|
||||
|
||||
|
||||
class Photo(ObjectType, CustomNode):
|
||||
class Photo(CustomNode, ObjectType):
|
||||
width = Int()
|
||||
|
||||
|
||||
|
|
|
@ -116,20 +116,40 @@ class Field(AbstractField, GraphQLField, OrderedType):
|
|||
|
||||
@classmethod
|
||||
def copy_and_extend(cls, field, type=None, args=None, resolver=None, source=None, deprecation_reason=None, name=None, description=None, required=False, _creation_counter=False, **extra_args):
|
||||
if isinstance(field, Field):
|
||||
type = type or field._type
|
||||
resolver = resolver or field._resolver
|
||||
source = source or field.source
|
||||
name = name or field._name
|
||||
required = required or field.required
|
||||
_creation_counter = field.creation_counter if _creation_counter is False else None
|
||||
attname = field.attname
|
||||
parent = field.parent
|
||||
else:
|
||||
# If is a GraphQLField
|
||||
type = type or field.type
|
||||
resolver = resolver or field.resolver
|
||||
source = None
|
||||
name = field.name
|
||||
required = None
|
||||
_creation_counter = None
|
||||
attname = name
|
||||
parent = None
|
||||
|
||||
new_field = cls(
|
||||
type=type or field._type,
|
||||
type=type,
|
||||
args=to_arguments(args, field.args),
|
||||
resolver=field._resolver,
|
||||
source=source or getattr(field, 'source', None),
|
||||
resolver=resolver,
|
||||
source=source,
|
||||
deprecation_reason=field.deprecation_reason,
|
||||
name=field._name,
|
||||
required=required or getattr(field, 'required', False),
|
||||
name=name,
|
||||
required=required,
|
||||
description=field.description,
|
||||
_creation_counter=getattr(field, 'creation_counter', None) if _creation_counter is False else None,
|
||||
_creation_counter=_creation_counter,
|
||||
**extra_args
|
||||
)
|
||||
new_field.attname = field.attname
|
||||
new_field.parent = field.parent
|
||||
new_field.attname = attname
|
||||
new_field.parent = parent
|
||||
return new_field
|
||||
|
||||
def __str__(self):
|
||||
|
@ -168,14 +188,30 @@ class InputField(AbstractField, GraphQLInputObjectField, OrderedType):
|
|||
|
||||
@classmethod
|
||||
def copy_and_extend(cls, field, type=None, default_value=None, description=None, name=None, required=False, _creation_counter=False):
|
||||
if isinstance(field, Field):
|
||||
type = type or field._type
|
||||
name = name or field._name
|
||||
required = required or field.required
|
||||
_creation_counter = field.creation_counter if _creation_counter is False else None
|
||||
attname = field.attname
|
||||
parent = field.parent
|
||||
else:
|
||||
# If is a GraphQLField
|
||||
type = type or field.type
|
||||
name = field.name
|
||||
required = None
|
||||
_creation_counter = None
|
||||
attname = None
|
||||
parent = None
|
||||
|
||||
new_field = cls(
|
||||
type=type or field._type,
|
||||
name=name or field._name,
|
||||
required=required or field.required,
|
||||
type=type,
|
||||
name=name,
|
||||
required=required,
|
||||
default_value=default_value or field.default_value,
|
||||
description=description or field.description,
|
||||
_creation_counter=getattr(field, 'creation_counter', None) if _creation_counter is False else None,
|
||||
_creation_counter=_creation_counter,
|
||||
)
|
||||
new_field.attname = field.attname
|
||||
new_field.parent = field.parent
|
||||
new_field.attname = attname
|
||||
new_field.parent = parent
|
||||
return new_field
|
||||
|
|
|
@ -40,7 +40,10 @@ class Interface(six.with_metaclass(InterfaceTypeMeta)):
|
|||
abstract = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise Exception("An interface cannot be intitialized")
|
||||
from .objecttype import ObjectType
|
||||
if not isinstance(self, ObjectType):
|
||||
raise Exception("An interface cannot be intitialized")
|
||||
super(Interface, self).__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def implements(cls, object_type):
|
||||
|
|
|
@ -71,6 +71,7 @@ class ObjectTypeMeta(InterfaceTypeMeta):
|
|||
cls._meta.interfaces,
|
||||
cls.get_interfaces(bases),
|
||||
)))
|
||||
|
||||
local_fields = cls._extract_local_fields(attrs)
|
||||
if not cls._meta.graphql_type:
|
||||
cls = super(ObjectTypeMeta, cls).construct(bases, attrs)
|
||||
|
|
|
@ -93,7 +93,6 @@ def test_objecttype_as_container_get_fields():
|
|||
|
||||
def test_parent_container_get_fields():
|
||||
fields = Container._meta.graphql_type.get_fields()
|
||||
print [(f.creation_counter, f.name) for f in fields.values()]
|
||||
assert fields.keys() == ['field1', 'field2']
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user