Merge branch 'master' into fix-deprecation-warnings

This commit is contained in:
Syrus Akbary 2017-08-29 22:33:10 -07:00 committed by GitHub
commit a4cc360184
23 changed files with 220 additions and 253 deletions

View File

@ -4,7 +4,7 @@ Filtering
Graphene integrates with
`django-filter <https://django-filter.readthedocs.org>`__ to provide
filtering of results. See the `usage
documentation <https://django-filter.readthedocs.org/en/latest/usage.html#the-filter>`__
documentation <https://django-filter.readthedocs.io/en/latest/guide/usage.html#the-filter>`__
for details on the format for ``filter_fields``.
This filtering is automatically available when implementing a ``relay.Node``.
@ -26,7 +26,7 @@ Filterable fields
The ``filter_fields`` parameter is used to specify the fields which can
be filtered upon. The value specified here is passed directly to
``django-filter``, so see the `filtering
documentation <https://django-filter.readthedocs.org/en/latest/usage.html#the-filter>`__
documentation <https://django-filter.readthedocs.io/en/latest/guide/usage.html#the-filter>`__
for full details on the range of options available.
For example:

View File

@ -16,7 +16,7 @@ class Ship(DjangoObjectType):
interfaces = (relay.Node, )
@classmethod
def get_node(cls, id, context, info):
def get_node(cls, info, id):
node = get_ship(id)
return node
@ -34,7 +34,7 @@ class Faction(DjangoObjectType):
interfaces = (relay.Node, )
@classmethod
def get_node(cls, id, context, info):
def get_node(cls, info, id):
return get_faction(id)
@ -48,9 +48,7 @@ class IntroduceShip(relay.ClientIDMutation):
faction = graphene.Field(Faction)
@classmethod
def mutate_and_get_payload(cls, input, context, info):
ship_name = input.get('ship_name')
faction_id = input.get('faction_id')
def mutate_and_get_payload(cls, root, info, ship_name, faction_id, client_mutation_id=None):
ship = create_ship(ship_name, faction_id)
faction = get_faction(faction_id)
return IntroduceShip(ship=ship, faction=faction)

View File

@ -5,5 +5,10 @@ from .fields import (
DjangoConnectionField,
)
__all__ = ['DjangoObjectType',
'DjangoConnectionField']
__version__ = '2.0.dev2017073101'
__all__ = [
'__version__',
'DjangoObjectType',
'DjangoConnectionField'
]

View File

@ -2,15 +2,14 @@ from django.db import models
from django.utils.encoding import force_text
from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List,
NonNull, String)
from graphene.relay import is_node
NonNull, String, UUID)
from graphene.types.datetime import DateTime, Time
from graphene.types.json import JSONString
from graphene.utils.str_converters import to_camel_case, to_const
from graphql import assert_valid_name
from .compat import ArrayField, HStoreField, JSONField, RangeField
from .fields import get_connection_field, DjangoListField
from .fields import DjangoListField, DjangoConnectionField
from .utils import import_single_dispatch
singledispatch = import_single_dispatch()
@ -79,11 +78,15 @@ def convert_field_to_string(field, registry=None):
@convert_django_field.register(models.AutoField)
@convert_django_field.register(models.UUIDField)
def convert_field_to_id(field, registry=None):
return ID(description=field.help_text, required=not field.null)
@convert_django_field.register(models.UUIDField)
def convert_field_to_uuid(field, registry=None):
return UUID(description=field.help_text, required=not field.null)
@convert_django_field.register(models.PositiveIntegerField)
@convert_django_field.register(models.PositiveSmallIntegerField)
@convert_django_field.register(models.SmallIntegerField)
@ -148,8 +151,16 @@ def convert_field_to_list_or_connection(field, registry=None):
if not _type:
return
if is_node(_type):
return get_connection_field(_type)
# If there is a connection, we should transform the field
# into a DjangoConnectionField
if _type._meta.connection:
# Use a DjangoFilterConnectionField if there are
# defined filter_fields in the DjangoObjectType Meta
if _type._meta.filter_fields:
from .filter.fields import DjangoFilterConnectionField
return DjangoFilterConnectionField(_type)
return DjangoConnectionField(_type)
return DjangoListField(_type)

View File

@ -39,7 +39,8 @@ class DjangoDebugContext(object):
class DjangoDebugMiddleware(object):
def resolve(self, next, root, args, context, info):
def resolve(self, next, root, info, **args):
context = info.context
django_debug = getattr(context, 'django_debug', None)
if not django_debug:
if context is None:
@ -52,6 +53,6 @@ class DjangoDebugMiddleware(object):
))
if info.schema.get_type('DjangoDebug') == info.return_type:
return context.django_debug.get_debug_promise()
promise = next(root, args, context, info)
promise = next(root, info, **args)
context.django_debug.add_promise(promise)
return promise

View File

@ -33,7 +33,7 @@ def test_should_query_field():
reporter = graphene.Field(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_reporter(self, *args, **kwargs):
def resolve_reporter(self, info, **args):
return Reporter.objects.first()
query = '''
@ -80,7 +80,7 @@ def test_should_query_list():
all_reporters = graphene.List(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_all_reporters(self, *args, **kwargs):
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
query = '''
@ -129,7 +129,7 @@ def test_should_query_connection():
all_reporters = DjangoConnectionField(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_all_reporters(self, *args, **kwargs):
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
query = '''
@ -181,11 +181,11 @@ def test_should_query_connectionfilter():
interfaces = (Node, )
class Query(graphene.ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterType)
all_reporters = DjangoFilterConnectionField(ReporterType, fields=['last_name'])
s = graphene.String(resolver=lambda *_: "S")
debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_all_reporters(self, *args, **kwargs):
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
query = '''

View File

@ -9,7 +9,7 @@ from graphene.relay import ConnectionField, PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice
from .settings import graphene_settings
from .utils import DJANGO_FILTER_INSTALLED, maybe_queryset
from .utils import maybe_queryset
class DjangoListField(Field):
@ -22,8 +22,8 @@ class DjangoListField(Field):
return self.type.of_type._meta.node._meta.model
@staticmethod
def list_resolver(resolver, root, args, context, info):
return maybe_queryset(resolver(root, args, context, info))
def list_resolver(resolver, root, info, **args):
return maybe_queryset(resolver(root, info, **args))
def get_resolver(self, parent_resolver):
return partial(self.list_resolver, parent_resolver)
@ -43,6 +43,14 @@ class DjangoConnectionField(ConnectionField):
)
super(DjangoConnectionField, self).__init__(*args, **kwargs)
@property
def type(self):
from .types import DjangoObjectType
_type = super(ConnectionField, self).type
assert issubclass(_type, DjangoObjectType), "DjangoConnectionField only accepts DjangoObjectType types"
assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__)
return _type._meta.connection
@property
def node_type(self):
return self.type._meta.node
@ -89,7 +97,7 @@ class DjangoConnectionField(ConnectionField):
@classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit,
enforce_first_or_last, root, args, context, info):
enforce_first_or_last, root, info, **args):
first = args.get('first')
last = args.get('last')
@ -111,7 +119,7 @@ class DjangoConnectionField(ConnectionField):
).format(first, info.field_name, max_limit)
args['last'] = min(last, max_limit)
iterable = resolver(root, args, context, info)
iterable = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
if Promise.is_thenable(iterable):
@ -128,10 +136,3 @@ class DjangoConnectionField(ConnectionField):
self.max_limit,
self.enforce_first_or_last
)
def get_connection_field(*args, **kwargs):
if DJANGO_FILTER_INSTALLED:
from .filter.fields import DjangoFilterConnectionField
return DjangoFilterConnectionField(*args, **kwargs)
return DjangoConnectionField(*args, **kwargs)

View File

@ -1,7 +1,6 @@
from collections import OrderedDict
from functools import partial
# from graphene.relay import is_node
from graphene.types.argument import to_arguments
from ..fields import DjangoConnectionField
from .utils import get_filtering_args_from_filterset, get_filterset_class
@ -69,7 +68,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
@classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit,
enforce_first_or_last, filterset_class, filtering_args,
root, args, context, info):
root, info, **args):
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
qs = filterset_class(
data=filter_kwargs,
@ -83,9 +82,8 @@ class DjangoFilterConnectionField(DjangoConnectionField):
max_limit,
enforce_first_or_last,
root,
args,
context,
info
info,
**args
)
def get_resolver(self, parent_resolver):

View File

@ -114,9 +114,9 @@ def test_filter_explicit_filterset_orderable():
assert_orderable(field)
def test_filter_shortcut_filterset_orderable_true():
field = DjangoFilterConnectionField(ReporterNode)
assert_not_orderable(field)
# def test_filter_shortcut_filterset_orderable_true():
# field = DjangoFilterConnectionField(ReporterNode)
# assert_not_orderable(field)
# def test_filter_shortcut_filterset_orderable_headline():
@ -356,7 +356,7 @@ def test_recursive_filter_connection():
class ReporterFilterNode(DjangoObjectType):
child_reporters = DjangoFilterConnectionField(lambda: ReporterFilterNode)
def resolve_child_reporters(self, args, context, info):
def resolve_child_reporters(self, **args):
return []
class Meta:
@ -399,7 +399,7 @@ def test_should_query_filter_node_limit():
filterset_class=ReporterFilter
)
def resolve_all_reporters(self, args, context, info):
def resolve_all_reporters(self, info, **args):
return Reporter.objects.order_by('a_choice')
Reporter.objects.create(
@ -499,7 +499,7 @@ def test_should_query_filter_node_double_limit_raises():
filterset_class=ReporterFilter
)
def resolve_all_reporters(self, args, context, info):
def resolve_all_reporters(self, info, **args):
return Reporter.objects.order_by('a_choice')[:2]
Reporter.objects.create(

View File

@ -1,7 +1,7 @@
from django import forms
from django.forms.fields import BaseTemporalField
from graphene import ID, Boolean, Float, Int, List, String
from graphene import ID, Boolean, Float, Int, List, String, UUID
from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField
from .utils import import_single_dispatch
@ -32,11 +32,15 @@ def convert_form_field(field):
@convert_form_field.register(forms.ChoiceField)
@convert_form_field.register(forms.RegexField)
@convert_form_field.register(forms.Field)
@convert_form_field.register(UUIDField)
def convert_form_field_to_string(field):
return String(description=field.help_text, required=field.required)
@convert_form_field.register(UUIDField)
def convert_form_field_to_uuid(field):
return UUID(description=field.help_text, required=field.required)
@convert_form_field.register(forms.IntegerField)
@convert_form_field.register(forms.NumberInput)
def convert_form_field_to_int(field):

View File

@ -1,3 +1,4 @@
class Registry(object):
def __init__(self):

View File

@ -1,116 +1,74 @@
from collections import OrderedDict
from functools import partial
import six
import graphene
from graphene.types import Argument, Field
from graphene.types.mutation import Mutation, MutationMeta
from graphene.types import Field, InputField
from graphene.types.mutation import MutationOptions
from graphene.relay.mutation import ClientIDMutation
from graphene.types.objecttype import (
ObjectTypeMeta,
merge,
yank_fields_from_attrs
)
from graphene.types.options import Options
from graphene.types.utils import get_field_as
from graphene.utils.is_base_type import is_base_type
from .serializer_converter import (
convert_serializer_to_input_type,
convert_serializer_field
)
from .types import ErrorType
class SerializerMutationOptions(Options):
def __init__(self, *args, **kwargs):
super().__init__(*args, serializer_class=None, **kwargs)
class SerializerMutationOptions(MutationOptions):
serializer_class = None
class SerializerMutationMeta(MutationMeta):
def __new__(cls, name, bases, attrs):
if not is_base_type(bases, SerializerMutationMeta):
return type.__new__(cls, name, bases, attrs)
options = Options(
attrs.pop('Meta', None),
name=name,
description=attrs.pop('__doc__', None),
serializer_class=None,
local_fields=None,
only_fields=(),
exclude_fields=(),
interfaces=(),
registry=None
def fields_for_serializer(serializer, only_fields, exclude_fields, is_input=False):
fields = OrderedDict()
for name, field in serializer.fields.items():
is_not_in_only = only_fields and name not in only_fields
is_excluded = (
name in exclude_fields # or
# name in already_created_fields
)
if not options.serializer_class:
raise Exception('Missing serializer_class')
if is_not_in_only or is_excluded:
continue
cls = ObjectTypeMeta.__new__(
cls, name, bases, dict(attrs, _meta=options)
)
serializer_fields = cls.fields_for_serializer(options)
options.serializer_fields = yank_fields_from_attrs(
serializer_fields,
_as=Field,
)
options.fields = merge(
options.interface_fields, options.serializer_fields,
options.base_fields, options.local_fields,
{'errors': get_field_as(cls.errors, Field)}
)
cls.Input = convert_serializer_to_input_type(options.serializer_class)
cls.Field = partial(
Field,
cls,
resolver=cls.mutate,
input=Argument(cls.Input, required=True)
)
return cls
@staticmethod
def fields_for_serializer(options):
serializer = options.serializer_class()
only_fields = options.only_fields
already_created_fields = {
name
for name, _ in options.local_fields.items()
}
fields = OrderedDict()
for name, field in serializer.fields.items():
is_not_in_only = only_fields and name not in only_fields
is_excluded = (
name in options.exclude_fields or
name in already_created_fields
)
if is_not_in_only or is_excluded:
continue
fields[name] = convert_serializer_field(field, is_input=False)
return fields
fields[name] = convert_serializer_field(field, is_input=is_input)
return fields
class SerializerMutation(six.with_metaclass(SerializerMutationMeta, Mutation)):
class SerializerMutation(ClientIDMutation):
class Meta:
abstract = True
errors = graphene.List(
ErrorType,
description='May contain more than one error for '
'same field.'
description='May contain more than one error for same field.'
)
@classmethod
def mutate(cls, instance, args, request, info):
input = args.get('input')
def __init_subclass_with_meta__(cls, serializer_class=None,
only_fields=(), exclude_fields=(), **options):
serializer = cls._meta.serializer_class(data=dict(input))
if not serializer_class:
raise Exception('serializer_class is required for the SerializerMutation')
serializer = serializer_class()
input_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=True)
output_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=False)
_meta = SerializerMutationOptions(cls)
_meta.fields = yank_fields_from_attrs(
output_fields,
_as=Field,
)
input_fields = yank_fields_from_attrs(
input_fields,
_as=InputField,
)
super(SerializerMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options)
@classmethod
def mutate_and_get_payload(cls, root, info, **input):
serializer = cls._meta.serializer_class(data=input)
if serializer.is_valid():
return cls.perform_mutate(serializer, info)
@ -125,5 +83,4 @@ class SerializerMutation(six.with_metaclass(SerializerMutationMeta, Mutation)):
@classmethod
def perform_mutate(cls, serializer, info):
obj = serializer.save()
return cls(errors=[], **obj)
return cls(**obj)

View File

@ -2,6 +2,7 @@ from django.core.exceptions import ImproperlyConfigured
from rest_framework import serializers
import graphene
from graphene import Dynamic
from ..registry import get_global_registry
from ..utils import import_single_dispatch
@ -10,21 +11,6 @@ from .types import DictType
singledispatch = import_single_dispatch()
def convert_serializer_to_input_type(serializer_class):
serializer = serializer_class()
items = {
name: convert_serializer_field(field)
for name, field in serializer.fields.items()
}
return type(
'{}Input'.format(serializer.__class__.__name__),
(graphene.InputObjectType, ),
items
)
@singledispatch
def get_graphene_type_from_serializer_field(field):
raise ImproperlyConfigured(
@ -56,7 +42,8 @@ def convert_serializer_field(field, is_input=True):
if isinstance(field, serializers.ModelSerializer):
if is_input:
graphql_type = convert_serializer_to_input_type(field.__class__)
return Dynamic(lambda: None)
# graphql_type = convert_serializer_to_input_type(field.__class__)
else:
global_registry = get_global_registry()
field_model = field.Meta.model

View File

@ -28,7 +28,7 @@ def test_needs_serializer_class():
class MyMutation(SerializerMutation):
pass
assert exc.value.args[0] == 'Missing serializer_class'
assert str(exc.value) == 'serializer_class is required for the SerializerMutation'
def test_has_fields():
@ -65,6 +65,7 @@ def test_nested_model():
assert model_field.type == MyFakeModelGrapheneType
model_input = MyMutation.Input._meta.fields['model']
model_input_type = model_input._type.of_type
assert issubclass(model_input_type, InputObjectType)
assert 'cool_name' in model_input_type._meta.fields
model_input_type = model_input.get_type()
assert not model_input_type
# assert issubclass(model_input_type, InputObjectType)
# assert 'cool_name' in model_input_type._meta.fields

View File

@ -11,7 +11,7 @@ class Character(DjangoObjectType):
model = Reporter
interfaces = (relay.Node, )
def get_node(self, id, context, info):
def get_node(self, info, id):
pass
@ -22,17 +22,17 @@ class Human(DjangoObjectType):
model = Article
interfaces = (relay.Node, )
def resolve_raises(self, *args):
def resolve_raises(self, info):
raise Exception("This field should raise exception")
def get_node(self, id):
def get_node(self, info, id):
pass
class Query(graphene.ObjectType):
human = graphene.Field(Human)
def resolve_human(self, args, context, info):
def resolve_human(self, info):
return Human()

View File

@ -8,21 +8,20 @@ class QueryRoot(ObjectType):
request = graphene.String(required=True)
test = graphene.String(who=graphene.String())
def resolve_thrower(self, args, context, info):
def resolve_thrower(self, info):
raise Exception("Throws!")
def resolve_request(self, args, context, info):
request = context
return request.GET.get('q')
def resolve_request(self, info):
return info.context.GET.get('q')
def resolve_test(self, args, context, info):
return 'Hello %s' % (args.get('who') or 'World')
def resolve_test(self, info, who=None):
return 'Hello %s' % (who or 'World')
class MutationRoot(ObjectType):
write_test = graphene.Field(QueryRoot)
def resolve_write_test(self, args, context, info):
def resolve_write_test(self, info):
return QueryRoot()

View File

@ -84,7 +84,7 @@ def test_should_auto_convert_id():
def test_should_auto_convert_id():
assert_conversion(models.UUIDField, graphene.ID)
assert_conversion(models.UUIDField, graphene.UUID)
def test_should_auto_convert_duration():
@ -224,7 +224,7 @@ def test_should_manytomany_convert_connectionorlist_connection():
assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, ConnectionField)
assert dynamic_field.type == A.Connection
assert dynamic_field.type == A._meta.connection
def test_should_manytoone_convert_connectionorlist():

View File

@ -65,7 +65,7 @@ def test_should_regex_convert_string():
def test_should_uuid_convert_string():
if hasattr(forms, 'UUIDField'):
assert_conversion(forms.UUIDField, graphene.String)
assert_conversion(forms.UUIDField, graphene.UUID)
def test_should_integer_convert_int():

View File

@ -46,7 +46,7 @@ def test_should_query_simplelazy_objects():
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
def resolve_reporter(self, args, context, info):
def resolve_reporter(self, info):
return SimpleLazyObject(lambda: Reporter(id=1))
schema = graphene.Schema(query=Query)
@ -75,7 +75,7 @@ def test_should_query_well():
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
def resolve_reporter(self, *args, **kwargs):
def resolve_reporter(self, info):
return Reporter(first_name='ABA', last_name='X')
query = '''
@ -119,7 +119,7 @@ def test_should_query_postgres_fields():
class Query(graphene.ObjectType):
event = graphene.Field(EventType)
def resolve_event(self, *args, **kwargs):
def resolve_event(self, info):
return Event(
ages=(0, 10),
data={'angry_babies': True},
@ -162,10 +162,10 @@ def test_should_node():
interfaces = (Node, )
@classmethod
def get_node(cls, id, context, info):
def get_node(cls, info, id):
return Reporter(id=2, first_name='Cookie Monster')
def resolve_articles(self, *args, **kwargs):
def resolve_articles(self, info, **args):
return [Article(headline='Hi!')]
class ArticleNode(DjangoObjectType):
@ -175,7 +175,7 @@ def test_should_node():
interfaces = (Node, )
@classmethod
def get_node(cls, id, context, info):
def get_node(cls, info, id):
return Article(id=1, headline='Article node', pub_date=datetime.date(2002, 3, 11))
class Query(graphene.ObjectType):
@ -183,7 +183,7 @@ def test_should_node():
reporter = graphene.Field(ReporterNode)
article = graphene.Field(ArticleNode)
def resolve_reporter(self, *args, **kwargs):
def resolve_reporter(self, info):
return Reporter(id=1, first_name='ABA', last_name='X')
query = '''
@ -250,7 +250,7 @@ def test_should_query_connectionfields():
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
def resolve_all_reporters(self, args, context, info):
def resolve_all_reporters(self, info, **args):
return [Reporter(id=1)]
schema = graphene.Schema(query=Query)
@ -308,10 +308,10 @@ def test_should_keep_annotations():
all_reporters = DjangoConnectionField(ReporterType)
all_articles = DjangoConnectionField(ArticleType)
def resolve_all_reporters(self, args, context, info):
def resolve_all_reporters(self, info, **args):
return Reporter.objects.annotate(articles_c=Count('articles')).order_by('articles_c')
def resolve_all_articles(self, args, context, info):
def resolve_all_articles(self, info, **args):
return Article.objects.annotate(import_avg=Avg('importance')).order_by('import_avg')
schema = graphene.Schema(query=Query)
@ -618,7 +618,7 @@ def test_should_query_promise_connectionfields():
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
def resolve_all_reporters(self, *args, **kwargs):
def resolve_all_reporters(self, info, **args):
return Promise.resolve([Reporter(id=1)])
schema = graphene.Schema(query=Query)
@ -673,10 +673,11 @@ def test_should_query_dataloader_fields():
class Meta:
model = Reporter
interfaces = (Node, )
use_connection = True
articles = DjangoConnectionField(ArticleType)
def resolve_articles(self, *args, **kwargs):
def resolve_articles(self, info, **args):
return article_loader.load(self.id)
class Query(graphene.ObjectType):

View File

@ -38,7 +38,7 @@ def test_django_interface():
@patch('graphene_django.tests.models.Article.objects.get', return_value=Article(id=1))
def test_django_get_node(get):
article = Article.get_node(1, None, None)
article = Article.get_node(None, 1)
get.assert_called_with(pk=1)
assert article.id == 1

View File

@ -1,13 +1,10 @@
from collections import OrderedDict
import six
from django.utils.functional import SimpleLazyObject
from graphene import Field, ObjectType
from graphene.types.objecttype import ObjectTypeMeta
from graphene.types.options import Options
from graphene.types.utils import merge, yank_fields_from_attrs
from graphene.utils.is_base_type import is_base_type
from graphene import Field
from graphene.relay import Connection, Node
from graphene.types.objecttype import ObjectType, ObjectTypeOptions
from graphene.types.utils import yank_fields_from_attrs
from .converter import convert_django_field_with_choices
from .registry import Registry, get_global_registry
@ -15,16 +12,14 @@ from .utils import (DJANGO_FILTER_INSTALLED, get_model_fields,
is_valid_django_model)
def construct_fields(options):
_model_fields = get_model_fields(options.model)
only_fields = options.only_fields
exclude_fields = options.exclude_fields
def construct_fields(model, registry, only_fields, exclude_fields):
_model_fields = get_model_fields(model)
fields = OrderedDict()
for name, field in _model_fields:
is_not_in_only = only_fields and name not in options.only_fields
is_already_created = name in options.fields
is_excluded = name in exclude_fields or is_already_created
is_not_in_only = only_fields and name not in only_fields
# is_already_created = name in options.fields
is_excluded = name in exclude_fields # or is_already_created
# https://docs.djangoproject.com/en/1.10/ref/models/fields/#django.db.models.ForeignKey.related_query_name
is_no_backref = str(name).endswith('+')
if is_not_in_only or is_excluded or is_no_backref:
@ -32,78 +27,74 @@ def construct_fields(options):
# in there. Or when we exclude this field in exclude_fields.
# Or when there is no back reference.
continue
converted = convert_django_field_with_choices(field, options.registry)
converted = convert_django_field_with_choices(field, registry)
fields[name] = converted
return fields
class DjangoObjectTypeMeta(ObjectTypeMeta):
class DjangoObjectTypeOptions(ObjectTypeOptions):
model = None # type: Model
registry = None # type: Registry
connection = None # type: Type[Connection]
@staticmethod
def __new__(cls, name, bases, attrs):
# Also ensure initialization is only performed for subclasses of
# DjangoObjectType
if not is_base_type(bases, DjangoObjectTypeMeta):
return type.__new__(cls, name, bases, attrs)
filter_fields = ()
defaults = dict(
name=name,
description=attrs.pop('__doc__', None),
model=None,
local_fields=None,
only_fields=(),
exclude_fields=(),
interfaces=(),
skip_registry=False,
registry=None
)
if DJANGO_FILTER_INSTALLED:
# In case Django filter is available, then
# we allow more attributes in Meta
defaults.update(
filter_fields=(),
)
options = Options(
attrs.pop('Meta', None),
**defaults
)
if not options.registry:
options.registry = get_global_registry()
assert isinstance(options.registry, Registry), (
'The attribute registry in {}.Meta needs to be an instance of '
'Registry, received "{}".'
).format(name, options.registry)
assert is_valid_django_model(options.model), (
class DjangoObjectType(ObjectType):
@classmethod
def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False,
only_fields=(), exclude_fields=(), filter_fields=None, connection=None,
use_connection=None, interfaces=(), **options):
assert is_valid_django_model(model), (
'You need to pass a valid Django Model in {}.Meta, received "{}".'
).format(name, options.model)
).format(cls.__name__, model)
cls = ObjectTypeMeta.__new__(cls, name, bases, dict(attrs, _meta=options))
if not registry:
registry = get_global_registry()
options.registry.register(cls)
assert isinstance(registry, Registry), (
'The attribute registry in {} needs to be an instance of '
'Registry, received "{}".'
).format(cls.__name__, registry)
options.django_fields = yank_fields_from_attrs(
construct_fields(options),
if not DJANGO_FILTER_INSTALLED and filter_fields:
raise Exception("Can only set filter_fields if Django-Filter is installed")
django_fields = yank_fields_from_attrs(
construct_fields(model, registry, only_fields, exclude_fields),
_as=Field,
)
options.fields = merge(
options.interface_fields,
options.django_fields,
options.base_fields,
options.local_fields
)
return cls
if use_connection is None and interfaces:
use_connection = any((issubclass(interface, Node) for interface in interfaces))
if use_connection and not connection:
# We create the connection automatically
connection = Connection.create_type('{}Connection'.format(cls.__name__), node=cls)
class DjangoObjectType(six.with_metaclass(DjangoObjectTypeMeta, ObjectType)):
if connection is not None:
assert issubclass(connection, Connection), (
"The connection must be a Connection. Received {}"
).format(connection.__name__)
def resolve_id(self, args, context, info):
_meta = DjangoObjectTypeOptions(cls)
_meta.model = model
_meta.registry = registry
_meta.filter_fields = filter_fields
_meta.fields = django_fields
_meta.connection = connection
super(DjangoObjectType, cls).__init_subclass_with_meta__(_meta=_meta, interfaces=interfaces, **options)
if not skip_registry:
registry.register(cls)
def resolve_id(self, info):
return self.pk
@classmethod
def is_type_of(cls, root, context, info):
def is_type_of(cls, root, info):
if isinstance(root, SimpleLazyObject):
root._setup()
root = root._wrapped
@ -117,7 +108,7 @@ class DjangoObjectType(six.with_metaclass(DjangoObjectTypeMeta, ObjectType)):
return model == cls._meta.model
@classmethod
def get_node(cls, id, context, info):
def get_node(cls, info, id):
try:
return cls._meta.model.objects.get(pk=id)
except cls._meta.model.DoesNotExist:

View File

@ -13,3 +13,6 @@ omit = */tests/*
[isort]
known_first_party=graphene,graphene_django
[bdist_wheel]
universal=1

View File

@ -1,7 +1,16 @@
from setuptools import find_packages, setup
import sys
import ast
import re
_version_re = re.compile(r'__version__\s+=\s+(.*)')
with open('graphene_django/__init__.py', 'rb') as f:
version = str(ast.literal_eval(_version_re.search(
f.read().decode('utf-8')).group(1)))
rest_framework_require = [
'djangorestframework==3.6.3',
'djangorestframework>=3.6.3',
]
@ -17,7 +26,7 @@ tests_require = [
setup(
name='graphene-django',
version='1.3',
version=version,
description='Graphene Django integration',
long_description=open('README.rst').read(),
@ -48,11 +57,11 @@ setup(
install_requires=[
'six>=1.10.0',
'graphene>=1.4',
'graphene>=2.0.dev',
'Django>=1.8.0',
'iso8601',
'singledispatch>=3.4.0.3',
'promise>=2.0',
'promise>=2.1.dev',
],
setup_requires=[
'pytest-runner',