Improved integration with Graphene 2.0

This commit is contained in:
Syrus Akbary 2017-07-24 22:27:50 -07:00
parent 18db46e132
commit 48bcccdac2
10 changed files with 86 additions and 115 deletions

View File

@ -10,7 +10,7 @@ from graphene.utils.str_converters import to_camel_case, to_const
from graphql import assert_valid_name
from .compat import ArrayField, HStoreField, JSONField, RangeField
from .fields import get_connection_field, DjangoListField
from .fields import DjangoListField, DjangoConnectionField
from .utils import get_related_model, import_single_dispatch
singledispatch = import_single_dispatch()
@ -148,8 +148,16 @@ def convert_field_to_list_or_connection(field, registry=None):
if not _type:
return
if is_node(_type):
return get_connection_field(_type)
# If there is a connection, we should transform the field
# into a DjangoConnectionField
if _type._meta.connection:
# Use a DjangoFilterConnectionField if there are
# defined filter_fields in the DjangoObjectType Meta
if _type._meta.filter_fields:
from .filter.fields import DjangoFilterConnectionField
return DjangoFilterConnectionField(_type)
return DjangoConnectionField(_type)
return DjangoListField(_type)

View File

@ -43,6 +43,13 @@ class DjangoConnectionField(ConnectionField):
)
super(DjangoConnectionField, self).__init__(*args, **kwargs)
@property
def type(self):
from .types import DjangoObjectType
_type = super(ConnectionField, self).type
assert issubclass(_type, DjangoObjectType), "DjangoConnectionField only accepts DjangoObjectType types"
return _type._meta.connection
@property
def node_type(self):
return self.type._meta.node
@ -128,10 +135,3 @@ class DjangoConnectionField(ConnectionField):
self.max_limit,
self.enforce_first_or_last
)
def get_connection_field(*args, **kwargs):
if DJANGO_FILTER_INSTALLED:
from .filter.fields import DjangoFilterConnectionField
return DjangoFilterConnectionField(*args, **kwargs)
return DjangoConnectionField(*args, **kwargs)

View File

@ -356,7 +356,7 @@ def test_recursive_filter_connection():
class ReporterFilterNode(DjangoObjectType):
child_reporters = DjangoFilterConnectionField(lambda: ReporterFilterNode)
def resolve_child_reporters(self, args, context, info):
def resolve_child_reporters(self, **args):
return []
class Meta:
@ -399,7 +399,7 @@ def test_should_query_filter_node_limit():
filterset_class=ReporterFilter
)
def resolve_all_reporters(self, args, context, info):
def resolve_all_reporters(self, **args):
return Reporter.objects.order_by('a_choice')
Reporter.objects.create(
@ -499,7 +499,7 @@ def test_should_query_filter_node_double_limit_raises():
filterset_class=ReporterFilter
)
def resolve_all_reporters(self, args, context, info):
def resolve_all_reporters(self, **args):
return Reporter.objects.order_by('a_choice')[:2]
Reporter.objects.create(

View File

@ -1,8 +1,10 @@
class Registry(object):
def __init__(self):
self._registry = {}
self._registry_models = {}
self._connection_types = {}
def register(self, cls):
from .types import DjangoObjectType

View File

@ -3,16 +3,14 @@ from functools import partial
import six
import graphene
from graphene.types import Argument, Field
from graphene.types.mutation import Mutation, MutationMeta
from graphene import relay
from graphene.types import Argument, Field, InputField
from graphene.types.mutation import Mutation, MutationOptions
from graphene.types.objecttype import (
ObjectTypeMeta,
merge,
yank_fields_from_attrs
)
from graphene.types.options import Options
from graphene.types.utils import get_field_as
from graphene.utils.is_base_type import is_base_type
from .serializer_converter import (
convert_serializer_to_input_type,
@ -21,91 +19,53 @@ from .serializer_converter import (
from .types import ErrorType
class SerializerMutationOptions(Options):
def __init__(self, *args, **kwargs):
super().__init__(*args, serializer_class=None, **kwargs)
class SerializerMutationOptions(MutationOptions):
serializer_class = None
class SerializerMutationMeta(MutationMeta):
def __new__(cls, name, bases, attrs):
if not is_base_type(bases, SerializerMutationMeta):
return type.__new__(cls, name, bases, attrs)
options = Options(
attrs.pop('Meta', None),
name=name,
description=attrs.pop('__doc__', None),
serializer_class=None,
local_fields=None,
only_fields=(),
exclude_fields=(),
interfaces=(),
registry=None
def fields_for_serializer(serializer, only_fields, exclude_fields):
fields = OrderedDict()
for name, field in serializer.fields.items():
is_not_in_only = only_fields and name not in only_fields
is_excluded = (
name in exclude_fields # or
# name in already_created_fields
)
if not options.serializer_class:
raise Exception('Missing serializer_class')
if is_not_in_only or is_excluded:
continue
cls = ObjectTypeMeta.__new__(
cls, name, bases, dict(attrs, _meta=options)
)
fields[name] = convert_serializer_field(field, is_input=False)
return fields
serializer_fields = cls.fields_for_serializer(options)
options.serializer_fields = yank_fields_from_attrs(
class SerializerMutation(relay.ClientIDMutation):
errors = graphene.List(
ErrorType,
description='May contain more than one error for same field.'
)
@classmethod
def __init_subclass_with_meta__(cls, serializer_class,
only_fields=(), exclude_fields=(), **options):
if not serializer_class:
raise Exception('serializer_class is required for the SerializerMutation')
serializer = serializer_class()
serializer_fields = fields_for_serializer(serializer, only_fields, exclude_fields)
_meta = SerializerMutationOptions(cls)
_meta.fields = yank_fields_from_attrs(
serializer_fields,
_as=Field,
)
options.fields = merge(
options.interface_fields, options.serializer_fields,
options.base_fields, options.local_fields,
{'errors': get_field_as(cls.errors, Field)}
_meta.input_fields = yank_fields_from_attrs(
serializer_fields,
_as=InputField,
)
cls.Input = convert_serializer_to_input_type(options.serializer_class)
cls.Field = partial(
Field,
cls,
resolver=cls.mutate,
input=Argument(cls.Input, required=True)
)
return cls
@staticmethod
def fields_for_serializer(options):
serializer = options.serializer_class()
only_fields = options.only_fields
already_created_fields = {
name
for name, _ in options.local_fields.items()
}
fields = OrderedDict()
for name, field in serializer.fields.items():
is_not_in_only = only_fields and name not in only_fields
is_excluded = (
name in options.exclude_fields or
name in already_created_fields
)
if is_not_in_only or is_excluded:
continue
fields[name] = convert_serializer_field(field, is_input=False)
return fields
class SerializerMutation(six.with_metaclass(SerializerMutationMeta, Mutation)):
errors = graphene.List(
ErrorType,
description='May contain more than one error for '
'same field.'
)
@classmethod
def mutate(cls, instance, args, request, info):
input = args.get('input')

View File

@ -22,7 +22,7 @@ class Human(DjangoObjectType):
model = Article
interfaces = (relay.Node, )
def resolve_raises(self, *args):
def resolve_raises(self):
raise Exception("This field should raise exception")
def get_node(self, id):
@ -32,7 +32,7 @@ class Human(DjangoObjectType):
class Query(graphene.ObjectType):
human = graphene.Field(Human)
def resolve_human(self, args, context, info):
def resolve_human(self):
return Human()

View File

@ -1,5 +1,5 @@
import graphene
from graphene import ObjectType, Schema
from graphene import ObjectType, Schema, annotate, Context
class QueryRoot(ObjectType):
@ -8,21 +8,21 @@ class QueryRoot(ObjectType):
request = graphene.String(required=True)
test = graphene.String(who=graphene.String())
def resolve_thrower(self, args, context, info):
def resolve_thrower(self):
raise Exception("Throws!")
def resolve_request(self, args, context, info):
request = context
@annotate(request=Context)
def resolve_request(self, request):
return request.GET.get('q')
def resolve_test(self, args, context, info):
return 'Hello %s' % (args.get('who') or 'World')
def resolve_test(self, who=None):
return 'Hello %s' % (who or 'World')
class MutationRoot(ObjectType):
write_test = graphene.Field(QueryRoot)
def resolve_write_test(self, args, context, info):
def resolve_write_test(self):
return QueryRoot()

View File

@ -224,7 +224,7 @@ def test_should_manytomany_convert_connectionorlist_connection():
assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, ConnectionField)
assert dynamic_field.type == A.Connection
assert dynamic_field.type == A._meta.connection
def test_should_manytoone_convert_connectionorlist():

View File

@ -46,7 +46,7 @@ def test_should_query_simplelazy_objects():
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
def resolve_reporter(self, args, context, info):
def resolve_reporter(self):
return SimpleLazyObject(lambda: Reporter(id=1))
schema = graphene.Schema(query=Query)
@ -75,7 +75,7 @@ def test_should_query_well():
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
def resolve_reporter(self, *args, **kwargs):
def resolve_reporter(self):
return Reporter(first_name='ABA', last_name='X')
query = '''
@ -119,7 +119,7 @@ def test_should_query_postgres_fields():
class Query(graphene.ObjectType):
event = graphene.Field(EventType)
def resolve_event(self, *args, **kwargs):
def resolve_event(self):
return Event(
ages=(0, 10),
data={'angry_babies': True},
@ -165,7 +165,7 @@ def test_should_node():
def get_node(cls, id, context, info):
return Reporter(id=2, first_name='Cookie Monster')
def resolve_articles(self, *args, **kwargs):
def resolve_articles(self, **args):
return [Article(headline='Hi!')]
class ArticleNode(DjangoObjectType):
@ -183,7 +183,7 @@ def test_should_node():
reporter = graphene.Field(ReporterNode)
article = graphene.Field(ArticleNode)
def resolve_reporter(self, *args, **kwargs):
def resolve_reporter(self):
return Reporter(id=1, first_name='ABA', last_name='X')
query = '''
@ -250,7 +250,7 @@ def test_should_query_connectionfields():
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
def resolve_all_reporters(self, args, context, info):
def resolve_all_reporters(self, **args):
return [Reporter(id=1)]
schema = graphene.Schema(query=Query)
@ -308,10 +308,10 @@ def test_should_keep_annotations():
all_reporters = DjangoConnectionField(ReporterType)
all_articles = DjangoConnectionField(ArticleType)
def resolve_all_reporters(self, args, context, info):
def resolve_all_reporters(self, **args):
return Reporter.objects.annotate(articles_c=Count('articles')).order_by('articles_c')
def resolve_all_articles(self, args, context, info):
def resolve_all_articles(self, **args):
return Article.objects.annotate(import_avg=Avg('importance')).order_by('import_avg')
schema = graphene.Schema(query=Query)
@ -618,7 +618,7 @@ def test_should_query_promise_connectionfields():
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
def resolve_all_reporters(self, *args, **kwargs):
def resolve_all_reporters(self, **args):
return Promise.resolve([Reporter(id=1)])
schema = graphene.Schema(query=Query)
@ -673,10 +673,11 @@ def test_should_query_dataloader_fields():
class Meta:
model = Reporter
interfaces = (Node, )
use_connection = True
articles = DjangoConnectionField(ArticleType)
def resolve_articles(self, *args, **kwargs):
def resolve_articles(self, **args):
return article_loader.load(self.id)
class Query(graphene.ObjectType):

View File

@ -1,7 +1,7 @@
from setuptools import find_packages, setup
rest_framework_require = [
'djangorestframework==3.6.3',
'djangorestframework>=3.6.3',
]
@ -17,7 +17,7 @@ tests_require = [
setup(
name='graphene-django',
version='1.3',
version='2.0.dev',
description='Graphene Django integration',
long_description=open('README.rst').read(),
@ -48,11 +48,11 @@ setup(
install_requires=[
'six>=1.10.0',
'graphene>=1.4',
'graphene>=2.0.dev',
'Django>=1.8.0',
'iso8601',
'singledispatch>=3.4.0.3',
'promise>=2.0',
'promise>=2.1.dev',
],
setup_requires=[
'pytest-runner',