Merge branch 'master' into fix-deprecation-warnings

This commit is contained in:
Syrus Akbary 2017-08-29 22:33:10 -07:00 committed by GitHub
commit a4cc360184
23 changed files with 220 additions and 253 deletions

View File

@ -4,7 +4,7 @@ Filtering
Graphene integrates with Graphene integrates with
`django-filter <https://django-filter.readthedocs.org>`__ to provide `django-filter <https://django-filter.readthedocs.org>`__ to provide
filtering of results. See the `usage filtering of results. See the `usage
documentation <https://django-filter.readthedocs.org/en/latest/usage.html#the-filter>`__ documentation <https://django-filter.readthedocs.io/en/latest/guide/usage.html#the-filter>`__
for details on the format for ``filter_fields``. for details on the format for ``filter_fields``.
This filtering is automatically available when implementing a ``relay.Node``. This filtering is automatically available when implementing a ``relay.Node``.
@ -26,7 +26,7 @@ Filterable fields
The ``filter_fields`` parameter is used to specify the fields which can The ``filter_fields`` parameter is used to specify the fields which can
be filtered upon. The value specified here is passed directly to be filtered upon. The value specified here is passed directly to
``django-filter``, so see the `filtering ``django-filter``, so see the `filtering
documentation <https://django-filter.readthedocs.org/en/latest/usage.html#the-filter>`__ documentation <https://django-filter.readthedocs.io/en/latest/guide/usage.html#the-filter>`__
for full details on the range of options available. for full details on the range of options available.
For example: For example:

View File

@ -16,7 +16,7 @@ class Ship(DjangoObjectType):
interfaces = (relay.Node, ) interfaces = (relay.Node, )
@classmethod @classmethod
def get_node(cls, id, context, info): def get_node(cls, info, id):
node = get_ship(id) node = get_ship(id)
return node return node
@ -34,7 +34,7 @@ class Faction(DjangoObjectType):
interfaces = (relay.Node, ) interfaces = (relay.Node, )
@classmethod @classmethod
def get_node(cls, id, context, info): def get_node(cls, info, id):
return get_faction(id) return get_faction(id)
@ -48,9 +48,7 @@ class IntroduceShip(relay.ClientIDMutation):
faction = graphene.Field(Faction) faction = graphene.Field(Faction)
@classmethod @classmethod
def mutate_and_get_payload(cls, input, context, info): def mutate_and_get_payload(cls, root, info, ship_name, faction_id, client_mutation_id=None):
ship_name = input.get('ship_name')
faction_id = input.get('faction_id')
ship = create_ship(ship_name, faction_id) ship = create_ship(ship_name, faction_id)
faction = get_faction(faction_id) faction = get_faction(faction_id)
return IntroduceShip(ship=ship, faction=faction) return IntroduceShip(ship=ship, faction=faction)

View File

@ -5,5 +5,10 @@ from .fields import (
DjangoConnectionField, DjangoConnectionField,
) )
__all__ = ['DjangoObjectType', __version__ = '2.0.dev2017073101'
'DjangoConnectionField']
__all__ = [
'__version__',
'DjangoObjectType',
'DjangoConnectionField'
]

View File

@ -2,15 +2,14 @@ from django.db import models
from django.utils.encoding import force_text from django.utils.encoding import force_text
from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List, from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List,
NonNull, String) NonNull, String, UUID)
from graphene.relay import is_node
from graphene.types.datetime import DateTime, Time from graphene.types.datetime import DateTime, Time
from graphene.types.json import JSONString from graphene.types.json import JSONString
from graphene.utils.str_converters import to_camel_case, to_const from graphene.utils.str_converters import to_camel_case, to_const
from graphql import assert_valid_name from graphql import assert_valid_name
from .compat import ArrayField, HStoreField, JSONField, RangeField from .compat import ArrayField, HStoreField, JSONField, RangeField
from .fields import get_connection_field, DjangoListField from .fields import DjangoListField, DjangoConnectionField
from .utils import import_single_dispatch from .utils import import_single_dispatch
singledispatch = import_single_dispatch() singledispatch = import_single_dispatch()
@ -79,11 +78,15 @@ def convert_field_to_string(field, registry=None):
@convert_django_field.register(models.AutoField) @convert_django_field.register(models.AutoField)
@convert_django_field.register(models.UUIDField)
def convert_field_to_id(field, registry=None): def convert_field_to_id(field, registry=None):
return ID(description=field.help_text, required=not field.null) return ID(description=field.help_text, required=not field.null)
@convert_django_field.register(models.UUIDField)
def convert_field_to_uuid(field, registry=None):
return UUID(description=field.help_text, required=not field.null)
@convert_django_field.register(models.PositiveIntegerField) @convert_django_field.register(models.PositiveIntegerField)
@convert_django_field.register(models.PositiveSmallIntegerField) @convert_django_field.register(models.PositiveSmallIntegerField)
@convert_django_field.register(models.SmallIntegerField) @convert_django_field.register(models.SmallIntegerField)
@ -148,8 +151,16 @@ def convert_field_to_list_or_connection(field, registry=None):
if not _type: if not _type:
return return
if is_node(_type): # If there is a connection, we should transform the field
return get_connection_field(_type) # into a DjangoConnectionField
if _type._meta.connection:
# Use a DjangoFilterConnectionField if there are
# defined filter_fields in the DjangoObjectType Meta
if _type._meta.filter_fields:
from .filter.fields import DjangoFilterConnectionField
return DjangoFilterConnectionField(_type)
return DjangoConnectionField(_type)
return DjangoListField(_type) return DjangoListField(_type)

View File

@ -39,7 +39,8 @@ class DjangoDebugContext(object):
class DjangoDebugMiddleware(object): class DjangoDebugMiddleware(object):
def resolve(self, next, root, args, context, info): def resolve(self, next, root, info, **args):
context = info.context
django_debug = getattr(context, 'django_debug', None) django_debug = getattr(context, 'django_debug', None)
if not django_debug: if not django_debug:
if context is None: if context is None:
@ -52,6 +53,6 @@ class DjangoDebugMiddleware(object):
)) ))
if info.schema.get_type('DjangoDebug') == info.return_type: if info.schema.get_type('DjangoDebug') == info.return_type:
return context.django_debug.get_debug_promise() return context.django_debug.get_debug_promise()
promise = next(root, args, context, info) promise = next(root, info, **args)
context.django_debug.add_promise(promise) context.django_debug.add_promise(promise)
return promise return promise

View File

@ -33,7 +33,7 @@ def test_should_query_field():
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug') debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_reporter(self, *args, **kwargs): def resolve_reporter(self, info, **args):
return Reporter.objects.first() return Reporter.objects.first()
query = ''' query = '''
@ -80,7 +80,7 @@ def test_should_query_list():
all_reporters = graphene.List(ReporterType) all_reporters = graphene.List(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug') debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_all_reporters(self, *args, **kwargs): def resolve_all_reporters(self, info, **args):
return Reporter.objects.all() return Reporter.objects.all()
query = ''' query = '''
@ -129,7 +129,7 @@ def test_should_query_connection():
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug') debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_all_reporters(self, *args, **kwargs): def resolve_all_reporters(self, info, **args):
return Reporter.objects.all() return Reporter.objects.all()
query = ''' query = '''
@ -181,11 +181,11 @@ def test_should_query_connectionfilter():
interfaces = (Node, ) interfaces = (Node, )
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterType) all_reporters = DjangoFilterConnectionField(ReporterType, fields=['last_name'])
s = graphene.String(resolver=lambda *_: "S") s = graphene.String(resolver=lambda *_: "S")
debug = graphene.Field(DjangoDebug, name='__debug') debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_all_reporters(self, *args, **kwargs): def resolve_all_reporters(self, info, **args):
return Reporter.objects.all() return Reporter.objects.all()
query = ''' query = '''

View File

@ -9,7 +9,7 @@ from graphene.relay import ConnectionField, PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice from graphql_relay.connection.arrayconnection import connection_from_list_slice
from .settings import graphene_settings from .settings import graphene_settings
from .utils import DJANGO_FILTER_INSTALLED, maybe_queryset from .utils import maybe_queryset
class DjangoListField(Field): class DjangoListField(Field):
@ -22,8 +22,8 @@ class DjangoListField(Field):
return self.type.of_type._meta.node._meta.model return self.type.of_type._meta.node._meta.model
@staticmethod @staticmethod
def list_resolver(resolver, root, args, context, info): def list_resolver(resolver, root, info, **args):
return maybe_queryset(resolver(root, args, context, info)) return maybe_queryset(resolver(root, info, **args))
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):
return partial(self.list_resolver, parent_resolver) return partial(self.list_resolver, parent_resolver)
@ -43,6 +43,14 @@ class DjangoConnectionField(ConnectionField):
) )
super(DjangoConnectionField, self).__init__(*args, **kwargs) super(DjangoConnectionField, self).__init__(*args, **kwargs)
@property
def type(self):
from .types import DjangoObjectType
_type = super(ConnectionField, self).type
assert issubclass(_type, DjangoObjectType), "DjangoConnectionField only accepts DjangoObjectType types"
assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__)
return _type._meta.connection
@property @property
def node_type(self): def node_type(self):
return self.type._meta.node return self.type._meta.node
@ -89,7 +97,7 @@ class DjangoConnectionField(ConnectionField):
@classmethod @classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit, def connection_resolver(cls, resolver, connection, default_manager, max_limit,
enforce_first_or_last, root, args, context, info): enforce_first_or_last, root, info, **args):
first = args.get('first') first = args.get('first')
last = args.get('last') last = args.get('last')
@ -111,7 +119,7 @@ class DjangoConnectionField(ConnectionField):
).format(first, info.field_name, max_limit) ).format(first, info.field_name, max_limit)
args['last'] = min(last, max_limit) args['last'] = min(last, max_limit)
iterable = resolver(root, args, context, info) iterable = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection, default_manager, args) on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
if Promise.is_thenable(iterable): if Promise.is_thenable(iterable):
@ -128,10 +136,3 @@ class DjangoConnectionField(ConnectionField):
self.max_limit, self.max_limit,
self.enforce_first_or_last self.enforce_first_or_last
) )
def get_connection_field(*args, **kwargs):
if DJANGO_FILTER_INSTALLED:
from .filter.fields import DjangoFilterConnectionField
return DjangoFilterConnectionField(*args, **kwargs)
return DjangoConnectionField(*args, **kwargs)

View File

@ -1,7 +1,6 @@
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
# from graphene.relay import is_node
from graphene.types.argument import to_arguments from graphene.types.argument import to_arguments
from ..fields import DjangoConnectionField from ..fields import DjangoConnectionField
from .utils import get_filtering_args_from_filterset, get_filterset_class from .utils import get_filtering_args_from_filterset, get_filterset_class
@ -69,7 +68,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
@classmethod @classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit, def connection_resolver(cls, resolver, connection, default_manager, max_limit,
enforce_first_or_last, filterset_class, filtering_args, enforce_first_or_last, filterset_class, filtering_args,
root, args, context, info): root, info, **args):
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
qs = filterset_class( qs = filterset_class(
data=filter_kwargs, data=filter_kwargs,
@ -83,9 +82,8 @@ class DjangoFilterConnectionField(DjangoConnectionField):
max_limit, max_limit,
enforce_first_or_last, enforce_first_or_last,
root, root,
args, info,
context, **args
info
) )
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):

View File

@ -114,9 +114,9 @@ def test_filter_explicit_filterset_orderable():
assert_orderable(field) assert_orderable(field)
def test_filter_shortcut_filterset_orderable_true(): # def test_filter_shortcut_filterset_orderable_true():
field = DjangoFilterConnectionField(ReporterNode) # field = DjangoFilterConnectionField(ReporterNode)
assert_not_orderable(field) # assert_not_orderable(field)
# def test_filter_shortcut_filterset_orderable_headline(): # def test_filter_shortcut_filterset_orderable_headline():
@ -356,7 +356,7 @@ def test_recursive_filter_connection():
class ReporterFilterNode(DjangoObjectType): class ReporterFilterNode(DjangoObjectType):
child_reporters = DjangoFilterConnectionField(lambda: ReporterFilterNode) child_reporters = DjangoFilterConnectionField(lambda: ReporterFilterNode)
def resolve_child_reporters(self, args, context, info): def resolve_child_reporters(self, **args):
return [] return []
class Meta: class Meta:
@ -399,7 +399,7 @@ def test_should_query_filter_node_limit():
filterset_class=ReporterFilter filterset_class=ReporterFilter
) )
def resolve_all_reporters(self, args, context, info): def resolve_all_reporters(self, info, **args):
return Reporter.objects.order_by('a_choice') return Reporter.objects.order_by('a_choice')
Reporter.objects.create( Reporter.objects.create(
@ -499,7 +499,7 @@ def test_should_query_filter_node_double_limit_raises():
filterset_class=ReporterFilter filterset_class=ReporterFilter
) )
def resolve_all_reporters(self, args, context, info): def resolve_all_reporters(self, info, **args):
return Reporter.objects.order_by('a_choice')[:2] return Reporter.objects.order_by('a_choice')[:2]
Reporter.objects.create( Reporter.objects.create(

View File

@ -1,7 +1,7 @@
from django import forms from django import forms
from django.forms.fields import BaseTemporalField from django.forms.fields import BaseTemporalField
from graphene import ID, Boolean, Float, Int, List, String from graphene import ID, Boolean, Float, Int, List, String, UUID
from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField
from .utils import import_single_dispatch from .utils import import_single_dispatch
@ -32,11 +32,15 @@ def convert_form_field(field):
@convert_form_field.register(forms.ChoiceField) @convert_form_field.register(forms.ChoiceField)
@convert_form_field.register(forms.RegexField) @convert_form_field.register(forms.RegexField)
@convert_form_field.register(forms.Field) @convert_form_field.register(forms.Field)
@convert_form_field.register(UUIDField)
def convert_form_field_to_string(field): def convert_form_field_to_string(field):
return String(description=field.help_text, required=field.required) return String(description=field.help_text, required=field.required)
@convert_form_field.register(UUIDField)
def convert_form_field_to_uuid(field):
return UUID(description=field.help_text, required=field.required)
@convert_form_field.register(forms.IntegerField) @convert_form_field.register(forms.IntegerField)
@convert_form_field.register(forms.NumberInput) @convert_form_field.register(forms.NumberInput)
def convert_form_field_to_int(field): def convert_form_field_to_int(field):

View File

@ -1,3 +1,4 @@
class Registry(object): class Registry(object):
def __init__(self): def __init__(self):

View File

@ -1,116 +1,74 @@
from collections import OrderedDict from collections import OrderedDict
from functools import partial
import six
import graphene import graphene
from graphene.types import Argument, Field from graphene.types import Field, InputField
from graphene.types.mutation import Mutation, MutationMeta from graphene.types.mutation import MutationOptions
from graphene.relay.mutation import ClientIDMutation
from graphene.types.objecttype import ( from graphene.types.objecttype import (
ObjectTypeMeta,
merge,
yank_fields_from_attrs yank_fields_from_attrs
) )
from graphene.types.options import Options
from graphene.types.utils import get_field_as
from graphene.utils.is_base_type import is_base_type
from .serializer_converter import ( from .serializer_converter import (
convert_serializer_to_input_type,
convert_serializer_field convert_serializer_field
) )
from .types import ErrorType from .types import ErrorType
class SerializerMutationOptions(Options): class SerializerMutationOptions(MutationOptions):
def __init__(self, *args, **kwargs): serializer_class = None
super().__init__(*args, serializer_class=None, **kwargs)
class SerializerMutationMeta(MutationMeta): def fields_for_serializer(serializer, only_fields, exclude_fields, is_input=False):
def __new__(cls, name, bases, attrs): fields = OrderedDict()
if not is_base_type(bases, SerializerMutationMeta): for name, field in serializer.fields.items():
return type.__new__(cls, name, bases, attrs) is_not_in_only = only_fields and name not in only_fields
is_excluded = (
options = Options( name in exclude_fields # or
attrs.pop('Meta', None), # name in already_created_fields
name=name,
description=attrs.pop('__doc__', None),
serializer_class=None,
local_fields=None,
only_fields=(),
exclude_fields=(),
interfaces=(),
registry=None
) )
if not options.serializer_class: if is_not_in_only or is_excluded:
raise Exception('Missing serializer_class') continue
cls = ObjectTypeMeta.__new__( fields[name] = convert_serializer_field(field, is_input=is_input)
cls, name, bases, dict(attrs, _meta=options) return fields
)
serializer_fields = cls.fields_for_serializer(options)
options.serializer_fields = yank_fields_from_attrs(
serializer_fields,
_as=Field,
)
options.fields = merge(
options.interface_fields, options.serializer_fields,
options.base_fields, options.local_fields,
{'errors': get_field_as(cls.errors, Field)}
)
cls.Input = convert_serializer_to_input_type(options.serializer_class)
cls.Field = partial(
Field,
cls,
resolver=cls.mutate,
input=Argument(cls.Input, required=True)
)
return cls
@staticmethod
def fields_for_serializer(options):
serializer = options.serializer_class()
only_fields = options.only_fields
already_created_fields = {
name
for name, _ in options.local_fields.items()
}
fields = OrderedDict()
for name, field in serializer.fields.items():
is_not_in_only = only_fields and name not in only_fields
is_excluded = (
name in options.exclude_fields or
name in already_created_fields
)
if is_not_in_only or is_excluded:
continue
fields[name] = convert_serializer_field(field, is_input=False)
return fields
class SerializerMutation(six.with_metaclass(SerializerMutationMeta, Mutation)): class SerializerMutation(ClientIDMutation):
class Meta:
abstract = True
errors = graphene.List( errors = graphene.List(
ErrorType, ErrorType,
description='May contain more than one error for ' description='May contain more than one error for same field.'
'same field.'
) )
@classmethod @classmethod
def mutate(cls, instance, args, request, info): def __init_subclass_with_meta__(cls, serializer_class=None,
input = args.get('input') only_fields=(), exclude_fields=(), **options):
serializer = cls._meta.serializer_class(data=dict(input)) if not serializer_class:
raise Exception('serializer_class is required for the SerializerMutation')
serializer = serializer_class()
input_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=True)
output_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=False)
_meta = SerializerMutationOptions(cls)
_meta.fields = yank_fields_from_attrs(
output_fields,
_as=Field,
)
input_fields = yank_fields_from_attrs(
input_fields,
_as=InputField,
)
super(SerializerMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options)
@classmethod
def mutate_and_get_payload(cls, root, info, **input):
serializer = cls._meta.serializer_class(data=input)
if serializer.is_valid(): if serializer.is_valid():
return cls.perform_mutate(serializer, info) return cls.perform_mutate(serializer, info)
@ -125,5 +83,4 @@ class SerializerMutation(six.with_metaclass(SerializerMutationMeta, Mutation)):
@classmethod @classmethod
def perform_mutate(cls, serializer, info): def perform_mutate(cls, serializer, info):
obj = serializer.save() obj = serializer.save()
return cls(**obj)
return cls(errors=[], **obj)

View File

@ -2,6 +2,7 @@ from django.core.exceptions import ImproperlyConfigured
from rest_framework import serializers from rest_framework import serializers
import graphene import graphene
from graphene import Dynamic
from ..registry import get_global_registry from ..registry import get_global_registry
from ..utils import import_single_dispatch from ..utils import import_single_dispatch
@ -10,21 +11,6 @@ from .types import DictType
singledispatch = import_single_dispatch() singledispatch = import_single_dispatch()
def convert_serializer_to_input_type(serializer_class):
serializer = serializer_class()
items = {
name: convert_serializer_field(field)
for name, field in serializer.fields.items()
}
return type(
'{}Input'.format(serializer.__class__.__name__),
(graphene.InputObjectType, ),
items
)
@singledispatch @singledispatch
def get_graphene_type_from_serializer_field(field): def get_graphene_type_from_serializer_field(field):
raise ImproperlyConfigured( raise ImproperlyConfigured(
@ -56,7 +42,8 @@ def convert_serializer_field(field, is_input=True):
if isinstance(field, serializers.ModelSerializer): if isinstance(field, serializers.ModelSerializer):
if is_input: if is_input:
graphql_type = convert_serializer_to_input_type(field.__class__) return Dynamic(lambda: None)
# graphql_type = convert_serializer_to_input_type(field.__class__)
else: else:
global_registry = get_global_registry() global_registry = get_global_registry()
field_model = field.Meta.model field_model = field.Meta.model

View File

@ -28,7 +28,7 @@ def test_needs_serializer_class():
class MyMutation(SerializerMutation): class MyMutation(SerializerMutation):
pass pass
assert exc.value.args[0] == 'Missing serializer_class' assert str(exc.value) == 'serializer_class is required for the SerializerMutation'
def test_has_fields(): def test_has_fields():
@ -65,6 +65,7 @@ def test_nested_model():
assert model_field.type == MyFakeModelGrapheneType assert model_field.type == MyFakeModelGrapheneType
model_input = MyMutation.Input._meta.fields['model'] model_input = MyMutation.Input._meta.fields['model']
model_input_type = model_input._type.of_type model_input_type = model_input.get_type()
assert issubclass(model_input_type, InputObjectType) assert not model_input_type
assert 'cool_name' in model_input_type._meta.fields # assert issubclass(model_input_type, InputObjectType)
# assert 'cool_name' in model_input_type._meta.fields

View File

@ -11,7 +11,7 @@ class Character(DjangoObjectType):
model = Reporter model = Reporter
interfaces = (relay.Node, ) interfaces = (relay.Node, )
def get_node(self, id, context, info): def get_node(self, info, id):
pass pass
@ -22,17 +22,17 @@ class Human(DjangoObjectType):
model = Article model = Article
interfaces = (relay.Node, ) interfaces = (relay.Node, )
def resolve_raises(self, *args): def resolve_raises(self, info):
raise Exception("This field should raise exception") raise Exception("This field should raise exception")
def get_node(self, id): def get_node(self, info, id):
pass pass
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
human = graphene.Field(Human) human = graphene.Field(Human)
def resolve_human(self, args, context, info): def resolve_human(self, info):
return Human() return Human()

View File

@ -8,21 +8,20 @@ class QueryRoot(ObjectType):
request = graphene.String(required=True) request = graphene.String(required=True)
test = graphene.String(who=graphene.String()) test = graphene.String(who=graphene.String())
def resolve_thrower(self, args, context, info): def resolve_thrower(self, info):
raise Exception("Throws!") raise Exception("Throws!")
def resolve_request(self, args, context, info): def resolve_request(self, info):
request = context return info.context.GET.get('q')
return request.GET.get('q')
def resolve_test(self, args, context, info): def resolve_test(self, info, who=None):
return 'Hello %s' % (args.get('who') or 'World') return 'Hello %s' % (who or 'World')
class MutationRoot(ObjectType): class MutationRoot(ObjectType):
write_test = graphene.Field(QueryRoot) write_test = graphene.Field(QueryRoot)
def resolve_write_test(self, args, context, info): def resolve_write_test(self, info):
return QueryRoot() return QueryRoot()

View File

@ -84,7 +84,7 @@ def test_should_auto_convert_id():
def test_should_auto_convert_id(): def test_should_auto_convert_id():
assert_conversion(models.UUIDField, graphene.ID) assert_conversion(models.UUIDField, graphene.UUID)
def test_should_auto_convert_duration(): def test_should_auto_convert_duration():
@ -224,7 +224,7 @@ def test_should_manytomany_convert_connectionorlist_connection():
assert isinstance(graphene_field, graphene.Dynamic) assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type() dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, ConnectionField) assert isinstance(dynamic_field, ConnectionField)
assert dynamic_field.type == A.Connection assert dynamic_field.type == A._meta.connection
def test_should_manytoone_convert_connectionorlist(): def test_should_manytoone_convert_connectionorlist():

View File

@ -65,7 +65,7 @@ def test_should_regex_convert_string():
def test_should_uuid_convert_string(): def test_should_uuid_convert_string():
if hasattr(forms, 'UUIDField'): if hasattr(forms, 'UUIDField'):
assert_conversion(forms.UUIDField, graphene.String) assert_conversion(forms.UUIDField, graphene.UUID)
def test_should_integer_convert_int(): def test_should_integer_convert_int():

View File

@ -46,7 +46,7 @@ def test_should_query_simplelazy_objects():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
def resolve_reporter(self, args, context, info): def resolve_reporter(self, info):
return SimpleLazyObject(lambda: Reporter(id=1)) return SimpleLazyObject(lambda: Reporter(id=1))
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
@ -75,7 +75,7 @@ def test_should_query_well():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
def resolve_reporter(self, *args, **kwargs): def resolve_reporter(self, info):
return Reporter(first_name='ABA', last_name='X') return Reporter(first_name='ABA', last_name='X')
query = ''' query = '''
@ -119,7 +119,7 @@ def test_should_query_postgres_fields():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
event = graphene.Field(EventType) event = graphene.Field(EventType)
def resolve_event(self, *args, **kwargs): def resolve_event(self, info):
return Event( return Event(
ages=(0, 10), ages=(0, 10),
data={'angry_babies': True}, data={'angry_babies': True},
@ -162,10 +162,10 @@ def test_should_node():
interfaces = (Node, ) interfaces = (Node, )
@classmethod @classmethod
def get_node(cls, id, context, info): def get_node(cls, info, id):
return Reporter(id=2, first_name='Cookie Monster') return Reporter(id=2, first_name='Cookie Monster')
def resolve_articles(self, *args, **kwargs): def resolve_articles(self, info, **args):
return [Article(headline='Hi!')] return [Article(headline='Hi!')]
class ArticleNode(DjangoObjectType): class ArticleNode(DjangoObjectType):
@ -175,7 +175,7 @@ def test_should_node():
interfaces = (Node, ) interfaces = (Node, )
@classmethod @classmethod
def get_node(cls, id, context, info): def get_node(cls, info, id):
return Article(id=1, headline='Article node', pub_date=datetime.date(2002, 3, 11)) return Article(id=1, headline='Article node', pub_date=datetime.date(2002, 3, 11))
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
@ -183,7 +183,7 @@ def test_should_node():
reporter = graphene.Field(ReporterNode) reporter = graphene.Field(ReporterNode)
article = graphene.Field(ArticleNode) article = graphene.Field(ArticleNode)
def resolve_reporter(self, *args, **kwargs): def resolve_reporter(self, info):
return Reporter(id=1, first_name='ABA', last_name='X') return Reporter(id=1, first_name='ABA', last_name='X')
query = ''' query = '''
@ -250,7 +250,7 @@ def test_should_query_connectionfields():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
def resolve_all_reporters(self, args, context, info): def resolve_all_reporters(self, info, **args):
return [Reporter(id=1)] return [Reporter(id=1)]
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
@ -308,10 +308,10 @@ def test_should_keep_annotations():
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
all_articles = DjangoConnectionField(ArticleType) all_articles = DjangoConnectionField(ArticleType)
def resolve_all_reporters(self, args, context, info): def resolve_all_reporters(self, info, **args):
return Reporter.objects.annotate(articles_c=Count('articles')).order_by('articles_c') return Reporter.objects.annotate(articles_c=Count('articles')).order_by('articles_c')
def resolve_all_articles(self, args, context, info): def resolve_all_articles(self, info, **args):
return Article.objects.annotate(import_avg=Avg('importance')).order_by('import_avg') return Article.objects.annotate(import_avg=Avg('importance')).order_by('import_avg')
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
@ -618,7 +618,7 @@ def test_should_query_promise_connectionfields():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
def resolve_all_reporters(self, *args, **kwargs): def resolve_all_reporters(self, info, **args):
return Promise.resolve([Reporter(id=1)]) return Promise.resolve([Reporter(id=1)])
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
@ -673,10 +673,11 @@ def test_should_query_dataloader_fields():
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node, )
use_connection = True
articles = DjangoConnectionField(ArticleType) articles = DjangoConnectionField(ArticleType)
def resolve_articles(self, *args, **kwargs): def resolve_articles(self, info, **args):
return article_loader.load(self.id) return article_loader.load(self.id)
class Query(graphene.ObjectType): class Query(graphene.ObjectType):

View File

@ -38,7 +38,7 @@ def test_django_interface():
@patch('graphene_django.tests.models.Article.objects.get', return_value=Article(id=1)) @patch('graphene_django.tests.models.Article.objects.get', return_value=Article(id=1))
def test_django_get_node(get): def test_django_get_node(get):
article = Article.get_node(1, None, None) article = Article.get_node(None, 1)
get.assert_called_with(pk=1) get.assert_called_with(pk=1)
assert article.id == 1 assert article.id == 1

View File

@ -1,13 +1,10 @@
from collections import OrderedDict from collections import OrderedDict
import six
from django.utils.functional import SimpleLazyObject from django.utils.functional import SimpleLazyObject
from graphene import Field, ObjectType from graphene import Field
from graphene.types.objecttype import ObjectTypeMeta from graphene.relay import Connection, Node
from graphene.types.options import Options from graphene.types.objecttype import ObjectType, ObjectTypeOptions
from graphene.types.utils import merge, yank_fields_from_attrs from graphene.types.utils import yank_fields_from_attrs
from graphene.utils.is_base_type import is_base_type
from .converter import convert_django_field_with_choices from .converter import convert_django_field_with_choices
from .registry import Registry, get_global_registry from .registry import Registry, get_global_registry
@ -15,16 +12,14 @@ from .utils import (DJANGO_FILTER_INSTALLED, get_model_fields,
is_valid_django_model) is_valid_django_model)
def construct_fields(options): def construct_fields(model, registry, only_fields, exclude_fields):
_model_fields = get_model_fields(options.model) _model_fields = get_model_fields(model)
only_fields = options.only_fields
exclude_fields = options.exclude_fields
fields = OrderedDict() fields = OrderedDict()
for name, field in _model_fields: for name, field in _model_fields:
is_not_in_only = only_fields and name not in options.only_fields is_not_in_only = only_fields and name not in only_fields
is_already_created = name in options.fields # is_already_created = name in options.fields
is_excluded = name in exclude_fields or is_already_created is_excluded = name in exclude_fields # or is_already_created
# https://docs.djangoproject.com/en/1.10/ref/models/fields/#django.db.models.ForeignKey.related_query_name # https://docs.djangoproject.com/en/1.10/ref/models/fields/#django.db.models.ForeignKey.related_query_name
is_no_backref = str(name).endswith('+') is_no_backref = str(name).endswith('+')
if is_not_in_only or is_excluded or is_no_backref: if is_not_in_only or is_excluded or is_no_backref:
@ -32,78 +27,74 @@ def construct_fields(options):
# in there. Or when we exclude this field in exclude_fields. # in there. Or when we exclude this field in exclude_fields.
# Or when there is no back reference. # Or when there is no back reference.
continue continue
converted = convert_django_field_with_choices(field, options.registry) converted = convert_django_field_with_choices(field, registry)
fields[name] = converted fields[name] = converted
return fields return fields
class DjangoObjectTypeMeta(ObjectTypeMeta): class DjangoObjectTypeOptions(ObjectTypeOptions):
model = None # type: Model
registry = None # type: Registry
connection = None # type: Type[Connection]
@staticmethod filter_fields = ()
def __new__(cls, name, bases, attrs):
# Also ensure initialization is only performed for subclasses of
# DjangoObjectType
if not is_base_type(bases, DjangoObjectTypeMeta):
return type.__new__(cls, name, bases, attrs)
defaults = dict(
name=name,
description=attrs.pop('__doc__', None),
model=None,
local_fields=None,
only_fields=(),
exclude_fields=(),
interfaces=(),
skip_registry=False,
registry=None
)
if DJANGO_FILTER_INSTALLED:
# In case Django filter is available, then
# we allow more attributes in Meta
defaults.update(
filter_fields=(),
)
options = Options( class DjangoObjectType(ObjectType):
attrs.pop('Meta', None), @classmethod
**defaults def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False,
) only_fields=(), exclude_fields=(), filter_fields=None, connection=None,
if not options.registry: use_connection=None, interfaces=(), **options):
options.registry = get_global_registry() assert is_valid_django_model(model), (
assert isinstance(options.registry, Registry), (
'The attribute registry in {}.Meta needs to be an instance of '
'Registry, received "{}".'
).format(name, options.registry)
assert is_valid_django_model(options.model), (
'You need to pass a valid Django Model in {}.Meta, received "{}".' 'You need to pass a valid Django Model in {}.Meta, received "{}".'
).format(name, options.model) ).format(cls.__name__, model)
cls = ObjectTypeMeta.__new__(cls, name, bases, dict(attrs, _meta=options)) if not registry:
registry = get_global_registry()
options.registry.register(cls) assert isinstance(registry, Registry), (
'The attribute registry in {} needs to be an instance of '
'Registry, received "{}".'
).format(cls.__name__, registry)
options.django_fields = yank_fields_from_attrs( if not DJANGO_FILTER_INSTALLED and filter_fields:
construct_fields(options), raise Exception("Can only set filter_fields if Django-Filter is installed")
django_fields = yank_fields_from_attrs(
construct_fields(model, registry, only_fields, exclude_fields),
_as=Field, _as=Field,
) )
options.fields = merge(
options.interface_fields,
options.django_fields,
options.base_fields,
options.local_fields
)
return cls if use_connection is None and interfaces:
use_connection = any((issubclass(interface, Node) for interface in interfaces))
if use_connection and not connection:
# We create the connection automatically
connection = Connection.create_type('{}Connection'.format(cls.__name__), node=cls)
class DjangoObjectType(six.with_metaclass(DjangoObjectTypeMeta, ObjectType)): if connection is not None:
assert issubclass(connection, Connection), (
"The connection must be a Connection. Received {}"
).format(connection.__name__)
def resolve_id(self, args, context, info): _meta = DjangoObjectTypeOptions(cls)
_meta.model = model
_meta.registry = registry
_meta.filter_fields = filter_fields
_meta.fields = django_fields
_meta.connection = connection
super(DjangoObjectType, cls).__init_subclass_with_meta__(_meta=_meta, interfaces=interfaces, **options)
if not skip_registry:
registry.register(cls)
def resolve_id(self, info):
return self.pk return self.pk
@classmethod @classmethod
def is_type_of(cls, root, context, info): def is_type_of(cls, root, info):
if isinstance(root, SimpleLazyObject): if isinstance(root, SimpleLazyObject):
root._setup() root._setup()
root = root._wrapped root = root._wrapped
@ -117,7 +108,7 @@ class DjangoObjectType(six.with_metaclass(DjangoObjectTypeMeta, ObjectType)):
return model == cls._meta.model return model == cls._meta.model
@classmethod @classmethod
def get_node(cls, id, context, info): def get_node(cls, info, id):
try: try:
return cls._meta.model.objects.get(pk=id) return cls._meta.model.objects.get(pk=id)
except cls._meta.model.DoesNotExist: except cls._meta.model.DoesNotExist:

View File

@ -13,3 +13,6 @@ omit = */tests/*
[isort] [isort]
known_first_party=graphene,graphene_django known_first_party=graphene,graphene_django
[bdist_wheel]
universal=1

View File

@ -1,7 +1,16 @@
from setuptools import find_packages, setup from setuptools import find_packages, setup
import sys
import ast
import re
_version_re = re.compile(r'__version__\s+=\s+(.*)')
with open('graphene_django/__init__.py', 'rb') as f:
version = str(ast.literal_eval(_version_re.search(
f.read().decode('utf-8')).group(1)))
rest_framework_require = [ rest_framework_require = [
'djangorestframework==3.6.3', 'djangorestframework>=3.6.3',
] ]
@ -17,7 +26,7 @@ tests_require = [
setup( setup(
name='graphene-django', name='graphene-django',
version='1.3', version=version,
description='Graphene Django integration', description='Graphene Django integration',
long_description=open('README.rst').read(), long_description=open('README.rst').read(),
@ -48,11 +57,11 @@ setup(
install_requires=[ install_requires=[
'six>=1.10.0', 'six>=1.10.0',
'graphene>=1.4', 'graphene>=2.0.dev',
'Django>=1.8.0', 'Django>=1.8.0',
'iso8601', 'iso8601',
'singledispatch>=3.4.0.3', 'singledispatch>=3.4.0.3',
'promise>=2.0', 'promise>=2.1.dev',
], ],
setup_requires=[ setup_requires=[
'pytest-runner', 'pytest-runner',