Improved Django integration

This commit is contained in:
Syrus Akbary 2016-06-21 23:04:12 -07:00
parent 37ed617fce
commit f9303dab72
18 changed files with 108 additions and 72 deletions

View File

@ -1,5 +1,5 @@
from cookbook.ingredients.models import Category, Ingredient from cookbook.ingredients.models import Category, Ingredient
from graphene import ObjectType, relay from graphene import ObjectType, Field
from graphene_django.filter import DjangoFilterConnectionField from graphene_django.filter import DjangoFilterConnectionField
from graphene_django.types import DjangoNode, DjangoObjectType from graphene_django.types import DjangoNode, DjangoObjectType
@ -29,8 +29,8 @@ class IngredientNode(DjangoNode, DjangoObjectType):
class Query(ObjectType): class Query(ObjectType):
category = relay.NodeField(CategoryNode) category = Field(CategoryNode)
all_categories = DjangoFilterConnectionField(CategoryNode) all_categories = DjangoFilterConnectionField(CategoryNode)
ingredient = relay.NodeField(IngredientNode) ingredient = Field(IngredientNode)
all_ingredients = DjangoFilterConnectionField(IngredientNode) all_ingredients = DjangoFilterConnectionField(IngredientNode)

View File

@ -1,8 +1,9 @@
import cookbook.ingredients.schema
import graphene import graphene
import cookbook.ingredients.schema
# print cookbook.ingredients.schema.Query._meta.graphql_type.get_fields()['allIngredients'].args
class Query(cookbook.ingredients.schema.Query): class Query(cookbook.ingredients.schema.Query):
pass pass
schema = graphene.Schema(name='Cookbook Schema', query=Query) schema = graphene.Schema(query=Query)

View File

@ -3,7 +3,7 @@ from django.contrib import admin
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
from cookbook.schema import schema from cookbook.schema import schema
from graphene.contrib.django.views import GraphQLView from graphene_django.views import GraphQLView
urlpatterns = [ urlpatterns = [
url(r'^admin/', admin.site.urls), url(r'^admin/', admin.site.urls),

View File

@ -4,13 +4,14 @@ from django.utils.encoding import force_text
from graphene import Enum, List, ID, Boolean, Float, Int, String, Field, NonNull from graphene import Enum, List, ID, Boolean, Float, Int, String, Field, NonNull
from graphene.types.json import JSONString from graphene.types.json import JSONString
from graphene.types.datetime import DateTime from graphene.types.datetime import DateTime
from graphene.types.json import JSONString
from graphene.utils.str_converters import to_const from graphene.utils.str_converters import to_const
from graphene.relay import Node, ConnectionField from graphene.relay import Node
# from ...core.types.custom_scalars import DateTime, JSONString
from .compat import (ArrayField, HStoreField, JSONField, RangeField, from .compat import (ArrayField, HStoreField, JSONField, RangeField,
RelatedObject, UUIDField) RelatedObject, UUIDField)
from .utils import get_related_model, import_single_dispatch from .utils import get_related_model, import_single_dispatch
from .fields import DjangoConnectionField from .fields import get_connection_field
singledispatch = import_single_dispatch() singledispatch = import_single_dispatch()
@ -30,8 +31,7 @@ def convert_django_field_with_choices(field, registry=None):
meta = field.model._meta meta = field.model._meta
name = '{}{}'.format(meta.object_name, field.name.capitalize()) name = '{}{}'.format(meta.object_name, field.name.capitalize())
graphql_choices = list(convert_choices(choices)) graphql_choices = list(convert_choices(choices))
from collections import OrderedDict enum = Enum(name, list(graphql_choices))
enum = Enum(name, OrderedDict(graphql_choices))
return enum(description=field.help_text) return enum(description=field.help_text)
return convert_django_field(field, registry) return convert_django_field(field, registry)
@ -106,7 +106,7 @@ def convert_field_to_list_or_connection(field, registry=None):
return return
if issubclass(_type, Node): if issubclass(_type, Node):
return DjangoConnectionField(_type) return get_connection_field(_type)
return Field(List(_type)) return Field(List(_type))
@ -116,8 +116,8 @@ def convert_relatedfield_to_djangomodel(field, registry=None):
model = field.model model = field.model
_type = registry.get_type_for_model(model) _type = registry.get_type_for_model(model)
if issubclass(_type, Node): if issubclass(_type, Node):
return DjangoConnectionField(_type) return get_connection_field(_type)
return Field(List(_type)) return List(_type)
@convert_django_field.register(models.OneToOneField) @convert_django_field.register(models.OneToOneField)

View File

@ -1,4 +1,4 @@
from .....core import Boolean, Float, ObjectType, String from graphene import Boolean, Float, ObjectType, String
class DjangoDebugBaseSQL(ObjectType): class DjangoDebugBaseSQL(ObjectType):

View File

@ -1,8 +1,8 @@
import pytest import pytest
import graphene import graphene
from graphene.contrib.django import DjangoConnectionField, DjangoNode from graphene_django import DjangoConnectionField, DjangoNode, DjangoObjectType
from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED from graphene_django.utils import DJANGO_FILTER_INSTALLED
from ...tests.models import Reporter from ...tests.models import Reporter
from ..middleware import DjangoDebugMiddleware from ..middleware import DjangoDebugMiddleware
@ -23,7 +23,7 @@ def test_should_query_field():
r2 = Reporter(last_name='Griffin') r2 = Reporter(last_name='Griffin')
r2.save() r2.save()
class ReporterType(DjangoNode): class ReporterType(DjangoNode, DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
@ -69,7 +69,7 @@ def test_should_query_list():
r2 = Reporter(last_name='Griffin') r2 = Reporter(last_name='Griffin')
r2.save() r2.save()
class ReporterType(DjangoNode): class ReporterType(DjangoNode, DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
@ -117,7 +117,7 @@ def test_should_query_connection():
r2 = Reporter(last_name='Griffin') r2 = Reporter(last_name='Griffin')
r2.save() r2.save()
class ReporterType(DjangoNode): class ReporterType(DjangoNode, DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
@ -166,14 +166,14 @@ def test_should_query_connection():
@pytest.mark.skipif(not DJANGO_FILTER_INSTALLED, @pytest.mark.skipif(not DJANGO_FILTER_INSTALLED,
reason="requires django-filter") reason="requires django-filter")
def test_should_query_connectionfilter(): def test_should_query_connectionfilter():
from graphene.contrib.django.filter import DjangoFilterConnectionField from ...filter import DjangoFilterConnectionField
r1 = Reporter(last_name='ABA') r1 = Reporter(last_name='ABA')
r1.save() r1.save()
r2 = Reporter(last_name='Griffin') r2 = Reporter(last_name='Griffin')
r2.save() r2.save()
class ReporterType(DjangoNode): class ReporterType(DjangoNode, DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter

View File

@ -1,7 +1,6 @@
from ....core.classtypes.objecttype import ObjectType from graphene import ObjectType, List
from ....core.types import Field
from .sql.types import DjangoDebugBaseSQL from .sql.types import DjangoDebugBaseSQL
class DjangoDebug(ObjectType): class DjangoDebug(ObjectType):
sql = Field(DjangoDebugBaseSQL.List()) sql = List(DjangoDebugBaseSQL)

View File

@ -1,7 +1,7 @@
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from graphene.relay import ConnectionField from graphene.relay import ConnectionField
from graphql_relay.connection.arrayconnection import connection_from_list_slice from graphql_relay.connection.arrayconnection import connection_from_list_slice
from .utils import maybe_queryset from .utils import maybe_queryset, DJANGO_FILTER_INSTALLED
class DjangoConnectionField(ConnectionField): class DjangoConnectionField(ConnectionField):
@ -39,3 +39,10 @@ class DjangoConnectionField(ConnectionField):
connection_type=self.connection, connection_type=self.connection,
edge_type=self.connection.Edge, edge_type=self.connection.Edge,
) )
def get_connection_field(*args, **kwargs):
if DJANGO_FILTER_INSTALLED:
from .filter.fields import DjangoFilterConnectionField
return DjangoFilterConnectionField(*args, **kwargs)
return ConnectionField(*args, **kwargs)

View File

@ -1,5 +1,5 @@
import warnings import warnings
from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED from ..utils import DJANGO_FILTER_INSTALLED
if not DJANGO_FILTER_INSTALLED: if not DJANGO_FILTER_INSTALLED:
warnings.warn( warnings.warn(

View File

@ -1,6 +1,7 @@
from ..fields import DjangoConnectionField from ..fields import DjangoConnectionField
from .utils import get_filtering_args_from_filterset, get_filterset_class from .utils import get_filtering_args_from_filterset, get_filterset_class
from graphene.types.argument import to_arguments
class DjangoFilterConnectionField(DjangoConnectionField): class DjangoFilterConnectionField(DjangoConnectionField):
@ -18,7 +19,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
self.filterset_class = get_filterset_class(filterset_class, **meta) self.filterset_class = get_filterset_class(filterset_class, **meta)
self.filtering_args = get_filtering_args_from_filterset(self.filterset_class, type) self.filtering_args = get_filtering_args_from_filterset(self.filterset_class, type)
kwargs.setdefault('args', {}) kwargs.setdefault('args', {})
kwargs['args'].update(**self.filtering_args) kwargs['args'].update(to_arguments(self.filtering_args))
super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs) super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs)
def get_queryset(self, qs, args, info): def get_queryset(self, qs, args, info):

View File

@ -5,8 +5,7 @@ from django.utils.text import capfirst
from django_filters import Filter, MultipleChoiceFilter from django_filters import Filter, MultipleChoiceFilter
from django_filters.filterset import FilterSet, FilterSetMetaclass from django_filters.filterset import FilterSet, FilterSetMetaclass
from graphene.contrib.django.forms import (GlobalIDFormField, from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
GlobalIDMultipleChoiceField)
from graphql_relay.node.node import from_global_id from graphql_relay.node.node import from_global_id

View File

@ -1,6 +1,6 @@
import django_filters import django_filters
from graphene.contrib.django.tests.models import Article, Pet, Reporter from graphene_django.tests.models import Article, Pet, Reporter
class ArticleFilter(django_filters.FilterSet): class ArticleFilter(django_filters.FilterSet):

View File

@ -2,51 +2,58 @@ from datetime import datetime
import pytest import pytest
from graphene import ObjectType, Schema from graphene import ObjectType, Schema, Field
from graphene.contrib.django import DjangoNode from graphene_django import DjangoNode, DjangoObjectType
from graphene.contrib.django.forms import (GlobalIDFormField, from graphene_django.forms import (GlobalIDFormField,
GlobalIDMultipleChoiceField) GlobalIDMultipleChoiceField)
from graphene.contrib.django.tests.models import Article, Pet, Reporter from graphene_django.tests.models import Article, Pet, Reporter
from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED from graphene_django.utils import DJANGO_FILTER_INSTALLED
from graphene.relay import NodeField
pytestmark = [] pytestmark = []
if DJANGO_FILTER_INSTALLED: if DJANGO_FILTER_INSTALLED:
import django_filters import django_filters
from graphene.contrib.django.filter import (GlobalIDFilter, DjangoFilterConnectionField, from graphene_django.filter import (GlobalIDFilter, DjangoFilterConnectionField,
GlobalIDMultipleChoiceFilter) GlobalIDMultipleChoiceFilter)
from graphene.contrib.django.filter.tests.filters import ArticleFilter, PetFilter from graphene_django.filter.tests.filters import ArticleFilter, PetFilter
else: else:
pytestmark.append(pytest.mark.skipif(True, reason='django_filters not installed')) pytestmark.append(pytest.mark.skipif(True, reason='django_filters not installed'))
pytestmark.append(pytest.mark.django_db) pytestmark.append(pytest.mark.django_db)
class ArticleNode(DjangoNode): class ArticleNode(DjangoNode, DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
class ReporterNode(DjangoNode): class ReporterNode(DjangoNode, DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
class PetNode(DjangoNode): class PetNode(DjangoNode, DjangoObjectType):
class Meta: class Meta:
model = Pet model = Pet
schema = Schema() # schema = Schema()
def get_args(field):
if isinstance(field.args, list):
return {arg.name: arg for arg in field.args}
else:
return field.args
def assert_arguments(field, *arguments): def assert_arguments(field, *arguments):
ignore = ('after', 'before', 'first', 'last', 'orderBy') ignore = ('after', 'before', 'first', 'last', 'orderBy')
args = get_args(field)
actual = [ actual = [
name name
for name in schema.T(field.arguments) for name in args
if name not in ignore and not name.startswith('_') if name not in ignore and not name.startswith('_')
] ]
assert set(arguments) == set(actual), \ assert set(arguments) == set(actual), \
@ -57,12 +64,14 @@ def assert_arguments(field, *arguments):
def assert_orderable(field): def assert_orderable(field):
assert 'orderBy' in schema.T(field.arguments), \ args = get_args(field)
assert 'orderBy' in args, \
'Field cannot be ordered' 'Field cannot be ordered'
def assert_not_orderable(field): def assert_not_orderable(field):
assert 'orderBy' not in schema.T(field.arguments), \ args = get_args(field)
assert 'orderBy' not in args, \
'Field can be ordered' 'Field can be ordered'
@ -122,7 +131,7 @@ def test_filter_shortcut_filterset_extra_meta():
def test_filter_filterset_information_on_meta(): def test_filter_filterset_information_on_meta():
class ReporterFilterNode(DjangoNode): class ReporterFilterNode(DjangoNode, DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
@ -135,14 +144,14 @@ def test_filter_filterset_information_on_meta():
def test_filter_filterset_information_on_meta_related(): def test_filter_filterset_information_on_meta_related():
class ReporterFilterNode(DjangoNode): class ReporterFilterNode(DjangoNode, DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
filter_fields = ['first_name', 'articles'] filter_fields = ['first_name', 'articles']
filter_order_by = True filter_order_by = True
class ArticleFilterNode(DjangoNode): class ArticleFilterNode(DjangoNode, DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
@ -152,25 +161,24 @@ def test_filter_filterset_information_on_meta_related():
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode) all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
all_articles = DjangoFilterConnectionField(ArticleFilterNode) all_articles = DjangoFilterConnectionField(ArticleFilterNode)
reporter = NodeField(ReporterFilterNode) reporter = Field(ReporterFilterNode)
article = NodeField(ArticleFilterNode) article = Field(ArticleFilterNode)
schema = Schema(query=Query) schema = Schema(query=Query)
schema.schema # Trigger the schema loading articles_field = ReporterFilterNode._meta.graphql_type.get_fields()['articles']
articles_field = schema.get_type('ReporterFilterNode')._meta.fields_map['articles']
assert_arguments(articles_field, 'headline', 'reporter') assert_arguments(articles_field, 'headline', 'reporter')
assert_orderable(articles_field) assert_orderable(articles_field)
def test_filter_filterset_related_results(): def test_filter_filterset_related_results():
class ReporterFilterNode(DjangoNode): class ReporterFilterNode(DjangoNode, DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
filter_fields = ['first_name', 'articles'] filter_fields = ['first_name', 'articles']
filter_order_by = True filter_order_by = True
class ArticleFilterNode(DjangoNode): class ArticleFilterNode(DjangoNode, DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
@ -180,8 +188,8 @@ def test_filter_filterset_related_results():
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode) all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
all_articles = DjangoFilterConnectionField(ArticleFilterNode) all_articles = DjangoFilterConnectionField(ArticleFilterNode)
reporter = NodeField(ReporterFilterNode) reporter = Field(ReporterFilterNode)
article = NodeField(ArticleFilterNode) article = Field(ArticleFilterNode)
r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com') r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com')
r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com') r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com')

View File

@ -1,6 +1,6 @@
import six import six
from ....core.types import Argument, String from graphene import Argument, String
from .filterset import custom_filterset_factory, setup_filterset from .filterset import custom_filterset_factory, setup_filterset
@ -9,16 +9,16 @@ def get_filtering_args_from_filterset(filterset_class, type):
a Graphene Field. These arguments will be available to a Graphene Field. These arguments will be available to
filter against in the GraphQL filter against in the GraphQL
""" """
from graphene.contrib.django.form_converter import convert_form_field from ..form_converter import convert_form_field
args = {} args = {}
for name, filter_field in six.iteritems(filterset_class.base_filters): for name, filter_field in six.iteritems(filterset_class.base_filters):
field_type = Argument(convert_form_field(filter_field.field)) field_type = convert_form_field(filter_field.field)
args[name] = field_type args[name] = field_type
# Also add the 'order_by' field # Also add the 'order_by' field
if filterset_class._meta.order_by: if filterset_class._meta.order_by:
args[filterset_class.order_by_field] = Argument(String()) args[filterset_class.order_by_field] = String()
return args return args

View File

@ -248,4 +248,4 @@ def test_should_postgres_range_convert_list():
from django.contrib.postgres.fields import IntegerRangeField from django.contrib.postgres.fields import IntegerRangeField
field = assert_conversion(IntegerRangeField, graphene.List) field = assert_conversion(IntegerRangeField, graphene.List)
assert isinstance(field.type, graphene.List) assert isinstance(field.type, graphene.List)
# assert isinstance(field.type.of_type, graphene.Int) assert field.type.of_type == get_graphql_type(graphene.Int)

View File

@ -8,7 +8,7 @@ from graphene.relay import Node
from graphene.relay.node import NodeMeta from graphene.relay.node import NodeMeta
from .converter import convert_django_field_with_choices from .converter import convert_django_field_with_choices
from graphene.types.options import Options from graphene.types.options import Options
from .utils import get_model_fields, is_valid_django_model from .utils import get_model_fields, is_valid_django_model, DJANGO_FILTER_INSTALLED
from .registry import Registry, get_global_registry from .registry import Registry, get_global_registry
from graphene.utils.is_base_type import is_base_type from graphene.utils.is_base_type import is_base_type
from graphene.utils.copy_fields import copy_fields from graphene.utils.copy_fields import copy_fields
@ -49,8 +49,7 @@ class DjangoObjectTypeMeta(ObjectTypeMeta):
if not is_base_type(bases, DjangoObjectTypeMeta): if not is_base_type(bases, DjangoObjectTypeMeta):
return super_new(cls, name, bases, attrs) return super_new(cls, name, bases, attrs)
options = Options( defaults = dict(
attrs.pop('Meta', None),
name=None, name=None,
description=None, description=None,
model=None, model=None,
@ -59,6 +58,19 @@ class DjangoObjectTypeMeta(ObjectTypeMeta):
interfaces=(), interfaces=(),
registry=None registry=None
) )
if DJANGO_FILTER_INSTALLED:
# In case Django filter is available, then
# we allow more attributes in Meta
defaults = dict(
defaults,
filter_fields=(),
filter_order_by=(),
)
options = Options(
attrs.pop('Meta', None),
**defaults
)
if not options.registry: if not options.registry:
options.registry = get_global_registry() options.registry = get_global_registry()
assert isinstance(options.registry, Registry), 'The attribute registry in {}.Meta needs to be an instance of Registry, received "{}".'.format(name, options.registry) assert isinstance(options.registry, Registry), 'The attribute registry in {}.Meta needs to be an instance of Registry, received "{}".'.format(name, options.registry)
@ -77,7 +89,7 @@ class DjangoObjectTypeMeta(ObjectTypeMeta):
fields=partial(cls._construct_fields, fields, options), fields=partial(cls._construct_fields, fields, options),
interfaces=tuple(get_interfaces(interfaces + base_interfaces)) interfaces=tuple(get_interfaces(interfaces + base_interfaces))
) )
options.get_fields = lambda: {} options.get_fields = partial(cls._construct_fields, fields, options)
if issubclass(cls, DjangoObjectType): if issubclass(cls, DjangoObjectType):
options.registry.register(cls) options.registry.register(cls)

View File

@ -1,5 +1,6 @@
import re import re
from collections import Iterable import copy
from collections import Iterable, OrderedDict
import six import six
@ -72,8 +73,15 @@ class Connection(six.with_metaclass(ConnectionMeta, ObjectType)):
class IterableConnectionField(Field): class IterableConnectionField(Field):
def __init__(self, type, args={}, *other_args, **kwargs): def __init__(self, type, *other_args, **kwargs):
super(IterableConnectionField, self).__init__(type, args=connection_args, *other_args, **kwargs) args = kwargs.pop('args', {})
if not args:
args = connection_args
else:
args = copy.copy(args)
args.update(connection_args)
super(IterableConnectionField, self).__init__(type, args=args, *other_args, **kwargs)
@property @property
def type(self): def type(self):

View File

@ -6,6 +6,7 @@ from graphql.type.definition import GraphQLArgument, GraphQLArgumentDefinition
from graphql.utils.assert_valid_name import assert_valid_name from graphql.utils.assert_valid_name import assert_valid_name
from ..utils.orderedtype import OrderedType from ..utils.orderedtype import OrderedType
from ..utils.str_converters import to_camel_case
class Argument(GraphQLArgument, OrderedType): class Argument(GraphQLArgument, OrderedType):
@ -67,7 +68,7 @@ def to_arguments(*args, **extra):
raise ValueError('Unknown argument "{}".'.format(default_name)) raise ValueError('Unknown argument "{}".'.format(default_name))
arg = Argument.copy_from(arg) arg = Argument.copy_from(arg)
arg.name = arg.name or default_name arg.name = arg.name or default_name and to_camel_case(default_name)
assert arg.name, 'All arguments must have a name.' assert arg.name, 'All arguments must have a name.'
assert arg.name not in arguments_names, 'More than one Argument have same name "{}".'.format(arg.name) assert arg.name not in arguments_names, 'More than one Argument have same name "{}".'.format(arg.name)
arguments.append(arg) arguments.append(arg)