mirror of
https://github.com/graphql-python/graphene.git
synced 2024-11-29 13:03:56 +03:00
Improved relay nodes and field copies
This commit is contained in:
parent
522f769cad
commit
51e97510c0
|
@ -4,7 +4,7 @@ from graphene import implements, relay, resolve_only_args
|
||||||
from .data import create_ship, get_empire, get_faction, get_rebels, get_ship
|
from .data import create_ship, get_empire, get_faction, get_rebels, get_ship
|
||||||
|
|
||||||
|
|
||||||
class Ship(graphene.ObjectType, relay.Node):
|
class Ship(relay.Node, graphene.ObjectType):
|
||||||
'''A ship in the Star Wars saga'''
|
'''A ship in the Star Wars saga'''
|
||||||
name = graphene.String(description='The name of the ship.')
|
name = graphene.String(description='The name of the ship.')
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ class Ship(graphene.ObjectType, relay.Node):
|
||||||
return get_ship(id)
|
return get_ship(id)
|
||||||
|
|
||||||
|
|
||||||
class Faction(graphene.ObjectType, relay.Node):
|
class Faction(relay.Node, graphene.ObjectType):
|
||||||
'''A faction in the Star Wars saga'''
|
'''A faction in the Star Wars saga'''
|
||||||
name = graphene.String(description='The name of the faction.')
|
name = graphene.String(description='The name of the faction.')
|
||||||
# ships = relay.ConnectionField(
|
# ships = relay.ConnectionField(
|
||||||
|
|
|
@ -39,14 +39,9 @@ class ClientIDMutationMeta(MutationMeta):
|
||||||
input_fields=input_local_fields,
|
input_fields=input_local_fields,
|
||||||
output_fields=cls._fields(bases, attrs, local_fields),
|
output_fields=cls._fields(bases, attrs, local_fields),
|
||||||
mutate_and_get_payload=cls.mutate_and_get_payload,
|
mutate_and_get_payload=cls.mutate_and_get_payload,
|
||||||
|
|
||||||
input_type_class=partial(GrapheneInputObjectType, graphene_type=cls),
|
|
||||||
input_field_class=InputField,
|
|
||||||
output_type_class=partial(GrapheneObjectType, graphene_type=cls),
|
|
||||||
field_class=Field,
|
|
||||||
)
|
)
|
||||||
cls._meta.graphql_type = field.type
|
cls._meta.graphql_type = field.type
|
||||||
cls.Field = partial(Field.copy_and_extend, field, type=None, _creation_counter=None)
|
cls.Field = partial(Field.copy_and_extend, field, type=field.type, _creation_counter=None)
|
||||||
constructed = super(ClientIDMutationMeta, cls).construct(bases, attrs)
|
constructed = super(ClientIDMutationMeta, cls).construct(bases, attrs)
|
||||||
return constructed
|
return constructed
|
||||||
|
|
||||||
|
|
|
@ -11,25 +11,30 @@ from ..types.interface import GrapheneInterfaceType, Interface, InterfaceTypeMet
|
||||||
class NodeMeta(ObjectTypeMeta):
|
class NodeMeta(ObjectTypeMeta):
|
||||||
|
|
||||||
def construct(cls, bases, attrs):
|
def construct(cls, bases, attrs):
|
||||||
if not cls.is_object_type():
|
is_object_type = cls.is_object_type()
|
||||||
cls.get_node = attrs.pop('get_node')
|
cls = super(NodeMeta, cls).construct(bases, attrs)
|
||||||
|
if not is_object_type:
|
||||||
|
get_node_from_global_id = getattr(cls, 'get_node_from_global_id', None)
|
||||||
|
assert get_node_from_global_id, '{}.get_node method is required by the Node interface.'.format(cls.__name__)
|
||||||
cls.id_resolver = attrs.pop('id_resolver', None)
|
cls.id_resolver = attrs.pop('id_resolver', None)
|
||||||
node_interface, node_field = node_definitions(
|
node_interface, node_field = node_definitions(
|
||||||
cls.get_node,
|
get_node_from_global_id,
|
||||||
id_resolver=cls.id_resolver,
|
id_resolver=cls.id_resolver,
|
||||||
interface_class=partial(GrapheneInterfaceType, graphene_type=cls),
|
|
||||||
field_class=Field,
|
|
||||||
)
|
)
|
||||||
cls._meta.graphql_type = node_interface
|
cls._meta.graphql_type = node_interface
|
||||||
cls.Field = partial(Field.copy_and_extend, node_field, type=None, _creation_counter=None)
|
cls.Field = partial(Field.copy_and_extend, node_field, type=node_field.type, _creation_counter=None)
|
||||||
return super(NodeMeta, cls).construct(bases, attrs)
|
else:
|
||||||
|
# The interface provided by node_definitions is not an instance
|
||||||
|
# of GrapheneInterfaceType, so it will have no graphql_type,
|
||||||
|
# so will not trigger Node.implements
|
||||||
|
cls.implements(cls)
|
||||||
|
return cls
|
||||||
|
|
||||||
class Node(six.with_metaclass(NodeMeta, Interface)):
|
class Node(six.with_metaclass(NodeMeta, Interface)):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def require_get_node(cls):
|
def require_get_node(cls):
|
||||||
return cls == Node
|
return Node._meta.graphql_type in cls._meta.graphql_type._provided_interfaces
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_global_id(cls, global_id):
|
def from_global_id(cls, global_id):
|
||||||
|
@ -44,7 +49,7 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
|
||||||
return cls.to_global_id(info.parent_type.name, getattr(root, 'id', None))
|
return cls.to_global_id(info.parent_type.name, getattr(root, 'id', None))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_node(cls, global_id, context, info):
|
def get_node_from_global_id(cls, global_id, context, info):
|
||||||
try:
|
try:
|
||||||
_type, _id = cls.from_global_id(global_id)
|
_type, _id = cls.from_global_id(global_id)
|
||||||
except:
|
except:
|
||||||
|
|
|
@ -7,7 +7,7 @@ from ...types import ObjectType, Schema, implements
|
||||||
from ...types.scalars import String
|
from ...types.scalars import String
|
||||||
|
|
||||||
|
|
||||||
class MyNode(ObjectType, Node):
|
class MyNode(Node, ObjectType):
|
||||||
|
|
||||||
name = String()
|
name = String()
|
||||||
|
|
||||||
|
@ -25,9 +25,7 @@ schema = Schema(query=RootQuery, types=[MyNode])
|
||||||
|
|
||||||
def test_node_no_get_node():
|
def test_node_no_get_node():
|
||||||
with pytest.raises(AssertionError) as excinfo:
|
with pytest.raises(AssertionError) as excinfo:
|
||||||
class MyNode(ObjectType):
|
class MyNode(Node, ObjectType):
|
||||||
class Meta:
|
|
||||||
interfaces = [Node]
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)
|
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)
|
||||||
|
@ -35,9 +33,8 @@ def test_node_no_get_node():
|
||||||
|
|
||||||
def test_node_no_get_node_with_meta():
|
def test_node_no_get_node_with_meta():
|
||||||
with pytest.raises(AssertionError) as excinfo:
|
with pytest.raises(AssertionError) as excinfo:
|
||||||
class MyNode(ObjectType):
|
class MyNode(Node, ObjectType):
|
||||||
class Meta:
|
pass
|
||||||
interfaces = [Node]
|
|
||||||
|
|
||||||
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)
|
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ from ...types.scalars import String, Int
|
||||||
|
|
||||||
class CustomNode(Node):
|
class CustomNode(Node):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_node(id, context, info):
|
def get_node_from_global_id(id, context, info):
|
||||||
assert info.schema == schema
|
assert info.schema == schema
|
||||||
if id in user_data:
|
if id in user_data:
|
||||||
return user_data.get(id)
|
return user_data.get(id)
|
||||||
|
@ -15,11 +15,11 @@ class CustomNode(Node):
|
||||||
return photo_data.get(id)
|
return photo_data.get(id)
|
||||||
|
|
||||||
|
|
||||||
class User(ObjectType, CustomNode):
|
class User(CustomNode, ObjectType):
|
||||||
name = String()
|
name = String()
|
||||||
|
|
||||||
|
|
||||||
class Photo(ObjectType, CustomNode):
|
class Photo(CustomNode, ObjectType):
|
||||||
width = Int()
|
width = Int()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -116,20 +116,40 @@ class Field(AbstractField, GraphQLField, OrderedType):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def copy_and_extend(cls, field, type=None, args=None, resolver=None, source=None, deprecation_reason=None, name=None, description=None, required=False, _creation_counter=False, **extra_args):
|
def copy_and_extend(cls, field, type=None, args=None, resolver=None, source=None, deprecation_reason=None, name=None, description=None, required=False, _creation_counter=False, **extra_args):
|
||||||
|
if isinstance(field, Field):
|
||||||
|
type = type or field._type
|
||||||
|
resolver = resolver or field._resolver
|
||||||
|
source = source or field.source
|
||||||
|
name = name or field._name
|
||||||
|
required = required or field.required
|
||||||
|
_creation_counter = field.creation_counter if _creation_counter is False else None
|
||||||
|
attname = field.attname
|
||||||
|
parent = field.parent
|
||||||
|
else:
|
||||||
|
# If is a GraphQLField
|
||||||
|
type = type or field.type
|
||||||
|
resolver = resolver or field.resolver
|
||||||
|
source = None
|
||||||
|
name = field.name
|
||||||
|
required = None
|
||||||
|
_creation_counter = None
|
||||||
|
attname = name
|
||||||
|
parent = None
|
||||||
|
|
||||||
new_field = cls(
|
new_field = cls(
|
||||||
type=type or field._type,
|
type=type,
|
||||||
args=to_arguments(args, field.args),
|
args=to_arguments(args, field.args),
|
||||||
resolver=field._resolver,
|
resolver=resolver,
|
||||||
source=source or getattr(field, 'source', None),
|
source=source,
|
||||||
deprecation_reason=field.deprecation_reason,
|
deprecation_reason=field.deprecation_reason,
|
||||||
name=field._name,
|
name=name,
|
||||||
required=required or getattr(field, 'required', False),
|
required=required,
|
||||||
description=field.description,
|
description=field.description,
|
||||||
_creation_counter=getattr(field, 'creation_counter', None) if _creation_counter is False else None,
|
_creation_counter=_creation_counter,
|
||||||
**extra_args
|
**extra_args
|
||||||
)
|
)
|
||||||
new_field.attname = field.attname
|
new_field.attname = attname
|
||||||
new_field.parent = field.parent
|
new_field.parent = parent
|
||||||
return new_field
|
return new_field
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -168,14 +188,30 @@ class InputField(AbstractField, GraphQLInputObjectField, OrderedType):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def copy_and_extend(cls, field, type=None, default_value=None, description=None, name=None, required=False, _creation_counter=False):
|
def copy_and_extend(cls, field, type=None, default_value=None, description=None, name=None, required=False, _creation_counter=False):
|
||||||
|
if isinstance(field, Field):
|
||||||
|
type = type or field._type
|
||||||
|
name = name or field._name
|
||||||
|
required = required or field.required
|
||||||
|
_creation_counter = field.creation_counter if _creation_counter is False else None
|
||||||
|
attname = field.attname
|
||||||
|
parent = field.parent
|
||||||
|
else:
|
||||||
|
# If is a GraphQLField
|
||||||
|
type = type or field.type
|
||||||
|
name = field.name
|
||||||
|
required = None
|
||||||
|
_creation_counter = None
|
||||||
|
attname = None
|
||||||
|
parent = None
|
||||||
|
|
||||||
new_field = cls(
|
new_field = cls(
|
||||||
type=type or field._type,
|
type=type,
|
||||||
name=name or field._name,
|
name=name,
|
||||||
required=required or field.required,
|
required=required,
|
||||||
default_value=default_value or field.default_value,
|
default_value=default_value or field.default_value,
|
||||||
description=description or field.description,
|
description=description or field.description,
|
||||||
_creation_counter=getattr(field, 'creation_counter', None) if _creation_counter is False else None,
|
_creation_counter=_creation_counter,
|
||||||
)
|
)
|
||||||
new_field.attname = field.attname
|
new_field.attname = attname
|
||||||
new_field.parent = field.parent
|
new_field.parent = parent
|
||||||
return new_field
|
return new_field
|
||||||
|
|
|
@ -40,7 +40,10 @@ class Interface(six.with_metaclass(InterfaceTypeMeta)):
|
||||||
abstract = True
|
abstract = True
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
raise Exception("An interface cannot be intitialized")
|
from .objecttype import ObjectType
|
||||||
|
if not isinstance(self, ObjectType):
|
||||||
|
raise Exception("An interface cannot be intitialized")
|
||||||
|
super(Interface, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def implements(cls, object_type):
|
def implements(cls, object_type):
|
||||||
|
|
|
@ -71,6 +71,7 @@ class ObjectTypeMeta(InterfaceTypeMeta):
|
||||||
cls._meta.interfaces,
|
cls._meta.interfaces,
|
||||||
cls.get_interfaces(bases),
|
cls.get_interfaces(bases),
|
||||||
)))
|
)))
|
||||||
|
|
||||||
local_fields = cls._extract_local_fields(attrs)
|
local_fields = cls._extract_local_fields(attrs)
|
||||||
if not cls._meta.graphql_type:
|
if not cls._meta.graphql_type:
|
||||||
cls = super(ObjectTypeMeta, cls).construct(bases, attrs)
|
cls = super(ObjectTypeMeta, cls).construct(bases, attrs)
|
||||||
|
|
|
@ -93,7 +93,6 @@ def test_objecttype_as_container_get_fields():
|
||||||
|
|
||||||
def test_parent_container_get_fields():
|
def test_parent_container_get_fields():
|
||||||
fields = Container._meta.graphql_type.get_fields()
|
fields = Container._meta.graphql_type.get_fields()
|
||||||
print [(f.creation_counter, f.name) for f in fields.values()]
|
|
||||||
assert fields.keys() == ['field1', 'field2']
|
assert fields.keys() == ['field1', 'field2']
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user