Improved syntax using pep8 style guide

This commit is contained in:
Syrus Akbary 2015-10-02 22:17:51 -07:00
parent 9a3f11b802
commit 176696c1ac
41 changed files with 271 additions and 156 deletions

View File

@ -11,9 +11,11 @@ from graphene.core.fields import (
) )
from graphene.contrib.django.fields import ConnectionOrListField, DjangoModelField from graphene.contrib.django.fields import ConnectionOrListField, DjangoModelField
@singledispatch @singledispatch
def convert_django_field(field, cls): def convert_django_field(field, cls):
raise Exception("Don't know how to convert the Django field %s (%s)" % (field, field.__class__)) raise Exception(
"Don't know how to convert the Django field %s (%s)" % (field, field.__class__))
@convert_django_field.register(models.DateField) @convert_django_field.register(models.DateField)

View File

@ -16,12 +16,14 @@ def get_type_for_model(schema, model):
schema = schema schema = schema
types = schema.types.values() types = schema.types.values()
for _type in types: for _type in types:
type_model = hasattr(_type,'_meta') and getattr(_type._meta, 'model', None) type_model = hasattr(_type, '_meta') and getattr(
_type._meta, 'model', None)
if model == type_model: if model == type_model:
return _type return _type
class DjangoConnectionField(relay.ConnectionField): class DjangoConnectionField(relay.ConnectionField):
def wrap_resolved(self, value, instance, args, info): def wrap_resolved(self, value, instance, args, info):
if isinstance(value, (QuerySet, Manager)): if isinstance(value, (QuerySet, Manager)):
cls = instance.__class__ cls = instance.__class__
@ -30,6 +32,7 @@ class DjangoConnectionField(relay.ConnectionField):
class ConnectionOrListField(LazyField): class ConnectionOrListField(LazyField):
@memoize @memoize
def get_field(self, schema): def get_field(self, schema):
model_field = self.field_type model_field = self.field_type
@ -43,6 +46,7 @@ class ConnectionOrListField(LazyField):
class DjangoModelField(Field): class DjangoModelField(Field):
def __init__(self, model, *args, **kwargs): def __init__(self, model, *args, **kwargs):
super(DjangoModelField, self).__init__(None, *args, **kwargs) super(DjangoModelField, self).__init__(None, *args, **kwargs)
self.model = model self.model = model
@ -55,7 +59,9 @@ class DjangoModelField(Field):
def get_object_type(self, schema): def get_object_type(self, schema):
_type = get_type_for_model(schema, self.model) _type = get_type_for_model(schema, self.model)
if not _type and self.object_type._meta.only_fields: if not _type and self.object_type._meta.only_fields:
# We will only raise the exception if the related field is specified in only_fields # We will only raise the exception if the related field is
raise Exception("Field %s (%s) model not mapped in current schema" % (self, self.model._meta.object_name)) # specified in only_fields
raise Exception("Field %s (%s) model not mapped in current schema" % (
self, self.model._meta.object_name))
return _type return _type

View File

@ -14,6 +14,7 @@ def is_base(cls):
class DjangoOptions(Options): class DjangoOptions(Options):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.model = None self.model = None
super(DjangoOptions, self).__init__(*args, **kwargs) super(DjangoOptions, self).__init__(*args, **kwargs)
@ -25,6 +26,7 @@ class DjangoOptions(Options):
if not is_node(cls) and not is_base(cls): if not is_node(cls) and not is_base(cls):
return return
if not self.model: if not self.model:
raise Exception('Django ObjectType %s must have a model in the Meta class attr' % cls) raise Exception(
'Django ObjectType %s must have a model in the Meta class attr' % cls)
elif not inspect.isclass(self.model) or not issubclass(self.model, models.Model): elif not inspect.isclass(self.model) or not issubclass(self.model, models.Model):
raise Exception('Provided model in %s is not a Django model' % cls) raise Exception('Provided model in %s is not a Django model' % cls)

View File

@ -13,7 +13,9 @@ from graphql.core.type import (
from graphene.utils import cached_property, memoize from graphene.utils import cached_property, memoize
from graphene.core.types import BaseObjectType from graphene.core.types import BaseObjectType
class Field(object): class Field(object):
def __init__(self, field_type, resolve=None, null=True, args=None, description='', **extra_args): def __init__(self, field_type, resolve=None, null=True, args=None, description='', **extra_args):
self.field_type = field_type self.field_type = field_type
self.resolve_fn = resolve self.resolve_fn = resolve
@ -41,7 +43,8 @@ class Field(object):
if self.resolve_fn: if self.resolve_fn:
resolve_fn = self.resolve_fn resolve_fn = self.resolve_fn
else: else:
resolve_fn = lambda root, args, info: root.resolve(self.field_name, args, info) resolve_fn = lambda root, args, info: root.resolve(
self.field_name, args, info)
return resolve_fn(instance, args, info) return resolve_fn(instance, args, info)
def get_object_type(self, schema): def get_object_type(self, schema):
@ -78,7 +81,8 @@ class Field(object):
@memoize @memoize
def internal_field(self, schema): def internal_field(self, schema):
if not self.object_type: if not self.object_type:
raise Exception('Field could not be constructed in a non graphene.Type or graphene.Interface') raise Exception(
'Field could not be constructed in a non graphene.Type or graphene.Interface')
extra_args = self.extra_args.copy() extra_args = self.extra_args.copy()
for arg_name, arg_value in extra_args.items(): for arg_name, arg_value in extra_args.items():
@ -118,6 +122,7 @@ class Field(object):
class NativeField(Field): class NativeField(Field):
def __init__(self, field=None): def __init__(self, field=None):
super(NativeField, self).__init__(None) super(NativeField, self).__init__(None)
self.field = field self.field = field
@ -135,6 +140,7 @@ class NativeField(Field):
class LazyField(Field): class LazyField(Field):
@memoize @memoize
def inner_field(self, schema): def inner_field(self, schema):
return self.get_field(schema) return self.get_field(schema)
@ -147,11 +153,13 @@ class LazyField(Field):
class LazyNativeField(NativeField): class LazyNativeField(NativeField):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(LazyNativeField, self).__init__(None, *args, **kwargs) super(LazyNativeField, self).__init__(None, *args, **kwargs)
def get_field(self, schema): def get_field(self, schema):
raise NotImplementedError("get_field function not implemented for %s LazyField" % self.__class__) raise NotImplementedError(
"get_field function not implemented for %s LazyField" % self.__class__)
@memoize @memoize
def internal_field(self, schema): def internal_field(self, schema):
@ -163,6 +171,7 @@ class LazyNativeField(NativeField):
class TypeField(Field): class TypeField(Field):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TypeField, self).__init__(self.field_type, *args, **kwargs) super(TypeField, self).__init__(self.field_type, *args, **kwargs)
@ -188,10 +197,12 @@ class FloatField(TypeField):
class ListField(Field): class ListField(Field):
def type_wrapper(self, field_type): def type_wrapper(self, field_type):
return GraphQLList(field_type) return GraphQLList(field_type)
class NonNullField(Field): class NonNullField(Field):
def type_wrapper(self, field_type): def type_wrapper(self, field_type):
return GraphQLNonNull(field_type) return GraphQLNonNull(field_type)

View File

@ -5,6 +5,7 @@ DEFAULT_NAMES = ('description', 'name', 'interface',
class Options(object): class Options(object):
def __init__(self, meta=None): def __init__(self, meta=None):
self.meta = meta self.meta = meta
self.local_fields = [] self.local_fields = []
@ -47,7 +48,8 @@ class Options(object):
# Any leftover attributes must be invalid. # Any leftover attributes must be invalid.
if meta_attrs != {}: if meta_attrs != {}:
raise TypeError("'class Meta' got invalid attribute(s): %s" % ','.join(meta_attrs.keys())) raise TypeError(
"'class Meta' got invalid attribute(s): %s" % ','.join(meta_attrs.keys()))
else: else:
self.proxy = False self.proxy = False

View File

@ -63,6 +63,7 @@ class Schema(object):
def introspect(self): def introspect(self):
return self._schema.get_type_map() return self._schema.get_type_map()
def register_internal_type(fun): def register_internal_type(fun):
@wraps(fun) @wraps(fun)
def wrapper(cls, schema): def wrapper(cls, schema):

View File

@ -11,6 +11,7 @@ from graphene.core.options import Options
from graphene.utils import memoize from graphene.utils import memoize
from graphene.core.schema import register_internal_type from graphene.core.schema import register_internal_type
class ObjectTypeMeta(type): class ObjectTypeMeta(type):
options_cls = Options options_cls = Options
@ -68,7 +69,8 @@ class ObjectTypeMeta(type):
raise Exception( raise Exception(
'Local field %r in class %r clashes ' 'Local field %r in class %r clashes '
'with field of similar name from ' 'with field of similar name from '
'base class %r' % (field.field_name, name, base.__name__) 'base class %r' % (
field.field_name, name, base.__name__)
) )
new_class._meta.parents.append(base) new_class._meta.parents.append(base)
if base._meta.interface: if base._meta.interface:
@ -93,6 +95,7 @@ class ObjectTypeMeta(type):
class BaseObjectType(object): class BaseObjectType(object):
def __new__(cls, instance=None, *args, **kwargs): def __new__(cls, instance=None, *args, **kwargs):
if cls._meta.interface: if cls._meta.interface:
raise Exception("An interface cannot be initialized") raise Exception("An interface cannot be initialized")
@ -143,7 +146,8 @@ class BaseObjectType(object):
return GraphQLInterfaceType( return GraphQLInterfaceType(
cls._meta.type_name, cls._meta.type_name,
description=cls._meta.description, description=cls._meta.description,
resolve_type=lambda *args, **kwargs: cls.resolve_type(schema, *args, **kwargs), resolve_type=lambda *
args, **kwargs: cls.resolve_type(schema, *args, **kwargs),
fields=fields fields=fields
) )
return GraphQLObjectType( return GraphQLObjectType(
@ -157,5 +161,6 @@ class BaseObjectType(object):
class ObjectType(six.with_metaclass(ObjectTypeMeta, BaseObjectType)): class ObjectType(six.with_metaclass(ObjectTypeMeta, BaseObjectType)):
pass pass
class Interface(six.with_metaclass(ObjectTypeMeta, BaseObjectType)): class Interface(six.with_metaclass(ObjectTypeMeta, BaseObjectType)):
pass pass

View File

@ -2,6 +2,7 @@ from graphene.core.schema import Schema
_global_schema = None _global_schema = None
def get_global_schema(): def get_global_schema():
global _global_schema global _global_schema
if not _global_schema: if not _global_schema:

View File

@ -13,4 +13,5 @@ def object_type_created(object_type):
type_name = object_type._meta.type_name type_name = object_type._meta.type_name
field = NodeIDField() field = NodeIDField()
object_type.add_to_class('id', field) object_type.add_to_class('id', field)
assert hasattr(object_type, 'get_node'), 'get_node classmethod not found in %s Node' % type_name assert hasattr(
object_type, 'get_node'), 'get_node classmethod not found in %s Node' % type_name

View File

@ -16,6 +16,7 @@ from graphene.utils import memoize
class ConnectionField(Field): class ConnectionField(Field):
def __init__(self, field_type, resolve=None, description=''): def __init__(self, field_type, resolve=None, description=''):
super(ConnectionField, self).__init__(field_type, resolve=resolve, super(ConnectionField, self).__init__(field_type, resolve=resolve,
args=connectionArgs, description=description) args=connectionArgs, description=description)
@ -27,23 +28,27 @@ class ConnectionField(Field):
resolved = super(ConnectionField, self).resolve(instance, args, info) resolved = super(ConnectionField, self).resolve(instance, args, info)
if resolved: if resolved:
resolved = self.wrap_resolved(resolved, instance, args, info) resolved = self.wrap_resolved(resolved, instance, args, info)
assert isinstance(resolved, collections.Iterable), 'Resolved value from the connection field have to be iterable' assert isinstance(
resolved, collections.Iterable), 'Resolved value from the connection field have to be iterable'
return connectionFromArray(resolved, args) return connectionFromArray(resolved, args)
@memoize @memoize
def internal_type(self, schema): def internal_type(self, schema):
from graphene.relay.types import BaseNode from graphene.relay.types import BaseNode
object_type = self.get_object_type(schema) object_type = self.get_object_type(schema)
assert issubclass(object_type, BaseNode), 'Only nodes have connections.' assert issubclass(
object_type, BaseNode), 'Only nodes have connections.'
return object_type.get_connection(schema) return object_type.get_connection(schema)
class NodeField(LazyNativeField): class NodeField(LazyNativeField):
def get_field(self, schema): def get_field(self, schema):
from graphene.relay.types import BaseNode from graphene.relay.types import BaseNode
return BaseNode.get_definitions(schema).nodeField return BaseNode.get_definitions(schema).nodeField
class NodeIDField(LazyNativeField): class NodeIDField(LazyNativeField):
def get_field(self, schema): def get_field(self, schema):
return globalIdField(self.object_type._meta.type_name) return globalIdField(self.object_type._meta.type_name)

View File

@ -23,6 +23,7 @@ def get_node(schema, globalId, *args):
class BaseNode(object): class BaseNode(object):
@classmethod @classmethod
@memoize @memoize
def get_definitions(cls, schema): def get_definitions(cls, schema):

View File

@ -1,7 +1,9 @@
from graphene.relay.types import BaseNode from graphene.relay.types import BaseNode
def is_node(object_type): def is_node(object_type):
return issubclass(object_type, BaseNode) and not is_node_type(object_type) return issubclass(object_type, BaseNode) and not is_node_type(object_type)
def is_node_type(object_type): def is_node_type(object_type):
return BaseNode in object_type.__bases__ return BaseNode in object_type.__bases__

View File

@ -1,6 +1,8 @@
from functools import wraps from functools import wraps
class cached_property(object): class cached_property(object):
""" """
A property that is only computed once per instance and then replaces itself A property that is only computed once per instance and then replaces itself
with an ordinary attribute. Deleting the attribute resets the property. with an ordinary attribute. Deleting the attribute resets the property.

View File

@ -8,10 +8,12 @@ r.save()
r2 = Reporter(first_name='Paul', last_name='Jones', email='paul@example.com') r2 = Reporter(first_name='Paul', last_name='Jones', email='paul@example.com')
r2.save() r2.save()
a = Article(id=None, headline="This is a test", pub_date=date(2005, 7, 27), reporter=r) a = Article(id=None, headline="This is a test",
pub_date=date(2005, 7, 27), reporter=r)
a.save() a.save()
new_article = r.articles.create(headline="John's second story", pub_date=date(2005, 7, 29)) new_article = r.articles.create(
headline="John's second story", pub_date=date(2005, 7, 29))
new_article2 = Article(headline="Paul's story", pub_date=date(2006, 1, 17)) new_article2 = Article(headline="Paul's story", pub_date=date(2006, 1, 17))
r.articles.add(new_article2) r.articles.add(new_article2)

View File

@ -1,6 +1,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from django.db import models from django.db import models
class Reporter(models.Model): class Reporter(models.Model):
first_name = models.CharField(max_length=30) first_name = models.CharField(max_length=30)
last_name = models.CharField(max_length=30) last_name = models.CharField(max_length=30)
@ -12,6 +13,7 @@ class Reporter(models.Model):
class Meta: class Meta:
app_label = 'contrib_django' app_label = 'contrib_django'
class Article(models.Model): class Article(models.Model):
headline = models.CharField(max_length=100) headline = models.CharField(max_length=100)
pub_date = models.DateField() pub_date = models.DateField()

View File

@ -20,6 +20,7 @@ def test_should_raise_if_no_model():
def test_should_raise_if_model_is_invalid(): def test_should_raise_if_model_is_invalid():
with raises(Exception) as excinfo: with raises(Exception) as excinfo:
class Character2(DjangoObjectType): class Character2(DjangoObjectType):
class Meta: class Meta:
model = 1 model = 1
assert 'not a Django model' in str(excinfo.value) assert 'not a Django model' in str(excinfo.value)
@ -28,6 +29,7 @@ def test_should_raise_if_model_is_invalid():
def test_should_raise_if_model_is_invalid(): def test_should_raise_if_model_is_invalid():
with raises(Exception) as excinfo: with raises(Exception) as excinfo:
class ReporterTypeError(DjangoObjectType): class ReporterTypeError(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
only_fields = ('articles', ) only_fields = ('articles', )
@ -41,18 +43,22 @@ def test_should_raise_if_model_is_invalid():
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
assert 'articles (Article) model not mapped in current schema' in str(excinfo.value) assert 'articles (Article) model not mapped in current schema' in str(
excinfo.value)
def test_should_map_fields_correctly(): def test_should_map_fields_correctly():
class ReporterType2(DjangoObjectType): class ReporterType2(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
assert ReporterType2._meta.fields_map.keys() == ['articles', 'first_name', 'last_name', 'id', 'email'] assert ReporterType2._meta.fields_map.keys(
) == ['articles', 'first_name', 'last_name', 'id', 'email']
def test_should_map_fields(): def test_should_map_fields():
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
@ -86,13 +92,16 @@ def test_should_map_fields():
def test_should_map_only_few_fields(): def test_should_map_only_few_fields():
class Reporter2(DjangoObjectType): class Reporter2(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
only_fields = ('id', 'email') only_fields = ('id', 'email')
assert Reporter2._meta.fields_map.keys() == ['id', 'email'] assert Reporter2._meta.fields_map.keys() == ['id', 'email']
def test_should_node(): def test_should_node():
class ReporterNodeType(DjangoNode): class ReporterNodeType(DjangoNode):
class Meta: class Meta:
model = Reporter model = Reporter
@ -104,6 +113,7 @@ def test_should_node():
return [ArticleNodeType(Article(headline='Hi!'))] return [ArticleNodeType(Article(headline='Hi!'))]
class ArticleNodeType(DjangoNode): class ArticleNodeType(DjangoNode):
class Meta: class Meta:
model = Article model = Article

View File

@ -20,13 +20,16 @@ from .models import Reporter, Article
class Character(DjangoInterface): class Character(DjangoInterface):
'''Character description''' '''Character description'''
class Meta: class Meta:
model = Reporter model = Reporter
class Human(DjangoNode): class Human(DjangoNode):
'''Human description''' '''Human description'''
def get_node(self, id): def get_node(self, id):
pass pass
@ -39,12 +42,14 @@ schema = Schema()
def test_django_interface(): def test_django_interface():
assert DjangoNode._meta.interface == True assert DjangoNode._meta.interface == True
def test_pseudo_interface(): def test_pseudo_interface():
object_type = Character.internal_type(schema) object_type = Character.internal_type(schema)
assert Character._meta.interface == True assert Character._meta.interface == True
assert isinstance(object_type, GraphQLInterfaceType) assert isinstance(object_type, GraphQLInterfaceType)
assert Character._meta.model == Reporter assert Character._meta.model == Reporter
assert object_type.get_fields().keys() == ['articles', 'first_name', 'last_name', 'id', 'email'] assert object_type.get_fields().keys() == [
'articles', 'first_name', 'last_name', 'id', 'email']
def test_interface_resolve_type(): def test_interface_resolve_type():

View File

@ -17,10 +17,13 @@ from graphql.core.type import (
GraphQLID, GraphQLID,
) )
class ObjectType(object): class ObjectType(object):
_meta = Options() _meta = Options()
def resolve(self, *args, **kwargs): def resolve(self, *args, **kwargs):
return None return None
def can_resolve(self, *args): def can_resolve(self, *args):
return True return True
@ -28,11 +31,13 @@ ot = ObjectType()
ObjectType._meta.contribute_to_class(ObjectType, '_meta') ObjectType._meta.contribute_to_class(ObjectType, '_meta')
class Schema(object): class Schema(object):
pass pass
schema = Schema() schema = Schema()
def test_field_no_contributed_raises_error(): def test_field_no_contributed_raises_error():
f = Field(GraphQLString) f = Field(GraphQLString)
with raises(Exception) as excinfo: with raises(Exception) as excinfo:
@ -71,6 +76,7 @@ def test_field_resolve_type_custom():
pass pass
class Schema(object): class Schema(object):
def get_type(self, name): def get_type(self, name):
if name == 'MyCustomType': if name == 'MyCustomType':
return MyCustomType return MyCustomType

View File

@ -8,13 +8,16 @@ from graphene.core.fields import (
from graphene.core.options import Options from graphene.core.options import Options
class Meta: class Meta:
interface = True interface = True
type_name = 'Character' type_name = 'Character'
class InvalidMeta: class InvalidMeta:
other_value = True other_value = True
def test_field_added_in_meta(): def test_field_added_in_meta():
opt = Options(Meta) opt = Options(Meta)
@ -27,6 +30,7 @@ def test_field_added_in_meta():
opt.add_field(f) opt.add_field(f)
assert f in opt.fields assert f in opt.fields
def test_options_contribute(): def test_options_contribute():
opt = Options(Meta) opt = Options(Meta)
@ -36,6 +40,7 @@ def test_options_contribute():
opt.contribute_to_class(ObjectType, '_meta') opt.contribute_to_class(ObjectType, '_meta')
assert ObjectType._meta == opt assert ObjectType._meta == opt
def test_options_typename(): def test_options_typename():
opt = Options(Meta) opt = Options(Meta)
@ -45,16 +50,19 @@ def test_options_typename():
opt.contribute_to_class(ObjectType, '_meta') opt.contribute_to_class(ObjectType, '_meta')
assert opt.type_name == 'Character' assert opt.type_name == 'Character'
def test_options_description(): def test_options_description():
opt = Options(Meta) opt = Options(Meta)
class ObjectType(object): class ObjectType(object):
'''False description''' '''False description'''
pass pass
opt.contribute_to_class(ObjectType, '_meta') opt.contribute_to_class(ObjectType, '_meta')
assert opt.description == 'False description' assert opt.description == 'False description'
def test_field_no_contributed_raises_error(): def test_field_no_contributed_raises_error():
opt = Options(InvalidMeta) opt = Options(InvalidMeta)

View File

@ -46,17 +46,21 @@ class Human(Character):
schema.query = Human schema.query = Human
def test_get_registered_type(): def test_get_registered_type():
assert schema.get_type('Character') == Character assert schema.get_type('Character') == Character
def test_get_unregistered_type(): def test_get_unregistered_type():
with raises(Exception) as excinfo: with raises(Exception) as excinfo:
schema.get_type('NON_EXISTENT_MODEL') schema.get_type('NON_EXISTENT_MODEL')
assert 'not found' in str(excinfo.value) assert 'not found' in str(excinfo.value)
def test_schema_query(): def test_schema_query():
assert schema.query == Human assert schema.query == Human
def test_query_schema_graphql(): def test_query_schema_graphql():
a = object() a = object()
query = ''' query = '''
@ -100,4 +104,5 @@ def test_query_schema_execute():
def test_schema_get_type_map(): def test_schema_get_type_map():
assert schema.schema.get_type_map().keys() == ['__Field', 'String', 'Pet', 'Character', '__InputValue', '__Directive', '__TypeKind', '__Schema', '__Type', 'Human', '__EnumValue', 'Boolean'] assert schema.schema.get_type_map().keys() == [
'__Field', 'String', 'Pet', 'Character', '__InputValue', '__Directive', '__TypeKind', '__Schema', '__Type', 'Human', '__EnumValue', 'Boolean']

View File

@ -17,13 +17,16 @@ from graphene.core.types import (
class Character(Interface): class Character(Interface):
'''Character description''' '''Character description'''
name = StringField() name = StringField()
class Meta: class Meta:
type_name = 'core.Character' type_name = 'core.Character'
class Human(Character): class Human(Character):
'''Human description''' '''Human description'''
friends = StringField() friends = StringField()
@ -39,7 +42,8 @@ def test_interface():
assert isinstance(object_type, GraphQLInterfaceType) assert isinstance(object_type, GraphQLInterfaceType)
assert Character._meta.type_name == 'core.Character' assert Character._meta.type_name == 'core.Character'
assert object_type.description == 'Character description' assert object_type.description == 'Character description'
assert object_type.get_fields() == {'name': Character._meta.fields_map['name'].internal_field(schema)} assert object_type.get_fields() == {
'name': Character._meta.fields_map['name'].internal_field(schema)}
def test_interface_resolve_type(): def test_interface_resolve_type():
@ -53,5 +57,6 @@ def test_object_type():
assert Human._meta.type_name == 'core.Human' assert Human._meta.type_name == 'core.Human'
assert isinstance(object_type, GraphQLObjectType) assert isinstance(object_type, GraphQLObjectType)
assert object_type.description == 'Human description' assert object_type.description == 'Human description'
assert object_type.get_fields() == {'name': Character._meta.fields_map['name'].internal_field(schema), 'friends': Human._meta.fields_map['friends'].internal_field(schema)} assert object_type.get_fields() == {'name': Character._meta.fields_map['name'].internal_field(
schema), 'friends': Human._meta.fields_map['friends'].internal_field(schema)}
assert object_type.get_interfaces() == [Character.internal_type(schema)] assert object_type.get_interfaces() == [Character.internal_type(schema)]

View File

@ -5,6 +5,7 @@ from graphene import relay
schema = graphene.Schema() schema = graphene.Schema()
class OtherNode(relay.Node): class OtherNode(relay.Node):
name = graphene.StringField() name = graphene.StringField()

View File

@ -73,6 +73,7 @@ droidData = {
'2001': artoo, '2001': artoo,
} }
def getCharacter(id): def getCharacter(id):
return humanData.get(id) or droidData.get(id) return humanData.get(id) or droidData.get(id)

View File

@ -2,6 +2,7 @@ from collections import namedtuple
from .models import Ship, Faction from .models import Ship, Faction
def initialize(): def initialize():
rebels = Faction( rebels = Faction(
id='1', id='1',
@ -15,7 +16,6 @@ def initialize():
) )
empire.save() empire.save()
xwing = Ship( xwing = Ship(
id='1', id='1',
name='X-Wing', name='X-Wing',
@ -88,14 +88,18 @@ def createShip(shipName, factionId):
def getShip(_id): def getShip(_id):
return Ship.objects.get(id=_id) return Ship.objects.get(id=_id)
def getShips(): def getShips():
return Ship.objects.all() return Ship.objects.all()
def getFaction(_id): def getFaction(_id):
return Faction.objects.get(id=_id) return Faction.objects.get(id=_id)
def getRebels(): def getRebels():
return getFaction(1) return getFaction(1)
def getEmpire(): def getEmpire():
return getFaction(2) return getFaction(2)

View File

@ -15,7 +15,9 @@ from .data import (
schema = graphene.Schema(name='Starwars Django Relay Schema') schema = graphene.Schema(name='Starwars Django Relay Schema')
class Ship(DjangoNode): class Ship(DjangoNode):
class Meta: class Meta:
model = ShipModel model = ShipModel
@ -23,7 +25,9 @@ class Ship(DjangoNode):
def get_node(cls, id): def get_node(cls, id):
return Ship(getShip(id)) return Ship(getShip(id))
class Faction(DjangoNode): class Faction(DjangoNode):
class Meta: class Meta:
model = FactionModel model = FactionModel

View File

@ -7,6 +7,7 @@ from .data import initialize
pytestmark = pytest.mark.django_db pytestmark = pytest.mark.django_db
def test_correct_fetch_first_ship_rebels(): def test_correct_fetch_first_ship_rebels():
initialize() initialize()
query = ''' query = '''

View File

@ -7,6 +7,7 @@ from .schema import schema
pytestmark = pytest.mark.django_db pytestmark = pytest.mark.django_db
def test_correctly_fetches_id_name_rebels(): def test_correctly_fetches_id_name_rebels():
initialize() initialize()
query = ''' query = '''
@ -27,6 +28,7 @@ def test_correctly_fetches_id_name_rebels():
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_correctly_refetches_rebels(): def test_correctly_refetches_rebels():
initialize() initialize()
query = ''' query = '''
@ -49,6 +51,7 @@ def test_correctly_refetches_rebels():
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_correctly_fetches_id_name_empire(): def test_correctly_fetches_id_name_empire():
initialize() initialize()
query = ''' query = '''
@ -69,6 +72,7 @@ def test_correctly_fetches_id_name_empire():
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_correctly_refetches_empire(): def test_correctly_refetches_empire():
initialize() initialize()
query = ''' query = '''
@ -91,6 +95,7 @@ def test_correctly_refetches_empire():
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_correctly_refetches_xwing(): def test_correctly_refetches_xwing():
initialize() initialize()
query = ''' query = '''

View File

@ -74,6 +74,7 @@ data = {
} }
} }
def createShip(shipName, factionId): def createShip(shipName, factionId):
nextShip = len(data['Ship'].keys())+1 nextShip = len(data['Ship'].keys())+1
newShip = Ship( newShip = Ship(
@ -88,11 +89,14 @@ def createShip(shipName, factionId):
def getShip(_id): def getShip(_id):
return data['Ship'][_id] return data['Ship'][_id]
def getFaction(_id): def getFaction(_id):
return data['Faction'][_id] return data['Faction'][_id]
def getRebels(): def getRebels():
return rebels return rebels
def getEmpire(): def getEmpire():
return empire return empire

View File

@ -10,7 +10,9 @@ from .data import (
schema = graphene.Schema(name='Starwars Relay Schema') schema = graphene.Schema(name='Starwars Relay Schema')
class Ship(relay.Node): class Ship(relay.Node):
'''A ship in the Star Wars saga''' '''A ship in the Star Wars saga'''
name = graphene.StringField(description='The name of the ship.') name = graphene.StringField(description='The name of the ship.')
@ -20,9 +22,11 @@ class Ship(relay.Node):
class Faction(relay.Node): class Faction(relay.Node):
'''A faction in the Star Wars saga''' '''A faction in the Star Wars saga'''
name = graphene.StringField(description='The name of the faction.') name = graphene.StringField(description='The name of the faction.')
ships = relay.ConnectionField(Ship, description='The ships used by the faction.') ships = relay.ConnectionField(
Ship, description='The ships used by the faction.')
@resolve_only_args @resolve_only_args
def resolve_ships(self, **kwargs): def resolve_ships(self, **kwargs):

View File

@ -11,6 +11,7 @@ Episode = graphene.Enum('Episode', dict(
JEDI=6 JEDI=6
)) ))
def wrap_character(character): def wrap_character(character):
if isinstance(character, _Human): if isinstance(character, _Human):
return Human(character) return Human(character)

View File

@ -3,6 +3,7 @@ from graphql.core import graphql
from .schema import schema from .schema import schema
def test_correct_fetch_first_ship_rebels(): def test_correct_fetch_first_ship_rebels():
query = ''' query = '''
query RebelsShipsQuery { query RebelsShipsQuery {

View File

@ -3,6 +3,7 @@ from graphql.core import graphql
from .schema import schema from .schema import schema
def test_correctly_fetches_id_name_rebels(): def test_correctly_fetches_id_name_rebels():
query = ''' query = '''
query RebelsQuery { query RebelsQuery {
@ -22,6 +23,7 @@ def test_correctly_fetches_id_name_rebels():
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_correctly_refetches_rebels(): def test_correctly_refetches_rebels():
query = ''' query = '''
query RebelsRefetchQuery { query RebelsRefetchQuery {
@ -43,6 +45,7 @@ def test_correctly_refetches_rebels():
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_correctly_fetches_id_name_empire(): def test_correctly_fetches_id_name_empire():
query = ''' query = '''
query EmpireQuery { query EmpireQuery {
@ -62,6 +65,7 @@ def test_correctly_fetches_id_name_empire():
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_correctly_refetches_empire(): def test_correctly_refetches_empire():
query = ''' query = '''
query EmpireRefetchQuery { query EmpireRefetchQuery {
@ -83,6 +87,7 @@ def test_correctly_refetches_empire():
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_correctly_refetches_xwing(): def test_correctly_refetches_xwing():
query = ''' query = '''
query XWingRefetchQuery { query XWingRefetchQuery {