Merge pull request #61 from graphql-python/features/classtypes

Refactor ObjectTypes in ClassTypes. Fixed #30
This commit is contained in:
Syrus Akbary 2015-12-02 23:59:51 -08:00
commit 699aebec33
34 changed files with 922 additions and 655 deletions

View File

@ -1,5 +1,5 @@
#!/bin/bash #!/bin/bash
autoflake ./ -r --remove-unused-variables --remove-all-unused-imports --in-place autoflake ./examples/ ./graphene/ -r --remove-unused-variables --remove-all-unused-imports --in-place
autopep8 ./ -r --in-place --experimental --aggressive --max-line-length 120 autopep8 ./examples/ ./graphene/ -r --in-place --experimental --aggressive --max-line-length 120
isort -rc . isort -rc ./examples/ ./graphene/

View File

@ -8,11 +8,15 @@ from graphene.core.schema import (
Schema Schema
) )
from graphene.core.types import ( from graphene.core.classtypes import (
ObjectType, ObjectType,
InputObjectType, InputObjectType,
Interface, Interface,
Mutation, Mutation,
Scalar
)
from graphene.core.types import (
BaseType, BaseType,
LazyType, LazyType,
Argument, Argument,
@ -59,6 +63,7 @@ __all__ = [
'InputObjectType', 'InputObjectType',
'Interface', 'Interface',
'Mutation', 'Mutation',
'Scalar',
'Field', 'Field',
'InputField', 'InputField',
'StringField', 'StringField',

View File

@ -1,7 +1,6 @@
from graphene.contrib.django.types import ( from graphene.contrib.django.types import (
DjangoConnection, DjangoConnection,
DjangoObjectType, DjangoObjectType,
DjangoInterface,
DjangoNode DjangoNode
) )
from graphene.contrib.django.fields import ( from graphene.contrib.django.fields import (
@ -9,5 +8,5 @@ from graphene.contrib.django.fields import (
DjangoModelField DjangoModelField
) )
__all__ = ['DjangoObjectType', 'DjangoInterface', 'DjangoNode', __all__ = ['DjangoObjectType', 'DjangoNode', 'DjangoConnection',
'DjangoConnection', 'DjangoConnectionField', 'DjangoModelField'] 'DjangoConnectionField', 'DjangoModelField']

View File

@ -1,24 +1,15 @@
import inspect from ...core.classtypes.objecttype import ObjectTypeOptions
from django.db import models
from ...core.options import Options
from ...relay.types import Node from ...relay.types import Node
from ...relay.utils import is_node from ...relay.utils import is_node
VALID_ATTRS = ('model', 'only_fields', 'exclude_fields') VALID_ATTRS = ('model', 'only_fields', 'exclude_fields')
def is_base(cls): class DjangoOptions(ObjectTypeOptions):
from graphene.contrib.django.types import DjangoObjectType
return DjangoObjectType in cls.__bases__
class DjangoOptions(Options):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.model = None
super(DjangoOptions, self).__init__(*args, **kwargs) super(DjangoOptions, self).__init__(*args, **kwargs)
self.model = None
self.valid_attrs += VALID_ATTRS self.valid_attrs += VALID_ATTRS
self.only_fields = None self.only_fields = None
self.exclude_fields = [] self.exclude_fields = []
@ -28,11 +19,3 @@ class DjangoOptions(Options):
if is_node(cls): if is_node(cls):
self.exclude_fields = list(self.exclude_fields) + ['id'] self.exclude_fields = list(self.exclude_fields) + ['id']
self.interfaces.append(Node) self.interfaces.append(Node)
if not is_node(cls) and not is_base(cls):
return
if not self.model:
raise Exception(
'Django ObjectType %s must have a model in the Meta class attr' %
cls)
elif not inspect.isclass(self.model) or not issubclass(self.model, models.Model):
raise Exception('Provided model in %s is not a Django model' % cls)

View File

@ -1,9 +1,8 @@
from graphql.core.type import GraphQLInterfaceType, GraphQLObjectType from graphql.core.type import GraphQLObjectType
from mock import patch from mock import patch
from pytest import raises
from graphene import Schema from graphene import Schema
from graphene.contrib.django.types import DjangoInterface, DjangoNode from graphene.contrib.django.types import DjangoNode, DjangoObjectType
from graphene.core.fields import Field from graphene.core.fields import Field
from graphene.core.types.scalars import Int from graphene.core.types.scalars import Int
from graphene.relay.fields import GlobalIDField from graphene.relay.fields import GlobalIDField
@ -14,7 +13,8 @@ from .models import Article, Reporter
schema = Schema() schema = Schema()
class Character(DjangoInterface): @schema.register
class Character(DjangoObjectType):
'''Character description''' '''Character description'''
class Meta: class Meta:
model = Reporter model = Reporter
@ -31,7 +31,7 @@ class Human(DjangoNode):
def test_django_interface(): def test_django_interface():
assert DjangoNode._meta.is_interface is True assert DjangoNode._meta.interface is True
@patch('graphene.contrib.django.tests.models.Article.objects.get', return_value=Article(id=1)) @patch('graphene.contrib.django.tests.models.Article.objects.get', return_value=Article(id=1))
@ -41,17 +41,6 @@ def test_django_get_node(get):
assert human.id == 1 assert human.id == 1
def test_pseudo_interface_registered():
object_type = schema.T(Character)
assert Character._meta.is_interface is True
assert isinstance(object_type, GraphQLInterfaceType)
assert Character._meta.model == Reporter
assert_equal_lists(
object_type.get_fields().keys(),
['articles', 'firstName', 'lastName', 'email', 'pets', 'id']
)
def test_djangonode_idfield(): def test_djangonode_idfield():
idfield = DjangoNode._meta.fields_map['id'] idfield = DjangoNode._meta.fields_map['id']
assert isinstance(idfield, GlobalIDField) assert isinstance(idfield, GlobalIDField)
@ -68,32 +57,21 @@ def test_node_replacedfield():
assert schema.T(idfield).type == schema.T(Int()) assert schema.T(idfield).type == schema.T(Int())
def test_interface_resolve_type(): def test_objecttype_init_none():
resolve_type = Character._resolve_type(schema, Human())
assert isinstance(resolve_type, GraphQLObjectType)
def test_interface_objecttype_init_none():
h = Human() h = Human()
assert h._root is None assert h._root is None
def test_interface_objecttype_init_good(): def test_objecttype_init_good():
instance = Article() instance = Article()
h = Human(instance) h = Human(instance)
assert h._root == instance assert h._root == instance
def test_interface_objecttype_init_unexpected():
with raises(AssertionError) as excinfo:
Human(object())
assert str(excinfo.value) == "Human received a non-compatible instance (object) when expecting Article"
def test_object_type(): def test_object_type():
object_type = schema.T(Human) object_type = schema.T(Human)
Human._meta.fields_map Human._meta.fields_map
assert Human._meta.is_interface is False assert Human._meta.interface is False
assert isinstance(object_type, GraphQLObjectType) assert isinstance(object_type, GraphQLObjectType)
assert_equal_lists( assert_equal_lists(
object_type.get_fields().keys(), object_type.get_fields().keys(),
@ -103,5 +81,5 @@ def test_object_type():
def test_node_notinterface(): def test_node_notinterface():
assert Human._meta.is_interface is False assert Human._meta.interface is False
assert DjangoNode in Human._meta.interfaces assert DjangoNode in Human._meta.interfaces

View File

@ -1,22 +1,19 @@
import six import inspect
from ...core.types import BaseObjectType, ObjectTypeMeta import six
from ...relay.fields import GlobalIDField from django.db import models
from ...relay.types import BaseNode, Connection
from ...core.classtypes.objecttype import ObjectType, ObjectTypeMeta
from ...relay.types import Connection, Node, NodeMeta
from .converter import convert_django_field from .converter import convert_django_field
from .options import DjangoOptions from .options import DjangoOptions
from .utils import get_reverse_fields, maybe_queryset from .utils import get_reverse_fields, maybe_queryset
class DjangoObjectTypeMeta(ObjectTypeMeta): class DjangoObjectTypeMeta(ObjectTypeMeta):
options_cls = DjangoOptions options_class = DjangoOptions
def is_interface(cls, parents): def construct_fields(cls):
return DjangoInterface in parents
def add_extra_fields(cls):
if not cls._meta.model:
return
only_fields = cls._meta.only_fields only_fields = cls._meta.only_fields
reverse_fields = get_reverse_fields(cls._meta.model) reverse_fields = get_reverse_fields(cls._meta.model)
all_fields = sorted(list(cls._meta.model._meta.fields) + all_fields = sorted(list(cls._meta.model._meta.fields) +
@ -35,8 +32,24 @@ class DjangoObjectTypeMeta(ObjectTypeMeta):
converted_field = convert_django_field(field) converted_field = convert_django_field(field)
cls.add_to_class(field.name, converted_field) cls.add_to_class(field.name, converted_field)
def construct(cls, *args, **kwargs):
cls = super(DjangoObjectTypeMeta, cls).construct(*args, **kwargs)
if not cls._meta.abstract:
if not cls._meta.model:
raise Exception(
'Django ObjectType %s must have a model in the Meta class attr' %
cls)
elif not inspect.isclass(cls._meta.model) or not issubclass(cls._meta.model, models.Model):
raise Exception('Provided model in %s is not a Django model' % cls)
class InstanceObjectType(BaseObjectType): cls.construct_fields()
return cls
class InstanceObjectType(ObjectType):
class Meta:
abstract = True
def __init__(self, _root=None): def __init__(self, _root=None):
if _root: if _root:
@ -63,12 +76,9 @@ class InstanceObjectType(BaseObjectType):
class DjangoObjectType(six.with_metaclass( class DjangoObjectType(six.with_metaclass(
DjangoObjectTypeMeta, InstanceObjectType)): DjangoObjectTypeMeta, InstanceObjectType)):
pass
class Meta:
class DjangoInterface(six.with_metaclass( abstract = True
DjangoObjectTypeMeta, InstanceObjectType)):
pass
class DjangoConnection(Connection): class DjangoConnection(Connection):
@ -79,8 +89,21 @@ class DjangoConnection(Connection):
return super(DjangoConnection, cls).from_list(iterable, *args, **kwargs) return super(DjangoConnection, cls).from_list(iterable, *args, **kwargs)
class DjangoNode(BaseNode, DjangoInterface): class DjangoNodeMeta(DjangoObjectTypeMeta, NodeMeta):
id = GlobalIDField() pass
class NodeInstance(Node, InstanceObjectType):
class Meta:
abstract = True
class DjangoNode(six.with_metaclass(
DjangoNodeMeta, NodeInstance)):
class Meta:
abstract = True
@classmethod @classmethod
def get_node(cls, id, info=None): def get_node(cls, id, info=None):

View File

@ -0,0 +1,16 @@
from .inputobjecttype import InputObjectType
from .interface import Interface
from .mutation import Mutation
from .objecttype import ObjectType
from .options import Options
from .scalar import Scalar
from .uniontype import UnionType
__all__ = [
'InputObjectType',
'Interface',
'Mutation',
'ObjectType',
'Options',
'Scalar',
'UnionType']

View File

@ -0,0 +1,133 @@
import copy
import inspect
from collections import OrderedDict
import six
from ..exceptions import SkipField
from .options import Options
class ClassTypeMeta(type):
options_class = Options
def __new__(mcs, name, bases, attrs):
super_new = super(ClassTypeMeta, mcs).__new__
module = attrs.pop('__module__', None)
doc = attrs.pop('__doc__', None)
new_class = super_new(mcs, name, bases, {
'__module__': module,
'__doc__': doc
})
attr_meta = attrs.pop('Meta', None)
if not attr_meta:
meta = getattr(new_class, 'Meta', None)
else:
meta = attr_meta
new_class.add_to_class('_meta', new_class.get_options(meta))
return mcs.construct(new_class, bases, attrs)
def get_options(cls, meta):
return cls.options_class(meta)
def add_to_class(cls, name, value):
# We should call the contribute_to_class method only if it's bound
if not inspect.isclass(value) and hasattr(
value, 'contribute_to_class'):
value.contribute_to_class(cls, name)
else:
setattr(cls, name, value)
def construct(cls, bases, attrs):
# Add all attributes to the class.
for obj_name, obj in attrs.items():
cls.add_to_class(obj_name, obj)
if not cls._meta.abstract:
from ..types import List, NonNull
setattr(cls, 'NonNull', NonNull(cls))
setattr(cls, 'List', List(cls))
return cls
class ClassType(six.with_metaclass(ClassTypeMeta)):
class Meta:
abstract = True
@classmethod
def internal_type(cls, schema):
raise NotImplementedError("Function internal_type not implemented in type {}".format(cls))
class FieldsOptions(Options):
def __init__(self, *args, **kwargs):
super(FieldsOptions, self).__init__(*args, **kwargs)
self.local_fields = []
def add_field(self, field):
self.local_fields.append(field)
@property
def fields(self):
return sorted(self.local_fields)
@property
def fields_map(self):
return OrderedDict([(f.attname, f) for f in self.fields])
class FieldsClassTypeMeta(ClassTypeMeta):
options_class = FieldsOptions
def extend_fields(cls, bases):
new_fields = cls._meta.local_fields
field_names = {f.name: f for f in new_fields}
for base in bases:
if not isinstance(base, FieldsClassTypeMeta):
continue
parent_fields = base._meta.local_fields
for field in parent_fields:
if field.name in field_names and field.type.__class__ != field_names[
field.name].type.__class__:
raise Exception(
'Local field %r in class %r (%r) clashes '
'with field with similar name from '
'Interface %s (%r)' % (
field.name,
cls.__name__,
field.__class__,
base.__name__,
field_names[field.name].__class__)
)
new_field = copy.copy(field)
cls.add_to_class(field.attname, new_field)
def construct(cls, bases, attrs):
cls = super(FieldsClassTypeMeta, cls).construct(bases, attrs)
cls.extend_fields(bases)
return cls
class FieldsClassType(six.with_metaclass(FieldsClassTypeMeta, ClassType)):
class Meta:
abstract = True
@classmethod
def fields_internal_types(cls, schema):
fields = []
for field in cls._meta.fields:
try:
fields.append((field.name, schema.T(field)))
except SkipField:
continue
return OrderedDict(fields)

View File

@ -0,0 +1,25 @@
from functools import partial
from graphql.core.type import GraphQLInputObjectType
from .base import FieldsClassType
class InputObjectType(FieldsClassType):
class Meta:
abstract = True
def __init__(self, *args, **kwargs):
raise Exception("An InputObjectType cannot be initialized")
@classmethod
def internal_type(cls, schema):
if cls._meta.abstract:
raise Exception("Abstract InputObjectTypes don't have a specific type.")
return GraphQLInputObjectType(
cls._meta.type_name,
description=cls._meta.description,
fields=partial(cls.fields_internal_types, schema),
)

View File

@ -0,0 +1,53 @@
from functools import partial
import six
from graphql.core.type import GraphQLInterfaceType
from .base import FieldsClassTypeMeta
from .objecttype import ObjectType, ObjectTypeMeta
class InterfaceMeta(ObjectTypeMeta):
def construct(cls, bases, attrs):
if cls._meta.abstract or Interface in bases:
# Return Interface type
cls = FieldsClassTypeMeta.construct(cls, bases, attrs)
setattr(cls._meta, 'interface', True)
return cls
else:
# Return ObjectType class with all the inherited interfaces
cls = super(InterfaceMeta, cls).construct(bases, attrs)
for interface in bases:
is_interface = issubclass(interface, Interface) and getattr(interface._meta, 'interface', False)
if not is_interface:
continue
cls._meta.interfaces.append(interface)
return cls
class Interface(six.with_metaclass(InterfaceMeta, ObjectType)):
class Meta:
abstract = True
def __init__(self, *args, **kwargs):
if self._meta.interface:
raise Exception("An interface cannot be initialized")
return super(Interface, self).__init__(*args, **kwargs)
@classmethod
def _resolve_type(cls, schema, instance, *args):
return schema.T(instance.__class__)
@classmethod
def internal_type(cls, schema):
if not cls._meta.interface:
return super(Interface, cls).internal_type(schema)
return GraphQLInterfaceType(
cls._meta.type_name,
description=cls._meta.description,
resolve_type=partial(cls._resolve_type, schema),
fields=partial(cls.fields_internal_types, schema)
)

View File

@ -0,0 +1,32 @@
import six
from .objecttype import ObjectType, ObjectTypeMeta
class MutationMeta(ObjectTypeMeta):
def construct(cls, bases, attrs):
input_class = attrs.pop('Input', None)
if input_class:
items = dict(vars(input_class))
items.pop('__dict__', None)
items.pop('__doc__', None)
items.pop('__module__', None)
items.pop('__weakref__', None)
cls.add_to_class('arguments', cls.construct_arguments(items))
cls = super(MutationMeta, cls).construct(bases, attrs)
return cls
def construct_arguments(cls, items):
from ..types.argument import ArgumentsGroup
return ArgumentsGroup(**items)
class Mutation(six.with_metaclass(MutationMeta, ObjectType)):
class Meta:
abstract = True
@classmethod
def get_arguments(cls):
return cls.arguments

View File

@ -0,0 +1,103 @@
from functools import partial
import six
from graphql.core.type import GraphQLObjectType
from graphene import signals
from .base import FieldsClassType, FieldsClassTypeMeta, FieldsOptions
from .uniontype import UnionType
def is_objecttype(cls):
if not issubclass(cls, ObjectType):
return False
return not cls._meta.interface
class ObjectTypeOptions(FieldsOptions):
def __init__(self, *args, **kwargs):
super(ObjectTypeOptions, self).__init__(*args, **kwargs)
self.interface = False
self.interfaces = []
class ObjectTypeMeta(FieldsClassTypeMeta):
def construct(cls, bases, attrs):
cls = super(ObjectTypeMeta, cls).construct(bases, attrs)
if not cls._meta.abstract:
union_types = list(filter(is_objecttype, bases))
if len(union_types) > 1:
meta_attrs = dict(cls._meta.original_attrs, types=union_types)
Meta = type('Meta', (object, ), meta_attrs)
attrs['Meta'] = Meta
attrs['__module__'] = cls.__module__
attrs['__doc__'] = cls.__doc__
return type(cls.__name__, (UnionType, ), attrs)
return cls
options_class = ObjectTypeOptions
class ObjectType(six.with_metaclass(ObjectTypeMeta, FieldsClassType)):
class Meta:
abstract = True
def __init__(self, *args, **kwargs):
signals.pre_init.send(self.__class__, args=args, kwargs=kwargs)
self._root = kwargs.pop('_root', None)
args_len = len(args)
fields = self._meta.fields
if args_len > len(fields):
# Daft, but matches old exception sans the err msg.
raise IndexError("Number of args exceeds number of fields")
fields_iter = iter(fields)
if not kwargs:
for val, field in zip(args, fields_iter):
setattr(self, field.attname, val)
else:
for val, field in zip(args, fields_iter):
setattr(self, field.attname, val)
kwargs.pop(field.attname, None)
for field in fields_iter:
try:
val = kwargs.pop(field.attname)
setattr(self, field.attname, val)
except KeyError:
pass
if kwargs:
for prop in list(kwargs):
try:
if isinstance(getattr(self.__class__, prop), property):
setattr(self, prop, kwargs.pop(prop))
except AttributeError:
pass
if kwargs:
raise TypeError(
"'%s' is an invalid keyword argument for this function" %
list(kwargs)[0])
signals.post_init.send(self.__class__, instance=self)
@classmethod
def internal_type(cls, schema):
if cls._meta.abstract:
raise Exception("Abstract ObjectTypes don't have a specific type.")
return GraphQLObjectType(
cls._meta.type_name,
description=cls._meta.description,
interfaces=list(map(schema.T, cls._meta.interfaces)),
fields=partial(cls.fields_internal_types, schema),
is_type_of=getattr(cls, 'is_type_of', None)
)
@classmethod
def wrap(cls, instance, args, info):
return cls(_root=instance)

View File

@ -1,24 +1,11 @@
from collections import OrderedDict
from ..utils import cached_property
DEFAULT_NAMES = ('description', 'name', 'is_interface', 'is_mutation',
'type_name', 'interfaces', 'abstract')
class Options(object): class Options(object):
def __init__(self, meta=None): def __init__(self, meta=None, **defaults):
self.meta = meta self.meta = meta
self.local_fields = []
self.is_interface = False
self.is_mutation = False
self.is_union = False
self.abstract = False self.abstract = False
self.interfaces = [] for name, value in defaults.items():
self.parents = [] setattr(self, name, value)
self.types = [] self.valid_attrs = list(defaults.keys()) + ['type_name', 'description', 'abstract']
self.valid_attrs = DEFAULT_NAMES
def contribute_to_class(self, cls, name): def contribute_to_class(self, cls, name):
cls._meta = self cls._meta = self
@ -59,14 +46,3 @@ class Options(object):
meta_attrs.keys())) meta_attrs.keys()))
del self.meta del self.meta
def add_field(self, field):
self.local_fields.append(field)
@cached_property
def fields(self):
return sorted(self.local_fields)
@cached_property
def fields_map(self):
return OrderedDict([(f.attname, f) for f in self.fields])

View File

@ -0,0 +1,21 @@
from graphql.core.type import GraphQLScalarType
from ..types.base import MountedType
from .base import ClassType
class Scalar(ClassType, MountedType):
@classmethod
def internal_type(cls, schema):
serialize = getattr(cls, 'serialize')
parse_literal = getattr(cls, 'parse_literal')
parse_value = getattr(cls, 'parse_value')
return GraphQLScalarType(
name=cls._meta.type_name,
description=cls._meta.description,
serialize=serialize,
parse_value=parse_value,
parse_literal=parse_literal
)

View File

@ -1,11 +1,9 @@
from py.test import raises from py.test import raises
from graphene.core.fields import Field from graphene.core.classtypes import Options
from graphene.core.options import Options
class Meta: class Meta:
is_interface = True
type_name = 'Character' type_name = 'Character'
@ -13,19 +11,6 @@ class InvalidMeta:
other_value = True other_value = True
def test_field_added_in_meta():
opt = Options(Meta)
class ObjectType(object):
pass
opt.contribute_to_class(ObjectType, '_meta')
f = Field(None)
f.attname = 'string_field'
opt.add_field(f)
assert f in opt.fields
def test_options_contribute(): def test_options_contribute():
opt = Options(Meta) opt = Options(Meta)

View File

@ -0,0 +1,67 @@
from ...schema import Schema
from ...types import Field, List, NonNull, String
from ..base import ClassType, FieldsClassType
def test_classtype_basic():
class Character(ClassType):
'''Character description'''
assert Character._meta.type_name == 'Character'
assert Character._meta.description == 'Character description'
def test_classtype_advanced():
class Character(ClassType):
class Meta:
type_name = 'OtherCharacter'
description = 'OtherCharacter description'
assert Character._meta.type_name == 'OtherCharacter'
assert Character._meta.description == 'OtherCharacter description'
def test_classtype_definition_list():
class Character(ClassType):
'''Character description'''
assert isinstance(Character.List, List)
assert Character.List.of_type == Character
def test_classtype_definition_nonnull():
class Character(ClassType):
'''Character description'''
assert isinstance(Character.NonNull, NonNull)
assert Character.NonNull.of_type == Character
def test_fieldsclasstype():
f = Field(String())
class Character(FieldsClassType):
field_name = f
assert Character._meta.fields == [f]
def test_fieldsclasstype_fieldtype():
f = Field(String())
class Character(FieldsClassType):
field_name = f
schema = Schema(query=Character)
assert Character.fields_internal_types(schema)['fieldName'] == schema.T(f)
assert Character._meta.fields_map['field_name'] == f
def test_fieldsclasstype_inheritfields():
name_field = Field(String())
last_name_field = Field(String())
class Fields1(FieldsClassType):
name = name_field
class Fields2(Fields1):
last_name = last_name_field
assert list(Fields2._meta.fields_map.keys()) == ['name', 'last_name']

View File

@ -0,0 +1,21 @@
from graphql.core.type import GraphQLInputObjectType
from graphene.core.schema import Schema
from graphene.core.types import String
from ..inputobjecttype import InputObjectType
def test_inputobjecttype():
class InputCharacter(InputObjectType):
'''InputCharacter description'''
name = String()
schema = Schema()
object_type = schema.T(InputCharacter)
assert isinstance(object_type, GraphQLInputObjectType)
assert InputCharacter._meta.type_name == 'InputCharacter'
assert object_type.description == 'InputCharacter description'
assert list(object_type.get_fields().keys()) == ['name']

View File

@ -0,0 +1,86 @@
from graphql.core.type import GraphQLInterfaceType, GraphQLObjectType
from py.test import raises
from graphene.core.schema import Schema
from graphene.core.types import String
from ..interface import Interface
from ..objecttype import ObjectType
def test_interface():
class Character(Interface):
'''Character description'''
name = String()
schema = Schema()
object_type = schema.T(Character)
assert issubclass(Character, Interface)
assert isinstance(object_type, GraphQLInterfaceType)
assert Character._meta.interface
assert Character._meta.type_name == 'Character'
assert object_type.description == 'Character description'
assert list(object_type.get_fields().keys()) == ['name']
def test_interface_cannot_initialize():
class Character(Interface):
pass
with raises(Exception) as excinfo:
Character()
assert 'An interface cannot be initialized' == str(excinfo.value)
def test_interface_inheritance_abstract():
class Character(Interface):
pass
class ShouldBeInterface(Character):
class Meta:
abstract = True
class ShouldBeObjectType(ShouldBeInterface):
pass
assert ShouldBeInterface._meta.interface
assert not ShouldBeObjectType._meta.interface
assert issubclass(ShouldBeObjectType, ObjectType)
def test_interface_inheritance():
class Character(Interface):
pass
class GeneralInterface(Interface):
pass
class ShouldBeObjectType(GeneralInterface, Character):
pass
schema = Schema()
assert Character._meta.interface
assert not ShouldBeObjectType._meta.interface
assert issubclass(ShouldBeObjectType, ObjectType)
assert Character in ShouldBeObjectType._meta.interfaces
assert GeneralInterface in ShouldBeObjectType._meta.interfaces
assert isinstance(schema.T(Character), GraphQLInterfaceType)
assert isinstance(schema.T(ShouldBeObjectType), GraphQLObjectType)
def test_interface_inheritance_non_objects():
class CommonClass(object):
common_attr = True
class Character(CommonClass, Interface):
pass
class ShouldBeObjectType(Character):
pass
assert Character._meta.interface
assert Character.common_attr
assert ShouldBeObjectType.common_attr

View File

@ -0,0 +1,27 @@
from graphql.core.type import GraphQLObjectType
from graphene.core.schema import Schema
from graphene.core.types import String
from ...types.argument import ArgumentsGroup
from ..mutation import Mutation
def test_mutation():
class MyMutation(Mutation):
'''MyMutation description'''
class Input:
arg_name = String()
name = String()
schema = Schema()
object_type = schema.T(MyMutation)
assert MyMutation._meta.type_name == 'MyMutation'
assert isinstance(object_type, GraphQLObjectType)
assert object_type.description == 'MyMutation description'
assert list(object_type.get_fields().keys()) == ['name']
assert MyMutation._meta.fields_map['name'].object_type == MyMutation
assert isinstance(MyMutation.arguments, ArgumentsGroup)
assert 'argName' in MyMutation.arguments

View File

@ -0,0 +1,89 @@
from graphql.core.type import GraphQLObjectType
from py.test import raises
from graphene.core.schema import Schema
from graphene.core.types import String
from ..objecttype import ObjectType
from ..uniontype import UnionType
def test_object_type():
class Human(ObjectType):
'''Human description'''
name = String()
friends = String()
schema = Schema()
object_type = schema.T(Human)
assert Human._meta.type_name == 'Human'
assert isinstance(object_type, GraphQLObjectType)
assert object_type.description == 'Human description'
assert list(object_type.get_fields().keys()) == ['name', 'friends']
assert Human._meta.fields_map['name'].object_type == Human
def test_object_type_container():
class Human(ObjectType):
name = String()
friends = String()
h = Human(name='My name')
assert h.name == 'My name'
def test_object_type_set_properties():
class Human(ObjectType):
name = String()
friends = String()
@property
def readonly_prop(self):
return 'readonly'
@property
def write_prop(self):
return self._write_prop
@write_prop.setter
def write_prop(self, value):
self._write_prop = value
h = Human(readonly_prop='custom', write_prop='custom')
assert h.readonly_prop == 'readonly'
assert h.write_prop == 'custom'
def test_object_type_container_invalid_kwarg():
class Human(ObjectType):
name = String()
with raises(TypeError):
Human(invalid='My name')
def test_object_type_container_too_many_args():
class Human(ObjectType):
name = String()
with raises(IndexError):
Human('Peter', 'No friends :(', None)
def test_object_type_union():
class Human(ObjectType):
name = String()
class Pet(ObjectType):
name = String()
class Thing(Human, Pet):
'''Thing union description'''
my_attr = True
assert issubclass(Thing, UnionType)
assert Thing._meta.types == [Human, Pet]
assert Thing._meta.type_name == 'Thing'
assert Thing._meta.description == 'Thing union description'
assert Thing.my_attr

View File

@ -0,0 +1,32 @@
from graphql.core.type import GraphQLScalarType
from ...schema import Schema
from ..scalar import Scalar
def test_custom_scalar():
import datetime
from graphql.core.language import ast
class DateTimeScalar(Scalar):
'''DateTimeScalar Documentation'''
@staticmethod
def serialize(dt):
return dt.isoformat()
@staticmethod
def parse_literal(node):
if isinstance(node, ast.StringValue):
return datetime.datetime.strptime(
node.value, "%Y-%m-%dT%H:%M:%S.%f")
@staticmethod
def parse_value(value):
return datetime.datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f")
schema = Schema()
scalar_type = schema.T(DateTimeScalar)
assert isinstance(scalar_type, GraphQLScalarType)
assert scalar_type.name == 'DateTimeScalar'
assert scalar_type.description == 'DateTimeScalar Documentation'

View File

@ -0,0 +1,28 @@
from graphql.core.type import GraphQLUnionType
from graphene.core.schema import Schema
from graphene.core.types import String
from ..objecttype import ObjectType
from ..uniontype import UnionType
def test_uniontype():
class Human(ObjectType):
name = String()
class Pet(ObjectType):
name = String()
class Thing(UnionType):
'''Thing union description'''
class Meta:
types = [Human, Pet]
schema = Schema()
object_type = schema.T(Thing)
assert isinstance(object_type, GraphQLUnionType)
assert Thing._meta.type_name == 'Thing'
assert object_type.description == 'Thing union description'
assert object_type.get_possible_types() == [schema.T(Human), schema.T(Pet)]

View File

@ -0,0 +1,40 @@
import six
from graphql.core.type import GraphQLUnionType
from .base import FieldsClassType, FieldsClassTypeMeta, FieldsOptions
class UnionTypeOptions(FieldsOptions):
def __init__(self, *args, **kwargs):
super(UnionTypeOptions, self).__init__(*args, **kwargs)
self.types = []
class UnionTypeMeta(FieldsClassTypeMeta):
options_class = UnionTypeOptions
def get_options(cls, meta):
return cls.options_class(meta, types=[])
class UnionType(six.with_metaclass(UnionTypeMeta, FieldsClassType)):
class Meta:
abstract = True
@classmethod
def _resolve_type(cls, schema, instance, *args):
return schema.T(instance.__class__)
@classmethod
def internal_type(cls, schema):
if cls._meta.abstract:
raise Exception("Abstract ObjectTypes don't have a specific type.")
return GraphQLUnionType(
cls._meta.type_name,
types=list(map(schema.T, cls._meta.types)),
resolve_type=cls._resolve_type,
description=cls._meta.description,
)

View File

@ -10,8 +10,8 @@ from graphql.core.utils.schema_printer import print_schema
from graphene import signals from graphene import signals
from .classtypes.base import ClassType
from .types.base import BaseType from .types.base import BaseType
from .types.objecttype import BaseObjectType
class GraphQLSchema(_GraphQLSchema): class GraphQLSchema(_GraphQLSchema):
@ -42,13 +42,13 @@ class Schema(object):
if not object_type: if not object_type:
return return
if inspect.isclass(object_type) and issubclass( if inspect.isclass(object_type) and issubclass(
object_type, BaseType) or isinstance( object_type, (BaseType, ClassType)) or isinstance(
object_type, BaseType): object_type, BaseType):
if object_type not in self._types: if object_type not in self._types:
internal_type = object_type.internal_type(self) internal_type = object_type.internal_type(self)
self._types[object_type] = internal_type self._types[object_type] = internal_type
is_objecttype = inspect.isclass( is_objecttype = inspect.isclass(
object_type) and issubclass(object_type, BaseObjectType) object_type) and issubclass(object_type, ClassType)
if is_objecttype: if is_objecttype:
self.register(object_type) self.register(object_type)
return self._types[object_type] return self._types[object_type]
@ -90,7 +90,7 @@ class Schema(object):
if name: if name:
objecttype = self._types_names.get(name, None) objecttype = self._types_names.get(name, None)
if objecttype and inspect.isclass( if objecttype and inspect.isclass(
objecttype) and issubclass(objecttype, BaseObjectType): objecttype) and issubclass(objecttype, ClassType):
return objecttype return objecttype
def __str__(self): def __str__(self):

View File

@ -1,7 +1,9 @@
from .base import BaseType, LazyType, OrderedType from .base import BaseType, LazyType, OrderedType
from .argument import Argument, ArgumentsGroup, to_arguments from .argument import Argument, ArgumentsGroup, to_arguments
from .definitions import List, NonNull from .definitions import List, NonNull
from .objecttype import ObjectTypeMeta, BaseObjectType, Interface, ObjectType, Mutation, InputObjectType # Compatibility import
from .objecttype import Interface, ObjectType, Mutation, InputObjectType
from .scalars import String, ID, Boolean, Int, Float, Scalar from .scalars import String, ID, Boolean, Int, Float, Scalar
from .field import Field, InputField from .field import Field, InputField
@ -17,8 +19,6 @@ __all__ = [
'Field', 'Field',
'InputField', 'InputField',
'Interface', 'Interface',
'BaseObjectType',
'ObjectTypeMeta',
'ObjectType', 'ObjectType',
'Mutation', 'Mutation',
'InputObjectType', 'InputObjectType',

View File

@ -104,11 +104,12 @@ class ArgumentType(MirroredType):
class FieldType(MirroredType): class FieldType(MirroredType):
def contribute_to_class(self, cls, name): def contribute_to_class(self, cls, name):
from ..types import BaseObjectType, InputObjectType from ..classtypes.base import FieldsClassType
if issubclass(cls, InputObjectType): from ..classtypes.inputobjecttype import InputObjectType
if issubclass(cls, (InputObjectType)):
inputfield = self.as_inputfield() inputfield = self.as_inputfield()
return inputfield.contribute_to_class(cls, name) return inputfield.contribute_to_class(cls, name)
elif issubclass(cls, BaseObjectType): elif issubclass(cls, (FieldsClassType)):
field = self.as_field() field = self.as_field()
return field.contribute_to_class(cls, name) return field.contribute_to_class(cls, name)

View File

@ -5,7 +5,9 @@ import six
from graphql.core.type import GraphQLField, GraphQLInputObjectField from graphql.core.type import GraphQLField, GraphQLInputObjectField
from ...utils import to_camel_case from ...utils import to_camel_case
from ..types import BaseObjectType, InputObjectType from ..classtypes.base import FieldsClassType
from ..classtypes.inputobjecttype import InputObjectType
from ..classtypes.mutation import Mutation
from .argument import ArgumentsGroup, snake_case_args from .argument import ArgumentsGroup, snake_case_args
from .base import LazyType, MountType, OrderedType from .base import LazyType, MountType, OrderedType
from .definitions import NonNull from .definitions import NonNull
@ -32,7 +34,7 @@ class Field(OrderedType):
def contribute_to_class(self, cls, attname): def contribute_to_class(self, cls, attname):
assert issubclass( assert issubclass(
cls, BaseObjectType), 'Field {} cannot be mounted in {}'.format( cls, (FieldsClassType)), 'Field {} cannot be mounted in {}'.format(
self, cls) self, cls)
if not self.name: if not self.name:
self.name = to_camel_case(attname) self.name = to_camel_case(attname)
@ -69,7 +71,7 @@ class Field(OrderedType):
description = resolver.__doc__ description = resolver.__doc__
type = schema.T(self.get_type(schema)) type = schema.T(self.get_type(schema))
type_objecttype = schema.objecttype(type) type_objecttype = schema.objecttype(type)
if type_objecttype and type_objecttype._meta.is_mutation: if type_objecttype and issubclass(type_objecttype, Mutation):
assert len(arguments) == 0 assert len(arguments) == 0
arguments = type_objecttype.get_arguments() arguments = type_objecttype.get_arguments()
resolver = getattr(type_objecttype, 'mutate') resolver = getattr(type_objecttype, 'mutate')
@ -126,7 +128,7 @@ class InputField(OrderedType):
def contribute_to_class(self, cls, attname): def contribute_to_class(self, cls, attname):
assert issubclass( assert issubclass(
cls, InputObjectType), 'InputField {} cannot be mounted in {}'.format( cls, (InputObjectType)), 'InputField {} cannot be mounted in {}'.format(
self, cls) self, cls)
if not self.name: if not self.name:
self.name = to_camel_case(attname) self.name = to_camel_case(attname)

View File

@ -1,282 +1,3 @@
import copy from ..classtypes import InputObjectType, Interface, Mutation, ObjectType
import inspect
from collections import OrderedDict
from functools import partial
import six __all__ = ['ObjectType', 'Interface', 'Mutation', 'InputObjectType']
from graphql.core.type import (GraphQLInputObjectType, GraphQLInterfaceType,
GraphQLObjectType, GraphQLUnionType)
from graphene import signals
from ..exceptions import SkipField
from ..options import Options
from .argument import ArgumentsGroup
from .base import BaseType
from .definitions import List, NonNull
def is_objecttype(cls):
if not issubclass(cls, BaseObjectType):
return False
_meta = getattr(cls, '_meta', None)
return not(_meta and (_meta.abstract or _meta.is_interface))
class ObjectTypeMeta(type):
options_cls = Options
def is_interface(cls, parents):
return Interface in parents
def is_mutation(cls, parents):
return issubclass(cls, Mutation)
def __new__(cls, name, bases, attrs):
super_new = super(ObjectTypeMeta, cls).__new__
parents = [b for b in bases if isinstance(b, cls)]
if not parents:
# If this isn't a subclass of Model, don't do anything special.
return super_new(cls, name, bases, attrs)
module = attrs.pop('__module__', None)
doc = attrs.pop('__doc__', None)
new_class = super_new(cls, name, bases, {
'__module__': module,
'__doc__': doc
})
attr_meta = attrs.pop('Meta', None)
abstract = getattr(attr_meta, 'abstract', False)
if not attr_meta:
meta = getattr(new_class, 'Meta', None)
else:
meta = attr_meta
base_meta = getattr(new_class, '_meta', None)
new_class.add_to_class('_meta', new_class.options_cls(meta))
new_class._meta.is_interface = new_class.is_interface(parents)
new_class._meta.is_mutation = new_class.is_mutation(parents) or (base_meta and base_meta.is_mutation)
union_types = list(filter(is_objecttype, parents))
new_class._meta.is_union = len(union_types) > 1
new_class._meta.types = union_types
assert not (
new_class._meta.is_interface and new_class._meta.is_mutation)
assert not (
new_class._meta.is_interface and new_class._meta.is_union)
# Add all attributes to the class.
for obj_name, obj in attrs.items():
new_class.add_to_class(obj_name, obj)
if abstract:
new_class._prepare()
return new_class
if new_class._meta.is_mutation:
assert hasattr(
new_class, 'mutate'), "All mutations must implement mutate method"
new_class.add_extra_fields()
new_fields = new_class._meta.local_fields
assert not(new_class._meta.is_union and new_fields), 'An union cannot have extra fields'
field_names = {f.name: f for f in new_fields}
for base in parents:
if not hasattr(base, '_meta'):
# Things without _meta aren't functional models, so they're
# uninteresting parents.
continue
# if base._meta.schema != new_class._meta.schema:
# raise Exception('The parent schema is not the same')
parent_fields = base._meta.local_fields
# Check for clashes between locally declared fields and those
# on the base classes (we cannot handle shadowed fields at the
# moment).
for field in parent_fields:
if field.name in field_names and field.type.__class__ != field_names[
field.name].type.__class__:
raise Exception(
'Local field %r in class %r (%r) clashes '
'with field with similar name from '
'Interface %s (%r)' % (
field.name,
new_class.__name__,
field.__class__,
base.__name__,
field_names[field.name].__class__)
)
new_field = copy.copy(field)
new_class.add_to_class(field.attname, new_field)
new_class._meta.parents.append(base)
if base._meta.is_interface:
new_class._meta.interfaces.append(base)
# new_class._meta.parents.extend(base._meta.parents)
setattr(new_class, 'NonNull', NonNull(new_class))
setattr(new_class, 'List', List(new_class))
new_class._prepare()
return new_class
def add_extra_fields(cls):
pass
def _prepare(cls):
if hasattr(cls, '_prepare_class'):
cls._prepare_class()
signals.class_prepared.send(cls)
def add_to_class(cls, name, value):
# We should call the contribute_to_class method only if it's bound
if not inspect.isclass(value) and hasattr(
value, 'contribute_to_class'):
value.contribute_to_class(cls, name)
else:
setattr(cls, name, value)
class BaseObjectType(BaseType):
def __new__(cls, *args, **kwargs):
if cls._meta.is_interface:
raise Exception("An interface cannot be initialized")
elif cls._meta.is_union:
raise Exception("An union cannot be initialized")
elif cls._meta.abstract:
raise Exception("An abstract ObjectType cannot be initialized")
return super(BaseObjectType, cls).__new__(cls)
def __init__(self, *args, **kwargs):
signals.pre_init.send(self.__class__, args=args, kwargs=kwargs)
self._root = kwargs.pop('_root', None)
args_len = len(args)
fields = self._meta.fields
if args_len > len(fields):
# Daft, but matches old exception sans the err msg.
raise IndexError("Number of args exceeds number of fields")
fields_iter = iter(fields)
if not kwargs:
for val, field in zip(args, fields_iter):
setattr(self, field.attname, val)
else:
for val, field in zip(args, fields_iter):
setattr(self, field.attname, val)
kwargs.pop(field.attname, None)
for field in fields_iter:
try:
val = kwargs.pop(field.attname)
setattr(self, field.attname, val)
except KeyError:
pass
if kwargs:
for prop in list(kwargs):
try:
if isinstance(getattr(self.__class__, prop), property):
setattr(self, prop, kwargs.pop(prop))
except AttributeError:
pass
if kwargs:
raise TypeError(
"'%s' is an invalid keyword argument for this function" %
list(kwargs)[0])
signals.post_init.send(self.__class__, instance=self)
@classmethod
def _resolve_type(cls, schema, instance, *args):
return schema.T(instance.__class__)
@classmethod
def internal_type(cls, schema):
if cls._meta.abstract:
raise Exception("Abstract ObjectTypes don't have a specific type.")
if cls._meta.is_interface:
return GraphQLInterfaceType(
cls._meta.type_name,
description=cls._meta.description,
resolve_type=partial(cls._resolve_type, schema),
fields=partial(cls.get_fields, schema)
)
elif cls._meta.is_union:
return GraphQLUnionType(
cls._meta.type_name,
types=cls._meta.types,
description=cls._meta.description,
)
return GraphQLObjectType(
cls._meta.type_name,
description=cls._meta.description,
interfaces=[schema.T(i) for i in cls._meta.interfaces],
fields=partial(cls.get_fields, schema),
is_type_of=getattr(cls, 'is_type_of', None)
)
@classmethod
def get_fields(cls, schema):
fields = []
for field in cls._meta.fields:
try:
fields.append((field.name, schema.T(field)))
except SkipField:
continue
return OrderedDict(fields)
@classmethod
def wrap(cls, instance, args, info):
return cls(_root=instance)
class Interface(six.with_metaclass(ObjectTypeMeta, BaseObjectType)):
pass
class ObjectType(six.with_metaclass(ObjectTypeMeta, BaseObjectType)):
pass
class Mutation(six.with_metaclass(ObjectTypeMeta, BaseObjectType)):
@classmethod
def _construct_arguments(cls, items):
return ArgumentsGroup(**items)
@classmethod
def _prepare_class(cls):
input_class = getattr(cls, 'Input', None)
if input_class:
items = dict(vars(input_class))
items.pop('__dict__', None)
items.pop('__doc__', None)
items.pop('__module__', None)
items.pop('__weakref__', None)
cls.add_to_class('arguments', cls._construct_arguments(items))
delattr(cls, 'Input')
@classmethod
def get_arguments(cls):
return cls.arguments
class InputObjectType(ObjectType):
@classmethod
def internal_type(cls, schema):
return GraphQLInputObjectType(
cls._meta.type_name,
description=cls._meta.description,
fields=partial(cls.get_fields, schema),
)

View File

@ -1,194 +0,0 @@
from graphql.core.execution.middlewares.utils import (resolver_has_tag,
tag_resolver)
from graphql.core.type import (GraphQLInterfaceType, GraphQLObjectType,
GraphQLUnionType)
from py.test import raises
from graphene.core.schema import Schema
from graphene.core.types import Int, Interface, ObjectType, String
class Character(Interface):
'''Character description'''
name = String()
class Meta:
type_name = 'core_Character'
class Human(Character):
'''Human description'''
friends = String()
class Meta:
type_name = 'core_Human'
@property
def readonly_prop(self):
return 'readonly'
@property
def write_prop(self):
return self._write_prop
@write_prop.setter
def write_prop(self, value):
self._write_prop = value
class Droid(Character):
'''Droid description'''
class CharacterType(Droid, Human):
'''Union Type'''
schema = Schema()
def test_interface():
object_type = schema.T(Character)
assert Character._meta.is_interface is True
assert isinstance(object_type, GraphQLInterfaceType)
assert Character._meta.type_name == 'core_Character'
assert object_type.description == 'Character description'
assert list(object_type.get_fields().keys()) == ['name']
def test_interface_cannot_initialize():
with raises(Exception) as excinfo:
Character()
assert 'An interface cannot be initialized' == str(excinfo.value)
def test_union():
object_type = schema.T(CharacterType)
assert CharacterType._meta.is_union is True
assert isinstance(object_type, GraphQLUnionType)
assert object_type.description == 'Union Type'
def test_union_cannot_initialize():
with raises(Exception) as excinfo:
CharacterType()
assert 'An union cannot be initialized' == str(excinfo.value)
def test_interface_resolve_type():
resolve_type = Character._resolve_type(schema, Human(object()))
assert isinstance(resolve_type, GraphQLObjectType)
def test_object_type():
object_type = schema.T(Human)
assert Human._meta.is_interface is False
assert Human._meta.type_name == 'core_Human'
assert isinstance(object_type, GraphQLObjectType)
assert object_type.description == 'Human description'
assert list(object_type.get_fields().keys()) == ['name', 'friends']
assert object_type.get_interfaces() == [schema.T(Character)]
assert Human._meta.fields_map['name'].object_type == Human
def test_object_type_container():
h = Human(name='My name')
assert h.name == 'My name'
def test_object_type_set_properties():
h = Human(readonly_prop='custom', write_prop='custom')
assert h.readonly_prop == 'readonly'
assert h.write_prop == 'custom'
def test_object_type_container_invalid_kwarg():
with raises(TypeError):
Human(invalid='My name')
def test_object_type_container_too_many_args():
with raises(IndexError):
Human('Peter', 'No friends :(', None)
def test_field_clashes():
with raises(Exception) as excinfo:
class Droid(Character):
name = Int()
assert 'clashes' in str(excinfo.value)
def test_fields_inherited_should_be_different():
assert Character._meta.fields_map['name'] != Human._meta.fields_map['name']
def test_field_mantain_resolver_tags():
class Droid(Character):
name = String()
def resolve_name(self, *args):
return 'My Droid'
tag_resolver(resolve_name, 'test')
field = schema.T(Droid._meta.fields_map['name'])
assert resolver_has_tag(field.resolver, 'test')
def test_type_has_nonnull():
class Droid(Character):
name = String()
assert Droid.NonNull.of_type == Droid
def test_type_has_list():
class Droid(Character):
name = String()
assert Droid.List.of_type == Droid
def test_abstracttype():
class MyObject1(ObjectType):
class Meta:
abstract = True
name1 = String()
class MyObject2(ObjectType):
class Meta:
abstract = True
name2 = String()
class MyObject(MyObject1, MyObject2):
pass
object_type = schema.T(MyObject)
assert list(MyObject._meta.fields_map.keys()) == ['name1', 'name2']
assert MyObject._meta.fields_map['name1'].object_type == MyObject
assert MyObject._meta.fields_map['name2'].object_type == MyObject
assert isinstance(object_type, GraphQLObjectType)
def test_abstracttype_initialize():
class MyAbstractObjectType(ObjectType):
class Meta:
abstract = True
with raises(Exception) as excinfo:
MyAbstractObjectType()
assert 'An abstract ObjectType cannot be initialized' == str(excinfo.value)
def test_abstracttype_type():
class MyAbstractObjectType(ObjectType):
class Meta:
abstract = True
with raises(Exception) as excinfo:
schema.T(MyAbstractObjectType)
assert 'Abstract ObjectTypes don\'t have a specific type.' == str(excinfo.value)

View File

@ -90,11 +90,5 @@ class GlobalIDField(Field):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(GlobalIDField, self).__init__(NonNull(ID()), *args, **kwargs) super(GlobalIDField, self).__init__(NonNull(ID()), *args, **kwargs)
def contribute_to_class(self, cls, name):
from graphene.relay.utils import is_node, is_node_type
in_node = is_node(cls) or is_node_type(cls)
assert in_node, 'GlobalIDField could only be inside a Node, but got %r' % cls
super(GlobalIDField, self).contribute_to_class(cls, name)
def resolver(self, instance, args, info): def resolver(self, instance, args, info):
return instance.to_global_id() return instance.to_global_id()

View File

@ -3,11 +3,14 @@ import warnings
from collections import Iterable from collections import Iterable
from functools import wraps from functools import wraps
import six
from graphql_relay.connection.arrayconnection import connection_from_list from graphql_relay.connection.arrayconnection import connection_from_list
from graphql_relay.node.node import to_global_id from graphql_relay.node.node import to_global_id
from ..core.types import (Boolean, Field, InputObjectType, Interface, List, from ..core.classtypes import InputObjectType, Interface, Mutation, ObjectType
Mutation, ObjectType, String) from ..core.classtypes.interface import InterfaceMeta
from ..core.classtypes.mutation import MutationMeta
from ..core.types import Boolean, Field, List, String
from ..core.types.argument import ArgumentsGroup from ..core.types.argument import ArgumentsGroup
from ..core.types.definitions import NonNull from ..core.types.definitions import NonNull
from ..utils import memoize from ..utils import memoize
@ -83,13 +86,10 @@ class Connection(ObjectType):
return self._connection_data return self._connection_data
class BaseNode(object): class NodeMeta(InterfaceMeta):
@classmethod def construct_get_node(cls):
def _prepare_class(cls): get_node = getattr(cls, 'get_node', None)
from graphene.relay.utils import is_node
if is_node(cls):
get_node = getattr(cls, 'get_node')
assert get_node, 'get_node classmethod not found in %s Node' % cls assert get_node, 'get_node classmethod not found in %s Node' % cls
assert callable(get_node), 'get_node have to be callable' assert callable(get_node), 'get_node have to be callable'
args = 3 args = 3
@ -111,6 +111,20 @@ class BaseNode(object):
setattr(cls, 'get_node', wrapped_node) setattr(cls, 'get_node', wrapped_node)
def construct(cls, *args, **kwargs):
cls = super(NodeMeta, cls).construct(*args, **kwargs)
if not cls._meta.abstract:
cls.construct_get_node()
return cls
class Node(six.with_metaclass(NodeMeta, Interface)):
'''An object with an ID'''
id = GlobalIDField()
class Meta:
abstract = True
def to_global_id(self): def to_global_id(self):
type_name = self._meta.type_name type_name = self._meta.type_name
return to_global_id(type_name, self.id) return to_global_id(type_name, self.id)
@ -127,27 +141,32 @@ class BaseNode(object):
return cls.edge_type return cls.edge_type
class Node(BaseNode, Interface):
'''An object with an ID'''
id = GlobalIDField()
class MutationInputType(InputObjectType): class MutationInputType(InputObjectType):
client_mutation_id = String(required=True) client_mutation_id = String(required=True)
class ClientIDMutation(Mutation): class RelayMutationMeta(MutationMeta):
client_mutation_id = String(required=True)
@classmethod def construct(cls, *args, **kwargs):
def _construct_arguments(cls, items): cls = super(RelayMutationMeta, cls).construct(*args, **kwargs)
if not cls._meta.abstract:
assert hasattr( assert hasattr(
cls, 'mutate_and_get_payload'), 'You have to implement mutate_and_get_payload' cls, 'mutate_and_get_payload'), 'You have to implement mutate_and_get_payload'
return cls
def construct_arguments(cls, items):
new_input_type = type('{}Input'.format( new_input_type = type('{}Input'.format(
cls._meta.type_name), (MutationInputType, ), items) cls._meta.type_name), (MutationInputType, ), items)
cls.add_to_class('input_type', new_input_type) cls.add_to_class('input_type', new_input_type)
return ArgumentsGroup(input=NonNull(new_input_type)) return ArgumentsGroup(input=NonNull(new_input_type))
class ClientIDMutation(six.with_metaclass(RelayMutationMeta, Mutation)):
client_mutation_id = String(required=True)
class Meta:
abstract = True
@classmethod @classmethod
def mutate(cls, instance, args, info): def mutate(cls, instance, args, info):
input = args.get('input') input = args.get('input')

View File

@ -1,10 +1,11 @@
from .types import BaseNode from .types import Node
def is_node(object_type): def is_node(object_type):
return object_type and issubclass( return object_type and issubclass(
object_type, BaseNode) and not is_node_type(object_type) object_type, Node) and not object_type._meta.abstract
def is_node_type(object_type): def is_node_type(object_type):
return BaseNode in object_type.__bases__ return object_type and issubclass(
object_type, Node) and object_type._meta.abstract

View File

@ -2,6 +2,7 @@ try:
from blinker import Signal from blinker import Signal
except ImportError: except ImportError:
class Signal(object): class Signal(object):
def send(self, *args, **kwargs): def send(self, *args, **kwargs):
pass pass