Merge pull request #2 from syrusakbary/django

Django Models integration
This commit is contained in:
Syrus Akbary 2015-10-01 01:57:52 -07:00
commit b5f49b1014
41 changed files with 1315 additions and 217 deletions

View File

@ -4,6 +4,8 @@ python:
- 2.7 - 2.7
install: install:
- pip install pytest pytest-cov coveralls flake8 six blinker - pip install pytest pytest-cov coveralls flake8 six blinker
# - pip install -e .[django] # TODO: Commented until graphqllib is in pypi
- pip install Django>=1.8.0 pytest-django singledispatch>=3.4.0.3
- pip install git+https://github.com/dittos/graphqllib.git # Last version of graphqllib - pip install git+https://github.com/dittos/graphqllib.git # Last version of graphqllib
- pip install graphql-relay - pip install graphql-relay
- python setup.py develop - python setup.py develop

View File

@ -35,4 +35,4 @@ from graphene.decorators import (
resolve_only_args resolve_only_args
) )
import graphene.relay # import graphene.relay

View File

View File

@ -0,0 +1,4 @@
from graphene.contrib.django.types import (
DjangoObjectType,
DjangoNode
)

View File

@ -0,0 +1,69 @@
from singledispatch import singledispatch
from django.db import models
from graphene.core.fields import (
StringField,
IDField,
IntField,
BooleanField,
FloatField,
ListField
)
from graphene.contrib.django.fields import ConnectionOrListField, DjangoModelField
@singledispatch
def convert_django_field(field, cls):
raise Exception("Don't know how to convert the Django field %s (%s)" % (field, field.__class__))
@convert_django_field.register(models.DateField)
@convert_django_field.register(models.CharField)
@convert_django_field.register(models.TextField)
@convert_django_field.register(models.EmailField)
@convert_django_field.register(models.SlugField)
def _(field, cls):
return StringField(description=field.description)
@convert_django_field.register(models.AutoField)
def _(field, cls):
return IDField(description=field.description)
@convert_django_field.register(models.PositiveIntegerField)
@convert_django_field.register(models.PositiveSmallIntegerField)
@convert_django_field.register(models.SmallIntegerField)
@convert_django_field.register(models.BigIntegerField)
@convert_django_field.register(models.URLField)
@convert_django_field.register(models.UUIDField)
@convert_django_field.register(models.IntegerField)
def _(field, cls):
return IntField(description=field.description)
@convert_django_field.register(models.BooleanField)
def _(field, cls):
return BooleanField(description=field.description, null=False)
@convert_django_field.register(models.NullBooleanField)
def _(field, cls):
return BooleanField(description=field.description)
@convert_django_field.register(models.FloatField)
def _(field, cls):
return FloatField(description=field.description)
@convert_django_field.register(models.ManyToManyField)
@convert_django_field.register(models.ManyToOneRel)
def _(field, cls):
model_field = DjangoModelField(field.related_model)
return ConnectionOrListField(model_field)
@convert_django_field.register(models.OneToOneField)
@convert_django_field.register(models.ForeignKey)
def _(field, cls):
return DjangoModelField(field.related_model, description=field.description)

View File

@ -0,0 +1,62 @@
from graphene.core.fields import (
ListField
)
from graphene import relay
from graphene.core.fields import Field, LazyField
from graphene.utils import cached_property, memoize
from graphene.env import get_global_schema
from graphene.relay.types import BaseNode
from django.db.models.query import QuerySet
from django.db.models.manager import Manager
def get_type_for_model(schema, model):
schema = schema or get_global_schema()
types = schema.types.values()
for _type in types:
type_model = hasattr(_type,'_meta') and getattr(_type._meta, 'model', None)
if model == type_model:
return _type
class DjangoConnectionField(relay.ConnectionField):
def wrap_resolved(self, value, instance, args, info):
if isinstance(value, (QuerySet, Manager)):
cls = instance.__class__
value = [cls(s) for s in value.all()]
return value
class ConnectionOrListField(LazyField):
@memoize
def get_field(self, schema):
model_field = self.field_type
field_object_type = model_field.get_object_type(schema)
if field_object_type and issubclass(field_object_type, BaseNode):
field = DjangoConnectionField(model_field)
else:
field = ListField(model_field)
field.contribute_to_class(self.object_type, self.field_name)
return field
class DjangoModelField(Field):
def __init__(self, model, *args, **kwargs):
super(DjangoModelField, self).__init__(None, *args, **kwargs)
self.model = model
@memoize
def internal_type(self, schema):
_type = self.get_object_type(schema)
return _type and _type.internal_type(schema)
def get_object_type(self, schema):
_type = get_type_for_model(schema, self.model)
if not _type and self.object_type._meta.only_fields:
# We will only raise the exception if the related field is specified in only_fields
raise Exception("Field %s (%s) model not mapped in current schema" % (self, self.model._meta.object_name))
return _type

View File

@ -0,0 +1,24 @@
import inspect
from django.db import models
from graphene.core.options import Options
VALID_ATTRS = ('model', 'only_fields')
from graphene.relay.types import Node, BaseNode
class DjangoOptions(Options):
def __init__(self, *args, **kwargs):
self.model = None
super(DjangoOptions, self).__init__(*args, **kwargs)
self.valid_attrs += VALID_ATTRS
self.only_fields = None
def contribute_to_class(self, cls, name):
super(DjangoOptions, self).contribute_to_class(cls, name)
if cls.__name__ == 'DjangoNode':
return
if not self.model:
raise Exception('Django ObjectType %s must have a model in the Meta class attr' % cls)
elif not inspect.isclass(self.model) or not issubclass(self.model, models.Model):
raise Exception('Provided model in %s is not a Django model' % cls)

View File

@ -0,0 +1,45 @@
import six
from django.db import models
from graphene.core.types import ObjectTypeMeta, BaseObjectType
from graphene.contrib.django.options import DjangoOptions
from graphene.contrib.django.converter import convert_django_field
from graphene.relay.types import Node, BaseNode
def get_reverse_fields(model):
for name, attr in model.__dict__.items():
related = getattr(attr, 'related', None)
if isinstance(related, models.ManyToOneRel):
yield related
class DjangoObjectTypeMeta(ObjectTypeMeta):
options_cls = DjangoOptions
def is_interface(cls, parents):
return DjangoInterface in parents
def add_extra_fields(cls):
if not cls._meta.model:
return
only_fields = cls._meta.only_fields
reverse_fields = tuple(get_reverse_fields(cls._meta.model))
for field in cls._meta.model._meta.fields + reverse_fields:
if only_fields and field.name not in only_fields:
continue
converted_field = convert_django_field(field, cls)
cls.add_to_class(field.name, converted_field)
class DjangoObjectType(six.with_metaclass(DjangoObjectTypeMeta, BaseObjectType)):
pass
class DjangoInterface(six.with_metaclass(DjangoObjectTypeMeta, BaseObjectType)):
pass
class DjangoNode(BaseNode, DjangoInterface):
pass

View File

@ -8,8 +8,10 @@ from graphql.core.type import (
GraphQLBoolean, GraphQLBoolean,
GraphQLID, GraphQLID,
GraphQLArgument, GraphQLArgument,
GraphQLFloat,
) )
from graphene.utils import cached_property from graphene.utils import cached_property, memoize
from graphene.core.types import BaseObjectType
class Field(object): class Field(object):
def __init__(self, field_type, resolve=None, null=True, args=None, description='', **extra_args): def __init__(self, field_type, resolve=None, null=True, args=None, description='', **extra_args):
@ -25,7 +27,6 @@ class Field(object):
def contribute_to_class(self, cls, name): def contribute_to_class(self, cls, name):
self.field_name = name self.field_name = name
self.object_type = cls self.object_type = cls
self.schema = cls._meta.schema
if isinstance(self.field_type, Field) and not self.field_type.object_type: if isinstance(self.field_type, Field) and not self.field_type.object_type:
self.field_type.contribute_to_class(cls, name) self.field_type.contribute_to_class(cls, name)
cls._meta.add_field(self) cls._meta.add_field(self)
@ -43,41 +44,39 @@ class Field(object):
resolve_fn = lambda root, args, info: root.resolve(self.field_name, args, info) resolve_fn = lambda root, args, info: root.resolve(self.field_name, args, info)
return resolve_fn(instance, args, info) return resolve_fn(instance, args, info)
def get_object_type(self): def get_object_type(self, schema):
from graphene.core.types import ObjectType
field_type = self.field_type field_type = self.field_type
_is_class = inspect.isclass(field_type) _is_class = inspect.isclass(field_type)
if _is_class and issubclass(field_type, ObjectType): if isinstance(field_type, Field):
return field_type.get_object_type(schema)
if _is_class and issubclass(field_type, BaseObjectType):
return field_type return field_type
elif isinstance(field_type, basestring): elif isinstance(field_type, basestring):
if field_type == 'self': if field_type == 'self':
return self.object_type return self.object_type
elif self.schema:
return self.schema.get_type(field_type)
@cached_property
def type(self):
field_type = self.field_type
if isinstance(field_type, Field):
field_type = self.field_type.type
else: else:
object_type = self.get_object_type() return schema.get_type(field_type)
if object_type:
field_type = object_type._meta.type
field_type = self.type_wrapper(field_type)
return field_type
def type_wrapper(self, field_type): def type_wrapper(self, field_type):
if not self.null: if not self.null:
field_type = GraphQLNonNull(field_type) field_type = GraphQLNonNull(field_type)
return field_type return field_type
@cached_property @memoize
def field(self): def internal_type(self, schema):
if not self.field_type: field_type = self.field_type
raise Exception('Must specify a field GraphQL type for the field %s'%self.field_name) if isinstance(field_type, Field):
field_type = self.field_type.internal_type(schema)
else:
object_type = self.get_object_type(schema)
if object_type:
field_type = object_type.internal_type(schema)
field_type = self.type_wrapper(field_type)
return field_type
@memoize
def internal_field(self, schema):
if not self.object_type: if not self.object_type:
raise Exception('Field could not be constructed in a non graphene.Type or graphene.Interface') raise Exception('Field could not be constructed in a non graphene.Type or graphene.Interface')
@ -94,8 +93,10 @@ class Field(object):
','.join(meta_attrs.keys()) ','.join(meta_attrs.keys())
)) ))
internal_type = self.internal_type(schema)
return GraphQLField( return GraphQLField(
self.type, internal_type,
description=self.description, description=self.description,
args=self.args, args=self.args,
resolver=self.resolver, resolver=self.resolver,
@ -119,7 +120,46 @@ class Field(object):
class NativeField(Field): class NativeField(Field):
def __init__(self, field=None): def __init__(self, field=None):
super(NativeField, self).__init__(None) super(NativeField, self).__init__(None)
self.field = field or getattr(self, 'field') self.field = field
def get_field(self, schema):
return self.field
@memoize
def internal_field(self, schema):
return self.get_field(schema)
@memoize
def internal_type(self, schema):
return self.internal_field(schema).type
class LazyField(Field):
@memoize
def inner_field(self, schema):
return self.get_field(schema)
def internal_type(self, schema):
return self.inner_field(schema).internal_type(schema)
def internal_field(self, schema):
return self.inner_field(schema).internal_field(schema)
class LazyNativeField(NativeField):
def __init__(self, *args, **kwargs):
super(LazyNativeField, self).__init__(None, *args, **kwargs)
def get_field(self, schema):
raise NotImplementedError("get_field function not implemented for %s LazyField" % self.__class__)
@memoize
def internal_field(self, schema):
return self.get_field(schema)
@memoize
def internal_type(self, schema):
return self.internal_field(schema).type
class TypeField(Field): class TypeField(Field):
@ -143,6 +183,10 @@ class IDField(TypeField):
field_type = GraphQLID field_type = GraphQLID
class FloatField(TypeField):
field_type = GraphQLFloat
class ListField(Field): class ListField(Field):
def type_wrapper(self, field_type): def type_wrapper(self, field_type):
return GraphQLList(field_type) return GraphQLList(field_type)

View File

@ -1,19 +1,18 @@
from graphene.env import get_global_schema
from graphene.utils import cached_property from graphene.utils import cached_property
DEFAULT_NAMES = ('description', 'name', 'interface', 'schema', DEFAULT_NAMES = ('description', 'name', 'interface',
'type_name', 'interfaces', 'proxy') 'type_name', 'interfaces', 'proxy')
class Options(object): class Options(object):
def __init__(self, meta=None, schema=None): def __init__(self, meta=None):
self.meta = meta self.meta = meta
self.local_fields = [] self.local_fields = []
self.interface = False self.interface = False
self.proxy = False self.proxy = False
self.schema = schema
self.interfaces = [] self.interfaces = []
self.parents = [] self.parents = []
self.valid_attrs = DEFAULT_NAMES
def contribute_to_class(self, cls, name): def contribute_to_class(self, cls, name):
cls._meta = self cls._meta = self
@ -36,7 +35,7 @@ class Options(object):
# over it, so we loop over the *original* dictionary instead. # over it, so we loop over the *original* dictionary instead.
if name.startswith('_'): if name.startswith('_'):
del meta_attrs[name] del meta_attrs[name]
for attr_name in DEFAULT_NAMES: for attr_name in self.valid_attrs:
if attr_name in meta_attrs: if attr_name in meta_attrs:
setattr(self, attr_name, meta_attrs.pop(attr_name)) setattr(self, attr_name, meta_attrs.pop(attr_name))
self.original_attrs[attr_name] = getattr(self, attr_name) self.original_attrs[attr_name] = getattr(self, attr_name)
@ -44,9 +43,13 @@ class Options(object):
setattr(self, attr_name, getattr(self.meta, attr_name)) setattr(self, attr_name, getattr(self.meta, attr_name))
self.original_attrs[attr_name] = getattr(self, attr_name) self.original_attrs[attr_name] = getattr(self, attr_name)
del self.valid_attrs
# Any leftover attributes must be invalid. # Any leftover attributes must be invalid.
if meta_attrs != {}: if meta_attrs != {}:
raise TypeError("'class Meta' got invalid attribute(s): %s" % ','.join(meta_attrs.keys())) raise TypeError("'class Meta' got invalid attribute(s): %s" % ','.join(meta_attrs.keys()))
else:
self.proxy = False
if self.interfaces != [] and self.interface: if self.interfaces != [] and self.interface:
raise Exception("A interface cannot inherit from interfaces") raise Exception("A interface cannot inherit from interfaces")
@ -66,7 +69,3 @@ class Options(object):
@cached_property @cached_property
def fields_map(self): def fields_map(self):
return {f.field_name: f for f in self.fields} return {f.field_name: f for f in self.fields}
@cached_property
def type(self):
return self.parent.get_graphql_type()

View File

@ -1,3 +1,5 @@
from functools import wraps
from graphql.core import graphql from graphql.core import graphql
from graphql.core.type import ( from graphql.core.type import (
GraphQLSchema GraphQLSchema
@ -10,10 +12,10 @@ class Schema(object):
_query = None _query = None
def __init__(self, query=None, mutation=None, name='Schema'): def __init__(self, query=None, mutation=None, name='Schema'):
self._internal_types = {}
self.mutation = mutation self.mutation = mutation
self.query = query self.query = query
self.name = name self.name = name
self._types = {}
signals.init_schema.send(self) signals.init_schema.send(self)
def __repr__(self): def __repr__(self):
@ -25,27 +27,33 @@ class Schema(object):
@query.setter @query.setter
def query(self, query): def query(self, query):
if not query:
return
self._query = query self._query = query
self._query_type = query._meta.type self._query_type = query and query.internal_type(self)
self._schema = GraphQLSchema(query=self._query_type, mutation=self.mutation)
def register_type(self, type): @cached_property
type_name = type._meta.type_name def schema(self):
if type_name in self._types: if not self._query_type:
raise Exception('Type name %s already registered in %r' % (type_name, self)) raise Exception('You have to define a base query type')
self._types[type_name] = type return GraphQLSchema(query=self._query_type, mutation=self.mutation)
def associate_internal_type(self, internal_type, object_type):
self._internal_types[internal_type.name] = object_type
def get_type(self, type_name): def get_type(self, type_name):
if type_name not in self._types: # print 'get_type'
# _type = self.schema.get_type(type_name)
if type_name not in self._internal_types:
raise Exception('Type %s not found in %r' % (type_name, self)) raise Exception('Type %s not found in %r' % (type_name, self))
return self._types[type_name] return self._internal_types[type_name]
@property
def types(self):
return self._internal_types
def execute(self, request='', root=None, vars=None, operation_name=None): def execute(self, request='', root=None, vars=None, operation_name=None):
root = root or object() root = root or object()
return graphql( return graphql(
self._schema, self.schema,
request=request, request=request,
root=self.query(root), root=self.query(root),
vars=vars, vars=vars,
@ -55,9 +63,12 @@ class Schema(object):
def introspect(self): def introspect(self):
return self._schema.get_type_map() return self._schema.get_type_map()
def register_internal_type(fun):
@wraps(fun)
def wrapper(cls, schema):
internal_type = fun(cls, schema)
if isinstance(schema, Schema):
schema.associate_internal_type(internal_type, cls)
return internal_type
@signals.class_prepared.connect return wrapper
def object_type_created(object_type):
schema = object_type._meta.schema
if schema:
schema.register_type(object_type)

View File

@ -8,12 +8,18 @@ from graphql.core.type import (
from graphene import signals from graphene import signals
from graphene.core.options import Options from graphene.core.options import Options
from graphene.utils import memoize
from graphene.core.schema import register_internal_type
class ObjectTypeMeta(type): class ObjectTypeMeta(type):
options_cls = Options
def is_interface(cls, parents):
return Interface in parents
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
super_new = super(ObjectTypeMeta, cls).__new__ super_new = super(ObjectTypeMeta, cls).__new__
parents = [b for b in bases if isinstance(b, ObjectTypeMeta)] parents = [b for b in bases if isinstance(b, cls)]
if not parents: if not parents:
# If this isn't a subclass of Model, don't do anything special. # If this isn't a subclass of Model, don't do anything special.
return super_new(cls, name, bases, attrs) return super_new(cls, name, bases, attrs)
@ -26,19 +32,20 @@ class ObjectTypeMeta(type):
}) })
attr_meta = attrs.pop('Meta', None) attr_meta = attrs.pop('Meta', None)
if not attr_meta: if not attr_meta:
meta = getattr(new_class, 'Meta', None) meta = None
# meta = getattr(new_class, 'Meta', None)
else: else:
meta = attr_meta meta = attr_meta
base_meta = getattr(new_class, '_meta', None) base_meta = getattr(new_class, '_meta', None)
schema = (base_meta and base_meta.schema) new_class.add_to_class('_meta', new_class.options_cls(meta))
new_class.add_to_class('_meta', Options(meta, schema)) new_class._meta.interface = new_class.is_interface(parents)
if base_meta and base_meta.proxy:
new_class._meta.interface = base_meta.interface
# Add all attributes to the class. # Add all attributes to the class.
for obj_name, obj in attrs.items(): for obj_name, obj in attrs.items():
new_class.add_to_class(obj_name, obj) new_class.add_to_class(obj_name, obj)
new_class.add_extra_fields()
new_fields = new_class._meta.local_fields new_fields = new_class._meta.local_fields
field_names = {f.field_name for f in new_fields} field_names = {f.field_name for f in new_fields}
@ -71,6 +78,9 @@ class ObjectTypeMeta(type):
new_class._prepare() new_class._prepare()
return new_class return new_class
def add_extra_fields(cls):
pass
def _prepare(cls): def _prepare(cls):
signals.class_prepared.send(cls) signals.class_prepared.send(cls)
@ -82,13 +92,13 @@ class ObjectTypeMeta(type):
setattr(cls, name, value) setattr(cls, name, value)
class ObjectType(six.with_metaclass(ObjectTypeMeta)): class BaseObjectType(object):
def __new__(cls, instance=None, *args, **kwargs): def __new__(cls, instance=None, *args, **kwargs):
if cls._meta.interface: if cls._meta.interface:
raise Exception("An interface cannot be initialized") raise Exception("An interface cannot be initialized")
if instance == None: if instance == None:
return None return None
return super(ObjectType, cls).__new__(cls, instance, *args, **kwargs) return super(BaseObjectType, cls).__new__(cls, instance, *args, **kwargs)
def __init__(self, instance=None): def __init__(self, instance=None):
signals.pre_init.send(self.__class__, instance=instance) signals.pre_init.send(self.__class__, instance=instance)
@ -117,45 +127,35 @@ class ObjectType(six.with_metaclass(ObjectTypeMeta)):
return True return True
@classmethod @classmethod
def resolve_type(cls, instance, *_): def resolve_type(cls, schema, instance, *_):
return instance._meta.type return instance.internal_type(schema)
@classmethod @classmethod
def get_graphql_type(cls): @memoize
fields = cls._meta.fields_map @register_internal_type
def internal_type(cls, schema):
fields_map = cls._meta.fields_map
fields = lambda: {
name: field.internal_field(schema)
for name, field in fields_map.items()
}
if cls._meta.interface: if cls._meta.interface:
return GraphQLInterfaceType( return GraphQLInterfaceType(
cls._meta.type_name, cls._meta.type_name,
description=cls._meta.description, description=cls._meta.description,
resolve_type=cls.resolve_type, resolve_type=lambda *args, **kwargs: cls.resolve_type(schema, *args, **kwargs),
fields=lambda: {name: field.field for name, field in fields.items()} fields=fields
) )
return GraphQLObjectType( return GraphQLObjectType(
cls._meta.type_name, cls._meta.type_name,
description=cls._meta.description, description=cls._meta.description,
interfaces=[i._meta.type for i in cls._meta.interfaces], interfaces=[i.internal_type(schema) for i in cls._meta.interfaces],
fields=lambda: {name: field.field for name, field in fields.items()} fields=fields
) )
class Interface(ObjectType): class ObjectType(six.with_metaclass(ObjectTypeMeta, BaseObjectType)):
class Meta: pass
interface = True
proxy = True
class Interface(six.with_metaclass(ObjectTypeMeta, BaseObjectType)):
@signals.init_schema.connect pass
def add_types_to_schema(schema):
own_schema = schema
class _Interface(Interface):
class Meta:
schema = own_schema
proxy = True
class _ObjectType(ObjectType):
class Meta:
schema = own_schema
proxy = True
setattr(own_schema, 'Interface', _Interface)
setattr(own_schema, 'ObjectType', _ObjectType)

View File

@ -1,20 +1,10 @@
from graphene.relay.nodes import (
create_node_definitions
)
from graphene.relay.fields import ( from graphene.relay.fields import (
ConnectionField, ConnectionField,
NodeField
) )
import graphene.relay.connections import graphene.relay.connections
from graphene.relay.relay import ( from graphene.relay.types import (
Relay Node
) )
from graphene.env import get_global_schema
schema = get_global_schema()
relay = schema.relay
Node, NodeField = relay.Node, relay.NodeField

View File

@ -1,35 +1,15 @@
import collections
from graphql_relay.node.node import ( from graphql_relay.node.node import (
globalIdField globalIdField
) )
from graphql_relay.connection.connection import (
connectionDefinitions
)
from graphene import signals from graphene import signals
from graphene.relay.fields import NodeIDField
from graphene.core.fields import NativeField from graphene.relay.types import BaseNode, Node
from graphene.relay.utils import get_relay
from graphene.relay.relay import Relay
@signals.class_prepared.connect @signals.class_prepared.connect
def object_type_created(object_type): def object_type_created(object_type):
relay = get_relay(object_type._meta.schema) if issubclass(object_type, BaseNode) and BaseNode not in object_type.__bases__:
if relay and issubclass(object_type, relay.Node):
type_name = object_type._meta.type_name type_name = object_type._meta.type_name
# def getId(*args, **kwargs): field = NodeIDField()
# print '**GET ID', args, kwargs
# return 2
field = NativeField(globalIdField(type_name))
object_type.add_to_class('id', field) object_type.add_to_class('id', field)
assert hasattr(object_type, 'get_node'), 'get_node classmethod not found in %s Node' % type_name assert hasattr(object_type, 'get_node'), 'get_node classmethod not found in %s Node' % type_name
connection = connectionDefinitions(type_name, object_type._meta.type).connectionType
object_type.add_to_class('connection', connection)
@signals.init_schema.connect
def schema_created(schema):
setattr(schema, 'relay', Relay(schema))

View File

@ -6,9 +6,13 @@ from graphql_relay.connection.arrayconnection import (
from graphql_relay.connection.connection import ( from graphql_relay.connection.connection import (
connectionArgs connectionArgs
) )
from graphene.core.fields import Field from graphql_relay.node.node import (
globalIdField
)
from graphene.core.fields import Field, LazyNativeField
from graphene.utils import cached_property from graphene.utils import cached_property
from graphene.relay.utils import get_relay from graphene.utils import memoize
class ConnectionField(Field): class ConnectionField(Field):
@ -16,15 +20,30 @@ class ConnectionField(Field):
super(ConnectionField, self).__init__(field_type, resolve=resolve, super(ConnectionField, self).__init__(field_type, resolve=resolve,
args=connectionArgs, description=description) args=connectionArgs, description=description)
def wrap_resolved(self, value, instance, args, info):
return value
def resolve(self, instance, args, info): def resolve(self, instance, args, info):
resolved = super(ConnectionField, self).resolve(instance, args, info) resolved = super(ConnectionField, self).resolve(instance, args, info)
if resolved: if resolved:
resolved = self.wrap_resolved(resolved, instance, args, info)
assert isinstance(resolved, collections.Iterable), 'Resolved value from the connection field have to be iterable' assert isinstance(resolved, collections.Iterable), 'Resolved value from the connection field have to be iterable'
return connectionFromArray(resolved, args) return connectionFromArray(resolved, args)
@cached_property @memoize
def type(self): def internal_type(self, schema):
object_type = self.get_object_type() from graphene.relay.types import BaseNode
relay = get_relay(object_type._meta.schema) object_type = self.get_object_type(schema)
assert issubclass(object_type, relay.Node), 'Only nodes have connections.' assert issubclass(object_type, BaseNode), 'Only nodes have connections.'
return object_type.connection return object_type.get_connection(schema)
class NodeField(LazyNativeField):
def get_field(self, schema):
from graphene.relay.types import BaseNode
return BaseNode.get_definitions(schema).nodeField
class NodeIDField(LazyNativeField):
def get_field(self, schema):
return globalIdField(self.object_type._meta.type_name)

View File

@ -1,42 +0,0 @@
from graphql_relay.node.node import (
nodeDefinitions,
fromGlobalId
)
from graphene.env import get_global_schema
from graphene.core.types import Interface
from graphene.core.fields import Field, NativeField
def getSchemaNode(schema=None):
def getNode(globalId, *args):
_schema = schema or get_global_schema()
resolvedGlobalId = fromGlobalId(globalId)
_type, _id = resolvedGlobalId.type, resolvedGlobalId.id
object_type = schema.get_type(_type)
return object_type.get_node(_id)
return getNode
def getNodeType(obj):
return obj._meta.type
def create_node_definitions(getNode=None, getNodeType=getNodeType, schema=None):
getNode = getNode or getSchemaNode(schema)
_nodeDefinitions = nodeDefinitions(getNode, getNodeType)
_Interface = getattr(schema,'Interface', Interface)
class Node(_Interface):
@classmethod
def get_graphql_type(cls):
if cls is Node:
# Return only nodeInterface when is the Node Inerface
return _nodeDefinitions.nodeInterface
return super(Node, cls).get_graphql_type()
class NodeField(NativeField):
field = _nodeDefinitions.nodeField
return Node, NodeField

View File

@ -1,14 +0,0 @@
from graphene.relay.nodes import (
create_node_definitions
)
from graphene.relay.fields import (
ConnectionField,
)
class Relay(object):
def __init__(self, schema):
self.schema = schema
self.Node, self.NodeField = create_node_definitions(schema=self.schema)
self.ConnectionField = ConnectionField

49
graphene/relay/types.py Normal file
View File

@ -0,0 +1,49 @@
from graphql_relay.node.node import (
nodeDefinitions,
fromGlobalId
)
from graphql_relay.connection.connection import (
connectionDefinitions
)
from graphene.env import get_global_schema
from graphene.core.types import Interface
from graphene.core.fields import LazyNativeField
from graphene.utils import memoize
def get_node_type(schema, obj):
return obj.internal_type(schema)
def get_node(schema, globalId, *args):
resolvedGlobalId = fromGlobalId(globalId)
_type, _id = resolvedGlobalId.type, resolvedGlobalId.id
object_type = schema.get_type(_type)
return object_type.get_node(_id)
class BaseNode(object):
@classmethod
@memoize
def get_definitions(cls, schema):
return nodeDefinitions(lambda *args: get_node(schema, *args), lambda *args: get_node_type(schema, *args))
@classmethod
@memoize
def get_connection(cls, schema):
_type = cls.internal_type(schema)
type_name = cls._meta.type_name
connection = connectionDefinitions(type_name, _type).connectionType
return connection
@classmethod
def internal_type(cls, schema):
if cls is Node or BaseNode in cls.__bases__:
# Return only nodeInterface when is the Node Inerface
return BaseNode.get_definitions(schema).nodeInterface
return super(BaseNode, cls).internal_type(schema)
class Node(BaseNode, Interface):
pass

View File

@ -1,3 +0,0 @@
def get_relay(schema):
return getattr(schema, 'relay', None)

View File

@ -1,3 +1,5 @@
from functools import wraps
class cached_property(object): class cached_property(object):
""" """
A property that is only computed once per instance and then replaces itself A property that is only computed once per instance and then replaces itself
@ -14,3 +16,17 @@ class cached_property(object):
return self return self
value = obj.__dict__[self.func.__name__] = self.func(obj) value = obj.__dict__[self.func.__name__] = self.func(obj)
return value return value
def memoize(fun):
"""A simple memoize decorator for functions supporting positional args."""
@wraps(fun)
def wrapper(*args, **kwargs):
key = (args, frozenset(sorted(kwargs.items())))
try:
return cache[key]
except KeyError:
ret = cache[key] = fun(*args, **kwargs)
return ret
cache = {}
return wrapper

View File

@ -56,6 +56,7 @@ setup(
extras_require={ extras_require={
'django': [ 'django': [
'Django>=1.8.0,<1.9', 'Django>=1.8.0,<1.9',
'pytest-django',
'singledispatch>=3.4.0.3', 'singledispatch>=3.4.0.3',
], ],
}, },

0
tests/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,17 @@
from datetime import date
from .models import Reporter, Article
r = Reporter(first_name='John', last_name='Smith', email='john@example.com')
r.save()
r2 = Reporter(first_name='Paul', last_name='Jones', email='paul@example.com')
r2.save()
a = Article(id=None, headline="This is a test", pub_date=date(2005, 7, 27), reporter=r)
a.save()
new_article = r.articles.create(headline="John's second story", pub_date=date(2005, 7, 29))
new_article2 = Article(headline="Paul's story", pub_date=date(2006, 1, 17))
r.articles.add(new_article2)

View File

@ -0,0 +1,25 @@
from __future__ import absolute_import
from django.db import models
class Reporter(models.Model):
first_name = models.CharField(max_length=30)
last_name = models.CharField(max_length=30)
email = models.EmailField()
def __str__(self): # __unicode__ on Python 2
return "%s %s" % (self.first_name, self.last_name)
class Meta:
app_label = 'contrib_django'
class Article(models.Model):
headline = models.CharField(max_length=100)
pub_date = models.DateField()
reporter = models.ForeignKey(Reporter, related_name='articles')
def __str__(self): # __unicode__ on Python 2
return self.headline
class Meta:
ordering = ('headline',)
app_label = 'contrib_django'

View File

@ -0,0 +1,170 @@
from py.test import raises
from collections import namedtuple
from pytest import raises
import graphene
from graphene import relay
from graphene.contrib.django import (
DjangoObjectType,
DjangoNode
)
from .models import Reporter, Article
def test_should_raise_if_no_model():
with raises(Exception) as excinfo:
class Character1(DjangoObjectType):
pass
assert 'model in the Meta' in str(excinfo.value)
def test_should_raise_if_model_is_invalid():
with raises(Exception) as excinfo:
class Character2(DjangoObjectType):
class Meta:
model = 1
assert 'not a Django model' in str(excinfo.value)
def test_should_raise_if_model_is_invalid():
with raises(Exception) as excinfo:
class ReporterTypeError(DjangoObjectType):
class Meta:
model = Reporter
only_fields = ('articles', )
schema = graphene.Schema(query=ReporterTypeError)
query = '''
query ReporterQuery {
articles
}
'''
result = schema.execute(query)
assert not result.errors
assert 'articles (Article) model not mapped in current schema' in str(excinfo.value)
def test_should_map_fields_correctly():
class ReporterType2(DjangoObjectType):
class Meta:
model = Reporter
assert ReporterType2._meta.fields_map.keys() == ['articles', 'first_name', 'last_name', 'id', 'email']
def test_should_map_fields():
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
class Query2(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
def resolve_reporter(self, *args, **kwargs):
return ReporterType(Reporter(first_name='ABA', last_name='X'))
query = '''
query ReporterQuery {
reporter {
first_name,
last_name,
email
}
}
'''
expected = {
'reporter': {
'first_name': 'ABA',
'last_name': 'X',
'email': ''
}
}
Schema = graphene.Schema(query=Query2)
result = Schema.execute(query)
assert not result.errors
assert result.data == expected
def test_should_map_only_few_fields():
class Reporter2(DjangoObjectType):
class Meta:
model = Reporter
only_fields = ('id', 'email')
assert Reporter2._meta.fields_map.keys() == ['id', 'email']
def test_should_node():
class ReporterNodeType(DjangoNode):
class Meta:
model = Reporter
@classmethod
def get_node(cls, id):
return ReporterNodeType(Reporter(id=2, first_name='Cookie Monster'))
def resolve_articles(self, *args, **kwargs):
return [ArticleNodeType(Article(headline='Hi!'))]
class ArticleNodeType(DjangoNode):
class Meta:
model = Article
@classmethod
def get_node(cls, id):
return ArticleNodeType(Article(id=1, headline='Article node'))
class Query1(graphene.ObjectType):
node = relay.NodeField()
reporter = graphene.Field(ReporterNodeType)
article = graphene.Field(ArticleNodeType)
def resolve_reporter(self, *args, **kwargs):
return ReporterNodeType(Reporter(id=1, first_name='ABA', last_name='X'))
query = '''
query ReporterQuery {
reporter {
id,
first_name,
articles {
edges {
node {
headline
}
}
}
last_name,
email
}
my_article: node(id:"QXJ0aWNsZU5vZGVUeXBlOjE=") {
id
... on ReporterNodeType {
first_name
}
... on ArticleNodeType {
headline
}
}
}
'''
expected = {
'reporter': {
'id': 'UmVwb3J0ZXJOb2RlVHlwZTox',
'first_name': 'ABA',
'last_name': 'X',
'email': '',
'articles': {
'edges': [{
'node': {
'headline': 'Hi!'
}
}]
},
},
'my_article': {
'id': 'QXJ0aWNsZU5vZGVUeXBlOjE=',
'headline': 'Article node'
}
}
Schema = graphene.Schema(query=Query1)
result = Schema.execute(query)
assert not result.errors
assert result.data == expected

View File

@ -0,0 +1,65 @@
from py.test import raises
from collections import namedtuple
from pytest import raises
from graphene.core.fields import (
Field,
StringField,
)
from graphql.core.type import (
GraphQLObjectType,
GraphQLInterfaceType
)
from graphene import Schema
from graphene.contrib.django.types import (
DjangoNode,
DjangoInterface
)
from .models import Reporter, Article
class Character(DjangoInterface):
'''Character description'''
class Meta:
model = Reporter
class Human(DjangoNode):
'''Human description'''
def get_node(self, id):
pass
class Meta:
model = Article
schema = Schema()
def test_django_interface():
assert DjangoNode._meta.interface == True
def test_pseudo_interface():
object_type = Character.internal_type(schema)
assert Character._meta.interface == True
assert isinstance(object_type, GraphQLInterfaceType)
assert Character._meta.model == Reporter
assert object_type.get_fields().keys() == ['articles', 'first_name', 'last_name', 'id', 'email']
def test_interface_resolve_type():
resolve_type = Character.resolve_type(schema, Human(object()))
assert isinstance(resolve_type, GraphQLObjectType)
def test_object_type():
object_type = Human.internal_type(schema)
assert Human._meta.interface == False
assert isinstance(object_type, GraphQLObjectType)
assert object_type.get_fields() == {
'headline': Human._meta.fields_map['headline'].internal_field(schema),
'id': Human._meta.fields_map['id'].internal_field(schema),
'reporter': Human._meta.fields_map['reporter'].internal_field(schema),
'pub_date': Human._meta.fields_map['pub_date'].internal_field(schema),
}
assert object_type.get_interfaces() == [DjangoNode.internal_type(schema)]

View File

@ -28,34 +28,65 @@ ot = ObjectType()
ObjectType._meta.contribute_to_class(ObjectType, '_meta') ObjectType._meta.contribute_to_class(ObjectType, '_meta')
class Schema(object):
pass
schema = Schema()
def test_field_no_contributed_raises_error(): def test_field_no_contributed_raises_error():
f = Field(GraphQLString) f = Field(GraphQLString)
with raises(Exception) as excinfo: with raises(Exception) as excinfo:
f.field f.internal_field(schema)
def test_field_type(): def test_field_type():
f = Field(GraphQLString) f = Field(GraphQLString)
f.contribute_to_class(ot, 'field_name') f.contribute_to_class(ot, 'field_name')
assert isinstance(f.field, GraphQLField) assert isinstance(f.internal_field(schema), GraphQLField)
assert f.type == GraphQLString assert f.internal_type(schema) == GraphQLString
def test_stringfield_type(): def test_stringfield_type():
f = StringField() f = StringField()
f.contribute_to_class(ot, 'field_name') f.contribute_to_class(ot, 'field_name')
assert f.type == GraphQLString assert f.internal_type(schema) == GraphQLString
def test_stringfield_type_null(): def test_stringfield_type_null():
f = StringField(null=False) f = StringField(null=False)
f.contribute_to_class(ot, 'field_name') f.contribute_to_class(ot, 'field_name')
assert isinstance(f.field, GraphQLField) assert isinstance(f.internal_field(schema), GraphQLField)
assert isinstance(f.type, GraphQLNonNull) assert isinstance(f.internal_type(schema), GraphQLNonNull)
def test_field_resolve(): def test_field_resolve():
f = StringField(null=False) f = StringField(null=False, resolve=lambda *args:'RESOLVED')
f.contribute_to_class(ot, 'field_name') f.contribute_to_class(ot, 'field_name')
field_type = f.field field_type = f.internal_field(schema)
field_type.resolver(ot,2,3) assert 'RESOLVED' == field_type.resolver(ot,2,3)
def test_field_resolve_type_custom():
class MyCustomType(object):
pass
class Schema(object):
def get_type(self, name):
if name == 'MyCustomType':
return MyCustomType
s = Schema()
f = Field('MyCustomType')
f.contribute_to_class(ot, 'field_name')
field_type = f.get_object_type(s)
assert field_type == MyCustomType
def test_field_resolve_type_custom():
s = Schema()
f = Field('self')
f.contribute_to_class(ot, 'field_name')
field_type = f.get_object_type(s)
assert field_type == ot

68
tests/core/test_query.py Normal file
View File

@ -0,0 +1,68 @@
from py.test import raises
from collections import namedtuple
from pytest import raises
from graphql.core import graphql
from graphene.core.fields import (
Field,
StringField,
ListField,
)
from graphql.core.type import (
GraphQLObjectType,
GraphQLSchema,
GraphQLInterfaceType
)
from graphene.core.types import (
Interface,
ObjectType
)
class Character(Interface):
name = StringField()
class Pet(ObjectType):
type = StringField(resolve=lambda *_:'Dog')
class Human(Character):
friends = ListField(Character)
pet = Field(Pet)
def resolve_name(self, *args):
return 'Peter'
def resolve_friend(self, *args):
return Human(object())
def resolve_pet(self, *args):
return Pet(object())
# def resolve_friends(self, *args, **kwargs):
# return 'HEY YOU!'
schema = object()
Human_type = Human.internal_type(schema)
def test_query():
schema = GraphQLSchema(query=Human_type)
query = '''
{
name
pet {
type
}
}
'''
expected = {
'name': 'Peter',
'pet': {
'type':'Dog'
}
}
result = graphql(schema, query, root=Human(object()))
assert not result.errors
assert result.data == expected

103
tests/core/test_schema.py Normal file
View File

@ -0,0 +1,103 @@
from py.test import raises
from collections import namedtuple
from pytest import raises
from graphql.core import graphql
from graphene.core.fields import (
Field,
StringField,
ListField,
)
from graphql.core.type import (
GraphQLObjectType,
GraphQLSchema,
GraphQLInterfaceType
)
from graphene import (
Interface,
ObjectType,
Schema
)
schema = Schema(name='My own schema')
class Character(Interface):
name = StringField()
class Pet(ObjectType):
type = StringField(resolve=lambda *_:'Dog')
class Human(Character):
friends = ListField(Character)
pet = Field(Pet)
def resolve_name(self, *args):
return 'Peter'
def resolve_friend(self, *args):
return Human(object())
def resolve_pet(self, *args):
return Pet(object())
schema.query = Human
def test_get_registered_type():
assert schema.get_type('Character') == Character
def test_get_unregistered_type():
with raises(Exception) as excinfo:
schema.get_type('NON_EXISTENT_MODEL')
assert 'not found' in str(excinfo.value)
def test_schema_query():
assert schema.query == Human
def test_query_schema_graphql():
a = object()
query = '''
{
name
pet {
type
}
}
'''
expected = {
'name': 'Peter',
'pet': {
'type':'Dog'
}
}
result = graphql(schema.schema, query, root=Human(object()))
assert not result.errors
assert result.data == expected
def test_query_schema_execute():
a = object()
query = '''
{
name
pet {
type
}
}
'''
expected = {
'name': 'Peter',
'pet': {
'type':'Dog'
}
}
result = schema.execute(query, root=object())
assert not result.errors
assert result.data == expected
def test_schema_get_type_map():
assert schema.schema.get_type_map().keys() == ['__Field', 'String', 'Pet', 'Character', '__InputValue', '__Directive', '__TypeKind', '__Schema', '__Type', 'Human', '__EnumValue', 'Boolean']

View File

@ -15,31 +15,43 @@ from graphene.core.types import (
ObjectType ObjectType
) )
class Character(Interface): class Character(Interface):
'''Character description''' '''Character description'''
name = StringField() name = StringField()
class Meta: class Meta:
type_name = 'core.Character' type_name = 'core.Character'
class Human(Character): class Human(Character):
'''Human description''' '''Human description'''
friends = StringField() friends = StringField()
class Meta: class Meta:
type_name = 'core.Human' type_name = 'core.Human'
schema = object()
def test_interface(): def test_interface():
object_type = Character._meta.type object_type = Character.internal_type(schema)
assert Character._meta.interface == True assert Character._meta.interface == True
assert Character._meta.type_name == 'core.Character'
assert isinstance(object_type, GraphQLInterfaceType) assert isinstance(object_type, GraphQLInterfaceType)
assert Character._meta.type_name == 'core.Character'
assert object_type.description == 'Character description' assert object_type.description == 'Character description'
assert object_type.get_fields() == {'name': Character._meta.fields_map['name'].field} assert object_type.get_fields() == {'name': Character._meta.fields_map['name'].internal_field(schema)}
def test_interface_resolve_type():
resolve_type = Character.resolve_type(schema, Human(object()))
assert isinstance(resolve_type, GraphQLObjectType)
def test_object_type(): def test_object_type():
object_type = Human._meta.type object_type = Human.internal_type(schema)
assert Human._meta.interface == False assert Human._meta.interface == False
assert Human._meta.type_name == 'core.Human' assert Human._meta.type_name == 'core.Human'
assert isinstance(object_type, GraphQLObjectType) assert isinstance(object_type, GraphQLObjectType)
assert object_type.description == 'Human description' assert object_type.description == 'Human description'
assert object_type.get_fields() == {'name': Character._meta.fields_map['name'].field, 'friends': Human._meta.fields_map['friends'].field} assert object_type.get_fields() == {'name': Character._meta.fields_map['name'].internal_field(schema), 'friends': Human._meta.fields_map['friends'].internal_field(schema)}
assert object_type.get_interfaces() == [Character._meta.type] assert object_type.get_interfaces() == [Character.internal_type(schema)]

14
tests/django_settings.py Normal file
View File

@ -0,0 +1,14 @@
SECRET_KEY = 1
INSTALLED_APPS = [
'graphene.contrib.django',
'tests.starwars_django',
'tests.contrib_django',
]
DATABASES={
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': 'tests/django.sqlite',
}
}

View File

@ -4,7 +4,6 @@ import graphene
from graphene import relay from graphene import relay
schema = graphene.Schema() schema = graphene.Schema()
relay = schema.relay
class OtherNode(relay.Node): class OtherNode(relay.Node):
name = graphene.StringField() name = graphene.StringField()
@ -22,8 +21,12 @@ def test_field_no_contributed_raises_error():
assert 'get_node' in str(excinfo.value) assert 'get_node' in str(excinfo.value)
def test_node_should_have_connection(): def test_node_should_have_same_connection_always():
assert OtherNode.connection s = object()
connection1 = OtherNode.get_connection(s)
connection2 = OtherNode.get_connection(s)
assert connection1 == connection2
def test_node_should_have_id_field(): def test_node_should_have_id_field():

View File

View File

@ -0,0 +1,101 @@
from collections import namedtuple
from .models import Ship, Faction
def initialize():
rebels = Faction(
id='1',
name='Alliance to Restore the Republic',
)
rebels.save()
empire = Faction(
id='2',
name='Galactic Empire',
)
empire.save()
xwing = Ship(
id='1',
name='X-Wing',
faction=rebels,
)
xwing.save()
ywing = Ship(
id='2',
name='Y-Wing',
faction=rebels,
)
ywing.save()
awing = Ship(
id='3',
name='A-Wing',
faction=rebels,
)
awing.save()
# Yeah, technically it's Corellian. But it flew in the service of the rebels,
# so for the purposes of this demo it's a rebel ship.
falcon = Ship(
id='4',
name='Millenium Falcon',
faction=rebels,
)
falcon.save()
homeOne = Ship(
id='5',
name='Home One',
faction=rebels,
)
homeOne.save()
tieFighter = Ship(
id='6',
name='TIE Fighter',
faction=empire,
)
tieFighter.save()
tieInterceptor = Ship(
id='7',
name='TIE Interceptor',
faction=empire,
)
tieInterceptor.save()
executor = Ship(
id='8',
name='Executor',
faction=empire,
)
executor.save()
def createShip(shipName, factionId):
nextShip = len(data['Ship'].keys())+1
newShip = Ship(
id=str(nextShip),
name=shipName
)
newShip.save()
return newShip
def getShip(_id):
return Ship.objects.get(id=_id)
def getShips():
return Ship.objects.all()
def getFaction(_id):
return Faction.objects.get(id=_id)
def getRebels():
return getFaction(1)
def getEmpire():
return getFaction(2)

View File

@ -0,0 +1,17 @@
from __future__ import absolute_import
from django.db import models
class Faction(models.Model):
name = models.CharField(max_length=50)
def __str__(self):
return self.name
class Ship(models.Model):
name = models.CharField(max_length=50)
faction = models.ForeignKey(Faction, related_name='ships')
def __str__(self):
return self.name

View File

@ -0,0 +1,54 @@
import graphene
from graphene import resolve_only_args, relay
from graphene.contrib.django import (
DjangoObjectType,
DjangoNode
)
from .models import Ship as ShipModel, Faction as FactionModel
from .data import (
getFaction,
getShip,
getShips,
getRebels,
getEmpire,
)
schema = graphene.Schema(name='Starwars Django Relay Schema')
class Ship(DjangoNode):
class Meta:
model = ShipModel
@classmethod
def get_node(cls, id):
return Ship(getShip(id))
class Faction(DjangoNode):
class Meta:
model = FactionModel
@classmethod
def get_node(cls, id):
return Faction(getFaction(id))
class Query(graphene.ObjectType):
rebels = graphene.Field(Faction)
empire = graphene.Field(Faction)
node = relay.NodeField()
ships = relay.ConnectionField(Ship, description='All the ships.')
@resolve_only_args
def resolve_ships(self):
return [Ship(s) for s in getShips()]
@resolve_only_args
def resolve_rebels(self):
return Faction(getRebels())
@resolve_only_args
def resolve_empire(self):
return Faction(getEmpire())
schema.query = Query

View File

@ -0,0 +1,42 @@
import pytest
from graphql.core import graphql
from .models import *
from .schema import schema
from .data import initialize
pytestmark = pytest.mark.django_db
def test_correct_fetch_first_ship_rebels():
initialize()
query = '''
query RebelsShipsQuery {
rebels {
name,
ships(first: 1) {
edges {
node {
name
}
}
}
}
}
'''
expected = {
'rebels': {
'name': 'Alliance to Restore the Republic',
'ships': {
'edges': [
{
'node': {
'name': 'X-Wing'
}
}
]
}
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected

View File

@ -0,0 +1,114 @@
import pytest
from pytest import raises
from graphql.core import graphql
from .data import initialize
from .schema import schema
pytestmark = pytest.mark.django_db
def test_correctly_fetches_id_name_rebels():
initialize()
query = '''
query RebelsQuery {
rebels {
id
name
}
}
'''
expected = {
'rebels': {
'id': 'RmFjdGlvbjox',
'name': 'Alliance to Restore the Republic'
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_correctly_refetches_rebels():
initialize()
query = '''
query RebelsRefetchQuery {
node(id: "RmFjdGlvbjox") {
id
... on Faction {
name
}
}
}
'''
expected = {
'node': {
'id': 'RmFjdGlvbjox',
'name': 'Alliance to Restore the Republic'
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_correctly_fetches_id_name_empire():
initialize()
query = '''
query EmpireQuery {
empire {
id
name
}
}
'''
expected = {
'empire': {
'id': 'RmFjdGlvbjoy',
'name': 'Galactic Empire'
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_correctly_refetches_empire():
initialize()
query = '''
query EmpireRefetchQuery {
node(id: "RmFjdGlvbjoy") {
id
... on Faction {
name
}
}
}
'''
expected = {
'node': {
'id': 'RmFjdGlvbjoy',
'name': 'Galactic Empire'
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_correctly_refetches_xwing():
initialize()
query = '''
query XWingRefetchQuery {
node(id: "U2hpcDox") {
id
... on Ship {
name
}
}
}
'''
expected = {
'node': {
'id': 'U2hpcDox',
'name': 'X-Wing'
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected

View File

@ -8,6 +8,8 @@ from .data import (
getEmpire, getEmpire,
) )
schema = graphene.Schema(name='Starwars Relay Schema')
class Ship(relay.Node): class Ship(relay.Node):
'''A ship in the Star Wars saga''' '''A ship in the Star Wars saga'''
name = graphene.StringField(description='The name of the ship.') name = graphene.StringField(description='The name of the ship.')
@ -45,4 +47,4 @@ class Query(graphene.ObjectType):
return Faction(getEmpire()) return Faction(getEmpire())
schema = graphene.Schema(query=Query, name='Starwars Relay Schema') schema.query = Query

View File

@ -5,6 +5,7 @@ envlist = py27
deps= deps=
pytest>=2.7.2 pytest>=2.7.2
django>=1.8.0,<1.9 django>=1.8.0,<1.9
pytest-django
flake8 flake8
six six
blinker blinker
@ -12,3 +13,6 @@ deps=
commands= commands=
py.test py.test
flake8 flake8
[pytest]
DJANGO_SETTINGS_MODULE = tests.django_settings