Improved integration with Graphene 2.0

This commit is contained in:
Syrus Akbary 2017-07-24 22:27:50 -07:00
parent 18db46e132
commit 48bcccdac2
10 changed files with 86 additions and 115 deletions

View File

@ -10,7 +10,7 @@ from graphene.utils.str_converters import to_camel_case, to_const
from graphql import assert_valid_name from graphql import assert_valid_name
from .compat import ArrayField, HStoreField, JSONField, RangeField from .compat import ArrayField, HStoreField, JSONField, RangeField
from .fields import get_connection_field, DjangoListField from .fields import DjangoListField, DjangoConnectionField
from .utils import get_related_model, import_single_dispatch from .utils import get_related_model, import_single_dispatch
singledispatch = import_single_dispatch() singledispatch = import_single_dispatch()
@ -148,8 +148,16 @@ def convert_field_to_list_or_connection(field, registry=None):
if not _type: if not _type:
return return
if is_node(_type): # If there is a connection, we should transform the field
return get_connection_field(_type) # into a DjangoConnectionField
if _type._meta.connection:
# Use a DjangoFilterConnectionField if there are
# defined filter_fields in the DjangoObjectType Meta
if _type._meta.filter_fields:
from .filter.fields import DjangoFilterConnectionField
return DjangoFilterConnectionField(_type)
return DjangoConnectionField(_type)
return DjangoListField(_type) return DjangoListField(_type)

View File

@ -43,6 +43,13 @@ class DjangoConnectionField(ConnectionField):
) )
super(DjangoConnectionField, self).__init__(*args, **kwargs) super(DjangoConnectionField, self).__init__(*args, **kwargs)
@property
def type(self):
from .types import DjangoObjectType
_type = super(ConnectionField, self).type
assert issubclass(_type, DjangoObjectType), "DjangoConnectionField only accepts DjangoObjectType types"
return _type._meta.connection
@property @property
def node_type(self): def node_type(self):
return self.type._meta.node return self.type._meta.node
@ -128,10 +135,3 @@ class DjangoConnectionField(ConnectionField):
self.max_limit, self.max_limit,
self.enforce_first_or_last self.enforce_first_or_last
) )
def get_connection_field(*args, **kwargs):
if DJANGO_FILTER_INSTALLED:
from .filter.fields import DjangoFilterConnectionField
return DjangoFilterConnectionField(*args, **kwargs)
return DjangoConnectionField(*args, **kwargs)

View File

@ -356,7 +356,7 @@ def test_recursive_filter_connection():
class ReporterFilterNode(DjangoObjectType): class ReporterFilterNode(DjangoObjectType):
child_reporters = DjangoFilterConnectionField(lambda: ReporterFilterNode) child_reporters = DjangoFilterConnectionField(lambda: ReporterFilterNode)
def resolve_child_reporters(self, args, context, info): def resolve_child_reporters(self, **args):
return [] return []
class Meta: class Meta:
@ -399,7 +399,7 @@ def test_should_query_filter_node_limit():
filterset_class=ReporterFilter filterset_class=ReporterFilter
) )
def resolve_all_reporters(self, args, context, info): def resolve_all_reporters(self, **args):
return Reporter.objects.order_by('a_choice') return Reporter.objects.order_by('a_choice')
Reporter.objects.create( Reporter.objects.create(
@ -499,7 +499,7 @@ def test_should_query_filter_node_double_limit_raises():
filterset_class=ReporterFilter filterset_class=ReporterFilter
) )
def resolve_all_reporters(self, args, context, info): def resolve_all_reporters(self, **args):
return Reporter.objects.order_by('a_choice')[:2] return Reporter.objects.order_by('a_choice')[:2]
Reporter.objects.create( Reporter.objects.create(

View File

@ -1,8 +1,10 @@
class Registry(object): class Registry(object):
def __init__(self): def __init__(self):
self._registry = {} self._registry = {}
self._registry_models = {} self._registry_models = {}
self._connection_types = {}
def register(self, cls): def register(self, cls):
from .types import DjangoObjectType from .types import DjangoObjectType

View File

@ -3,16 +3,14 @@ from functools import partial
import six import six
import graphene import graphene
from graphene.types import Argument, Field from graphene import relay
from graphene.types.mutation import Mutation, MutationMeta from graphene.types import Argument, Field, InputField
from graphene.types.mutation import Mutation, MutationOptions
from graphene.types.objecttype import ( from graphene.types.objecttype import (
ObjectTypeMeta,
merge,
yank_fields_from_attrs yank_fields_from_attrs
) )
from graphene.types.options import Options from graphene.types.options import Options
from graphene.types.utils import get_field_as from graphene.types.utils import get_field_as
from graphene.utils.is_base_type import is_base_type
from .serializer_converter import ( from .serializer_converter import (
convert_serializer_to_input_type, convert_serializer_to_input_type,
@ -21,91 +19,53 @@ from .serializer_converter import (
from .types import ErrorType from .types import ErrorType
class SerializerMutationOptions(Options): class SerializerMutationOptions(MutationOptions):
def __init__(self, *args, **kwargs): serializer_class = None
super().__init__(*args, serializer_class=None, **kwargs)
class SerializerMutationMeta(MutationMeta): def fields_for_serializer(serializer, only_fields, exclude_fields):
def __new__(cls, name, bases, attrs): fields = OrderedDict()
if not is_base_type(bases, SerializerMutationMeta): for name, field in serializer.fields.items():
return type.__new__(cls, name, bases, attrs) is_not_in_only = only_fields and name not in only_fields
is_excluded = (
options = Options( name in exclude_fields # or
attrs.pop('Meta', None), # name in already_created_fields
name=name,
description=attrs.pop('__doc__', None),
serializer_class=None,
local_fields=None,
only_fields=(),
exclude_fields=(),
interfaces=(),
registry=None
) )
if not options.serializer_class: if is_not_in_only or is_excluded:
raise Exception('Missing serializer_class') continue
cls = ObjectTypeMeta.__new__( fields[name] = convert_serializer_field(field, is_input=False)
cls, name, bases, dict(attrs, _meta=options) return fields
)
serializer_fields = cls.fields_for_serializer(options)
options.serializer_fields = yank_fields_from_attrs( class SerializerMutation(relay.ClientIDMutation):
errors = graphene.List(
ErrorType,
description='May contain more than one error for same field.'
)
@classmethod
def __init_subclass_with_meta__(cls, serializer_class,
only_fields=(), exclude_fields=(), **options):
if not serializer_class:
raise Exception('serializer_class is required for the SerializerMutation')
serializer = serializer_class()
serializer_fields = fields_for_serializer(serializer, only_fields, exclude_fields)
_meta = SerializerMutationOptions(cls)
_meta.fields = yank_fields_from_attrs(
serializer_fields, serializer_fields,
_as=Field, _as=Field,
) )
options.fields = merge( _meta.input_fields = yank_fields_from_attrs(
options.interface_fields, options.serializer_fields, serializer_fields,
options.base_fields, options.local_fields, _as=InputField,
{'errors': get_field_as(cls.errors, Field)}
) )
cls.Input = convert_serializer_to_input_type(options.serializer_class)
cls.Field = partial(
Field,
cls,
resolver=cls.mutate,
input=Argument(cls.Input, required=True)
)
return cls
@staticmethod
def fields_for_serializer(options):
serializer = options.serializer_class()
only_fields = options.only_fields
already_created_fields = {
name
for name, _ in options.local_fields.items()
}
fields = OrderedDict()
for name, field in serializer.fields.items():
is_not_in_only = only_fields and name not in only_fields
is_excluded = (
name in options.exclude_fields or
name in already_created_fields
)
if is_not_in_only or is_excluded:
continue
fields[name] = convert_serializer_field(field, is_input=False)
return fields
class SerializerMutation(six.with_metaclass(SerializerMutationMeta, Mutation)):
errors = graphene.List(
ErrorType,
description='May contain more than one error for '
'same field.'
)
@classmethod @classmethod
def mutate(cls, instance, args, request, info): def mutate(cls, instance, args, request, info):
input = args.get('input') input = args.get('input')

View File

@ -22,7 +22,7 @@ class Human(DjangoObjectType):
model = Article model = Article
interfaces = (relay.Node, ) interfaces = (relay.Node, )
def resolve_raises(self, *args): def resolve_raises(self):
raise Exception("This field should raise exception") raise Exception("This field should raise exception")
def get_node(self, id): def get_node(self, id):
@ -32,7 +32,7 @@ class Human(DjangoObjectType):
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
human = graphene.Field(Human) human = graphene.Field(Human)
def resolve_human(self, args, context, info): def resolve_human(self):
return Human() return Human()

View File

@ -1,5 +1,5 @@
import graphene import graphene
from graphene import ObjectType, Schema from graphene import ObjectType, Schema, annotate, Context
class QueryRoot(ObjectType): class QueryRoot(ObjectType):
@ -8,21 +8,21 @@ class QueryRoot(ObjectType):
request = graphene.String(required=True) request = graphene.String(required=True)
test = graphene.String(who=graphene.String()) test = graphene.String(who=graphene.String())
def resolve_thrower(self, args, context, info): def resolve_thrower(self):
raise Exception("Throws!") raise Exception("Throws!")
def resolve_request(self, args, context, info): @annotate(request=Context)
request = context def resolve_request(self, request):
return request.GET.get('q') return request.GET.get('q')
def resolve_test(self, args, context, info): def resolve_test(self, who=None):
return 'Hello %s' % (args.get('who') or 'World') return 'Hello %s' % (who or 'World')
class MutationRoot(ObjectType): class MutationRoot(ObjectType):
write_test = graphene.Field(QueryRoot) write_test = graphene.Field(QueryRoot)
def resolve_write_test(self, args, context, info): def resolve_write_test(self):
return QueryRoot() return QueryRoot()

View File

@ -224,7 +224,7 @@ def test_should_manytomany_convert_connectionorlist_connection():
assert isinstance(graphene_field, graphene.Dynamic) assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type() dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, ConnectionField) assert isinstance(dynamic_field, ConnectionField)
assert dynamic_field.type == A.Connection assert dynamic_field.type == A._meta.connection
def test_should_manytoone_convert_connectionorlist(): def test_should_manytoone_convert_connectionorlist():

View File

@ -46,7 +46,7 @@ def test_should_query_simplelazy_objects():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
def resolve_reporter(self, args, context, info): def resolve_reporter(self):
return SimpleLazyObject(lambda: Reporter(id=1)) return SimpleLazyObject(lambda: Reporter(id=1))
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
@ -75,7 +75,7 @@ def test_should_query_well():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
def resolve_reporter(self, *args, **kwargs): def resolve_reporter(self):
return Reporter(first_name='ABA', last_name='X') return Reporter(first_name='ABA', last_name='X')
query = ''' query = '''
@ -119,7 +119,7 @@ def test_should_query_postgres_fields():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
event = graphene.Field(EventType) event = graphene.Field(EventType)
def resolve_event(self, *args, **kwargs): def resolve_event(self):
return Event( return Event(
ages=(0, 10), ages=(0, 10),
data={'angry_babies': True}, data={'angry_babies': True},
@ -165,7 +165,7 @@ def test_should_node():
def get_node(cls, id, context, info): def get_node(cls, id, context, info):
return Reporter(id=2, first_name='Cookie Monster') return Reporter(id=2, first_name='Cookie Monster')
def resolve_articles(self, *args, **kwargs): def resolve_articles(self, **args):
return [Article(headline='Hi!')] return [Article(headline='Hi!')]
class ArticleNode(DjangoObjectType): class ArticleNode(DjangoObjectType):
@ -183,7 +183,7 @@ def test_should_node():
reporter = graphene.Field(ReporterNode) reporter = graphene.Field(ReporterNode)
article = graphene.Field(ArticleNode) article = graphene.Field(ArticleNode)
def resolve_reporter(self, *args, **kwargs): def resolve_reporter(self):
return Reporter(id=1, first_name='ABA', last_name='X') return Reporter(id=1, first_name='ABA', last_name='X')
query = ''' query = '''
@ -250,7 +250,7 @@ def test_should_query_connectionfields():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
def resolve_all_reporters(self, args, context, info): def resolve_all_reporters(self, **args):
return [Reporter(id=1)] return [Reporter(id=1)]
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
@ -308,10 +308,10 @@ def test_should_keep_annotations():
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
all_articles = DjangoConnectionField(ArticleType) all_articles = DjangoConnectionField(ArticleType)
def resolve_all_reporters(self, args, context, info): def resolve_all_reporters(self, **args):
return Reporter.objects.annotate(articles_c=Count('articles')).order_by('articles_c') return Reporter.objects.annotate(articles_c=Count('articles')).order_by('articles_c')
def resolve_all_articles(self, args, context, info): def resolve_all_articles(self, **args):
return Article.objects.annotate(import_avg=Avg('importance')).order_by('import_avg') return Article.objects.annotate(import_avg=Avg('importance')).order_by('import_avg')
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
@ -618,7 +618,7 @@ def test_should_query_promise_connectionfields():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
def resolve_all_reporters(self, *args, **kwargs): def resolve_all_reporters(self, **args):
return Promise.resolve([Reporter(id=1)]) return Promise.resolve([Reporter(id=1)])
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
@ -673,10 +673,11 @@ def test_should_query_dataloader_fields():
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node, )
use_connection = True
articles = DjangoConnectionField(ArticleType) articles = DjangoConnectionField(ArticleType)
def resolve_articles(self, *args, **kwargs): def resolve_articles(self, **args):
return article_loader.load(self.id) return article_loader.load(self.id)
class Query(graphene.ObjectType): class Query(graphene.ObjectType):

View File

@ -1,7 +1,7 @@
from setuptools import find_packages, setup from setuptools import find_packages, setup
rest_framework_require = [ rest_framework_require = [
'djangorestframework==3.6.3', 'djangorestframework>=3.6.3',
] ]
@ -17,7 +17,7 @@ tests_require = [
setup( setup(
name='graphene-django', name='graphene-django',
version='1.3', version='2.0.dev',
description='Graphene Django integration', description='Graphene Django integration',
long_description=open('README.rst').read(), long_description=open('README.rst').read(),
@ -48,11 +48,11 @@ setup(
install_requires=[ install_requires=[
'six>=1.10.0', 'six>=1.10.0',
'graphene>=1.4', 'graphene>=2.0.dev',
'Django>=1.8.0', 'Django>=1.8.0',
'iso8601', 'iso8601',
'singledispatch>=3.4.0.3', 'singledispatch>=3.4.0.3',
'promise>=2.0', 'promise>=2.1.dev',
], ],
setup_requires=[ setup_requires=[
'pytest-runner', 'pytest-runner',